taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

commit dd8ae5ea8ed0c7cb1a7880b1e1887c6e23cdf910
parent 9013799ebca1c426c3c3e9019eb71018b253b025
Author: Alex Auvolat <alex.auvolat@ens.fr>
Date:   Thu, 23 Jul 2015 10:07:38 -0400

Fix RNN prediction function

Diffstat:
Mmodel/rnn.py | 4++--
1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/model/rnn.py b/model/rnn.py @@ -99,7 +99,7 @@ class RNN(Initializable): res = self.predict_all(**kwargs)[0] last_id = tensor.cast(kwargs['latitude_mask'].sum(axis=0) - 1, dtype='int64') - return res[last_id] + return res[last_id, tensor.arange(kwargs['latitude_mask'].shape[1])] @predict.property('inputs') def predict_inputs(self): @@ -134,7 +134,7 @@ class RNN(Initializable): @application(outputs=['cost']) def valid_cost(self, **kwargs): last_id = tensor.cast(kwargs['latitude_mask'].sum(axis=1) - 1, dtype='int64') - return self.cost_matrix(**kwargs)[last_id].mean() + return self.cost_matrix(**kwargs)[last_id, tensor.arange(kwargs['latitude_mask'].shape[1])].mean() @valid_cost.property('inputs') def valid_cost_inputs(self):