taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

stream.py (4579B)


      1 from fuel.transformers import Batch, Padding, Mapping, SortMapping, Unpack, MultiProcessing, Filter
      2 from fuel.streams import DataStream
      3 from fuel.schemes import ConstantScheme, ShuffledExampleScheme
      4 
      5 from theano import tensor
      6 
      7 import data
      8 from data import transformers
      9 from data.hdf5 import TaxiDataset, TaxiStream
     10 
     11 
     12 class StreamRec(object):
     13     def __init__(self, config):
     14         self.config = config
     15 
     16     def train(self, req_vars):
     17         stream = TaxiDataset('train', data.traintest_ds)
     18 
     19         if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training:
     20             stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
     21         else:
     22             stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))
     23 
     24         if not data.tvt:
     25             valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',))
     26             valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
     27             stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
     28 
     29         if hasattr(self.config, 'max_splits'):
     30             stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits)
     31         elif not data.tvt:
     32             stream = transformers.add_destination(stream)
     33 
     34         if hasattr(self.config, 'train_max_len'):
     35             idx = stream.sources.index('latitude')
     36             def max_len_filter(x):
     37                 return len(x[idx]) <= self.config.train_max_len
     38             stream = Filter(stream, max_len_filter)
     39 
     40         stream = transformers.TaxiExcludeEmptyTrips(stream)
     41         stream = transformers.taxi_add_datetime(stream)
     42         stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
     43 
     44         stream = transformers.balanced_batch(stream, key='latitude',
     45                                              batch_size=self.config.batch_size,
     46                                              batch_sort_size=self.config.batch_sort_size)
     47         stream = Padding(stream, mask_sources=['latitude', 'longitude'])
     48         stream = transformers.Select(stream, req_vars)
     49         stream = MultiProcessing(stream)
     50 
     51         return stream
     52 
     53     def valid(self, req_vars):
     54         stream = TaxiStream(data.valid_set, data.valid_ds)
     55 
     56         stream = transformers.taxi_add_datetime(stream)
     57         stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
     58 
     59         stream = transformers.balanced_batch(stream, key='latitude',
     60                                              batch_size=self.config.batch_size,
     61                                              batch_sort_size=self.config.batch_sort_size)
     62         stream = Padding(stream, mask_sources=['latitude', 'longitude'])
     63         stream = transformers.Select(stream, req_vars)
     64         stream = MultiProcessing(stream)
     65 
     66         return stream
     67 
     68     def test(self, req_vars):
     69         stream = TaxiStream('test', data.traintest_ds)
     70         
     71         stream = transformers.taxi_add_datetime(stream)
     72         stream = transformers.taxi_remove_test_only_clients(stream)
     73 
     74         stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
     75 
     76         stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
     77         stream = Padding(stream, mask_sources=['latitude', 'longitude'])
     78         stream = transformers.Select(stream, req_vars)
     79         return stream
     80 
     81     def inputs(self):
     82         return {'call_type': tensor.bvector('call_type'),
     83                 'origin_call': tensor.ivector('origin_call'),
     84                 'origin_stand': tensor.bvector('origin_stand'),
     85                 'taxi_id': tensor.wvector('taxi_id'),
     86                 'timestamp': tensor.ivector('timestamp'),
     87                 'day_type': tensor.bvector('day_type'),
     88                 'missing_data': tensor.bvector('missing_data'),
     89                 'latitude': tensor.matrix('latitude'),
     90                 'longitude': tensor.matrix('longitude'),
     91                 'latitude_mask': tensor.matrix('latitude_mask'),
     92                 'longitude_mask': tensor.matrix('longitude_mask'),
     93                 'destination_latitude': tensor.vector('destination_latitude'),
     94                 'destination_longitude': tensor.vector('destination_longitude'),
     95                 'travel_time': tensor.ivector('travel_time'),
     96                 'input_time': tensor.ivector('input_time'),
     97                 'week_of_year': tensor.bvector('week_of_year'),
     98                 'day_of_week': tensor.bvector('day_of_week'),
     99                 'qhour_of_day': tensor.bvector('qhour_of_day')}
    100