taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

rnn_tgtcls.py (604B)


      1 import numpy
      2 import theano
      3 from theano import tensor
      4 from blocks.bricks.base import lazy
      5 from blocks.bricks import Softmax
      6 
      7 from model.rnn import RNN, Stream
      8 
      9 
     10 class Model(RNN):
     11     @lazy()
     12     def __init__(self, config, **kwargs):
     13         super(Model, self).__init__(config, output_dim=config.tgtcls.shape[0], **kwargs)
     14         self.classes = theano.shared(numpy.array(config.tgtcls, dtype=theano.config.floatX), name='classes')
     15         self.softmax = Softmax()
     16         self.children.append(self.softmax)
     17 
     18     def process_rto(self, rto):
     19         return tensor.dot(self.softmax.apply(rto), self.classes)