taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

commit b61a411fbcb98b09ee83f8dd124113c6d7f47737
parent 028402e4a2fafc39cc8fc0e036e79017b9f9c26a
Author: Étienne Simon <esimon@esimon.eu>
Date:   Thu, 23 Jul 2015 12:43:55 -0400

Fix rnn validation

Diffstat:
Mmodel/rnn.py | 2+-
1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/model/rnn.py b/model/rnn.py @@ -134,7 +134,7 @@ class RNN(Initializable): @application(outputs=['cost']) def valid_cost(self, **kwargs): last_id = tensor.cast(kwargs['latitude_mask'].sum(axis=1) - 1, dtype='int64') - return self.cost_matrix(**kwargs)[last_id, tensor.arange(kwargs['latitude_mask'].shape[1])].mean() + return self.cost_matrix(**kwargs)[last_id, tensor.arange(kwargs['latitude_mask'].shape[0])].mean() @valid_cost.property('inputs') def valid_cost_inputs(self):