taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

mlp_emb.py (5366B)


      1 from theano import tensor
      2 
      3 from fuel.transformers import Batch, MultiProcessing
      4 from fuel.streams import DataStream
      5 from fuel.schemes import ConstantScheme, ShuffledExampleScheme
      6 from blocks.bricks import application, MLP, Rectifier, Initializable, Identity
      7 
      8 import error
      9 import data
     10 from data import transformers
     11 from data.hdf5 import TaxiDataset, TaxiStream
     12 from data.cut import TaxiTimeCutScheme
     13 from model import ContextEmbedder
     14 
     15 
     16 class Model(Initializable):
     17     def __init__(self, config, **kwargs):
     18         super(Model, self).__init__(**kwargs)
     19         self.config = config
     20 
     21         self.context_embedder = ContextEmbedder(config)
     22         self.mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()],
     23                        dims=[config.dim_input] + config.dim_hidden + [config.dim_output])
     24 
     25         self.inputs = self.context_embedder.inputs # + self.extremities.keys()
     26         self.children = [ self.context_embedder, self.mlp ]
     27 
     28     def _push_initialization_config(self):
     29         self.mlp.weights_init = self.config.mlp_weights_init
     30         self.mlp.biases_init = self.config.mlp_biases_init
     31 
     32     @application(outputs=['destination'])
     33     def predict(self, **kwargs):
     34         embeddings = tuple(self.context_embedder.apply(**{k: kwargs[k] for k in self.context_embedder.inputs }))
     35 
     36         inputs = tensor.concatenate(embeddings, axis=1)
     37         outputs = self.mlp.apply(inputs)
     38 
     39         if self.config.output_mode == "destination":
     40             return data.train_gps_std * outputs + data.train_gps_mean
     41         elif self.config.dim_output == "clusters":
     42             return tensor.dot(outputs, self.classes)
     43 
     44     @predict.property('inputs')
     45     def predict_inputs(self):
     46         return self.inputs
     47 
     48     @application(outputs=['cost'])
     49     def cost(self, **kwargs):
     50         y_hat = self.predict(**kwargs)
     51         y = tensor.concatenate((kwargs['destination_latitude'][:, None],
     52                                 kwargs['destination_longitude'][:, None]), axis=1)
     53 
     54         return error.erdist(y_hat, y).mean()
     55 
     56     @cost.property('inputs')
     57     def cost_inputs(self):
     58         return self.inputs + ['destination_latitude', 'destination_longitude']
     59 
     60 
     61 class Stream(object):
     62     def __init__(self, config):
     63         self.config = config
     64 
     65     def train(self, req_vars):
     66         valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',))
     67         valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
     68 
     69         stream = TaxiDataset('train')
     70 
     71         if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training:
     72             stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
     73         else:
     74             stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))
     75 
     76         stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
     77         stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits)
     78 
     79         stream = transformers.taxi_add_datetime(stream)
     80         # stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
     81         stream = transformers.Select(stream, tuple(req_vars))
     82         
     83         stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
     84 
     85         stream = MultiProcessing(stream)
     86 
     87         return stream
     88 
     89     def valid(self, req_vars):
     90         stream = TaxiStream(self.config.valid_set, 'valid.hdf5')
     91 
     92         stream = transformers.taxi_add_datetime(stream)
     93         # stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
     94         stream = transformers.Select(stream, tuple(req_vars))
     95         return Batch(stream, iteration_scheme=ConstantScheme(1000))
     96 
     97     def test(self, req_vars):
     98         stream = TaxiStream('test')
     99         
    100         stream = transformers.taxi_add_datetime(stream)
    101         # stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
    102         stream = transformers.taxi_remove_test_only_clients(stream)
    103 
    104         return Batch(stream, iteration_scheme=ConstantScheme(1))
    105 
    106     def inputs(self):
    107         return {'call_type': tensor.bvector('call_type'),
    108                 'origin_call': tensor.ivector('origin_call'),
    109                 'origin_stand': tensor.bvector('origin_stand'),
    110                 'taxi_id': tensor.wvector('taxi_id'),
    111                 'timestamp': tensor.ivector('timestamp'),
    112                 'day_type': tensor.bvector('day_type'),
    113                 'missing_data': tensor.bvector('missing_data'),
    114                 'latitude': tensor.matrix('latitude'),
    115                 'longitude': tensor.matrix('longitude'),
    116                 'destination_latitude': tensor.vector('destination_latitude'),
    117                 'destination_longitude': tensor.vector('destination_longitude'),
    118                 'travel_time': tensor.ivector('travel_time'),
    119                 'first_k_latitude': tensor.matrix('first_k_latitude'),
    120                 'first_k_longitude': tensor.matrix('first_k_longitude'),
    121                 'last_k_latitude': tensor.matrix('last_k_latitude'),
    122                 'last_k_longitude': tensor.matrix('last_k_longitude'),
    123                 'input_time': tensor.ivector('input_time'),
    124                 'week_of_year': tensor.bvector('week_of_year'),
    125                 'day_of_week': tensor.bvector('day_of_week'),
    126                 'qhour_of_day': tensor.bvector('qhour_of_day')}