taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

make_tvt.py (6962B)


      1 #!/usr/bin/env python2
      2 # Separate the training set into a Training Valid and Test set
      3 
      4 import os
      5 import sys
      6 import importlib
      7 import cPickle
      8 
      9 import h5py
     10 import numpy
     11 import theano
     12 
     13 import data
     14 from data.hdf5 import TaxiDataset
     15 from error import hdist
     16 
     17 
     18 native_fields = {
     19     'trip_id': 'S19',
     20     'call_type': numpy.int8,
     21     'origin_call': numpy.int32,
     22     'origin_stand': numpy.int8,
     23     'taxi_id': numpy.int16,
     24     'timestamp': numpy.int32,
     25     'day_type': numpy.int8,
     26     'missing_data': numpy.bool,
     27     'latitude': data.Polyline,
     28     'longitude': data.Polyline,
     29 }
     30 
     31 all_fields = {
     32     'path_len': numpy.int16,
     33     'cluster': numpy.int16,
     34     'destination_latitude': numpy.float32,
     35     'destination_longitude': numpy.float32,
     36     'travel_time': numpy.int32,
     37 }
     38 
     39 all_fields.update(native_fields)
     40 
     41 def cut_me_baby(train, cuts, excl={}):
     42     dset = {}
     43     cuts.sort()
     44     cut_id = 0
     45     for i in xrange(data.train_size):
     46         if i%10000==0 and i!=0:
     47             print >> sys.stderr, 'cut: {:d} done'.format(i)
     48         if i in excl:
     49             continue
     50         time = train['timestamp'][i]
     51         latitude = train['latitude'][i]
     52         longitude = train['longitude'][i]
     53 
     54         if len(latitude) == 0:
     55             continue
     56 
     57         end_time = time + 15 * (len(latitude) - 1)
     58 
     59         while cuts[cut_id] < time:
     60             if cut_id >= len(cuts)-1:
     61                 return dset
     62             cut_id += 1
     63 
     64         if end_time < cuts[cut_id]:
     65             continue
     66         else:
     67             dset[i] = (cuts[cut_id] - time) / 15 + 1
     68 
     69     return dset
     70 
     71 def make_tvt(test_cuts_name, valid_cuts_name, outpath):
     72     trainset = TaxiDataset('train')
     73     traindata = trainset.get_data(None, slice(0, trainset.num_examples))
     74     idsort = traindata[trainset.sources.index('timestamp')].argsort()
     75 
     76     traindata = dict(zip(trainset.sources, (t[idsort] for t in traindata)))
     77 
     78     print >> sys.stderr, 'test cut begin'
     79     test_cuts = importlib.import_module('.%s' % test_cuts_name, 'data.cuts').cuts
     80     test = cut_me_baby(traindata, test_cuts)
     81 
     82     print >> sys.stderr, 'valid cut begin'
     83     valid_cuts = importlib.import_module('.%s' % valid_cuts_name, 'data.cuts').cuts
     84     valid = cut_me_baby(traindata, valid_cuts, test)
     85 
     86     test_size = len(test)
     87     valid_size = len(valid)
     88     train_size = data.train_size - test_size - valid_size
     89 
     90     print ' set   | size    | ratio'
     91     print ' ----- | ------- | -----'
     92     print ' train | {:>7d} | {:>5.3f}'.format(train_size, float(train_size)/data.train_size)
     93     print ' valid | {:>7d} | {:>5.3f}'.format(valid_size, float(valid_size)/data.train_size)
     94     print ' test  | {:>7d} | {:>5.3f}'.format(test_size , float(test_size )/data.train_size)
     95 
     96     with open(os.path.join(data.path, 'arrival-clusters.pkl'), 'r') as f:
     97         clusters = cPickle.load(f)
     98 
     99     print >> sys.stderr, 'compiling cluster assignment function'
    100     latitude = theano.tensor.scalar('latitude')
    101     longitude = theano.tensor.scalar('longitude')
    102     coords = theano.tensor.stack(latitude, longitude).dimshuffle('x', 0)
    103     parent = theano.tensor.argmin(hdist(clusters, coords))
    104     cluster = theano.function([latitude, longitude], parent)
    105 
    106     train_clients = set()
    107 
    108     print >> sys.stderr, 'preparing hdf5 data'
    109     hdata = {k: numpy.empty(shape=(data.train_size,), dtype=v) for k, v in all_fields.iteritems()}
    110 
    111     train_i = 0
    112     valid_i = train_size
    113     test_i = train_size + valid_size
    114 
    115     print >> sys.stderr, 'write: begin'
    116     for idtraj in xrange(data.train_size):
    117         if idtraj%10000==0 and idtraj!=0:
    118             print >> sys.stderr, 'write: {:d} done'.format(idtraj)
    119         in_test = idtraj in test
    120         in_valid = not in_test and idtraj in valid
    121         in_train = not in_test and not in_valid
    122 
    123         if idtraj in test:
    124             i = test_i
    125             test_i += 1
    126         elif idtraj in valid:
    127             i = valid_i
    128             valid_i += 1
    129         else:
    130             train_clients.add(traindata['origin_call'][idtraj])
    131             i = train_i
    132             train_i += 1
    133 
    134         trajlen = len(traindata['latitude'][idtraj])
    135         if trajlen == 0:
    136             hdata['destination_latitude'][i] = data.train_gps_mean[0]
    137             hdata['destination_longitude'][i] = data.train_gps_mean[1]
    138         else:
    139             hdata['destination_latitude'][i] = traindata['latitude'][idtraj][-1]
    140             hdata['destination_longitude'][i] = traindata['longitude'][idtraj][-1]
    141         hdata['travel_time'][i] = trajlen
    142 
    143         for field in native_fields:
    144             val = traindata[field][idtraj]
    145             if field in ['latitude', 'longitude']:
    146                 if in_test:
    147                     val = val[:test[idtraj]]
    148                 elif in_valid:
    149                     val = val[:valid[idtraj]]
    150             hdata[field][i] = val
    151 
    152         plen = len(hdata['latitude'][i])
    153         hdata['path_len'][i] = plen
    154         hdata['cluster'][i] = -1 if plen==0 else cluster(hdata['latitude'][i][0], hdata['longitude'][i][0])
    155 
    156     print >> sys.stderr, 'write: end'
    157 
    158     print >> sys.stderr, 'removing useless origin_call'
    159     for i in xrange(train_size, data.train_size):
    160         if hdata['origin_call'][i] not in train_clients:
    161             hdata['origin_call'][i] = 0
    162 
    163     print >> sys.stderr, 'preparing split array'
    164 
    165     split_array = numpy.empty(len(all_fields)*3, dtype=numpy.dtype([
    166         ('split', 'a', 64),
    167         ('source', 'a', 21),
    168         ('start', numpy.int64, 1),
    169         ('stop', numpy.int64, 1),
    170         ('indices', h5py.special_dtype(ref=h5py.Reference)),
    171         ('available', numpy.bool, 1),
    172         ('comment', 'a', 1)]))
    173 
    174     flen = len(all_fields)
    175     for i, field in enumerate(all_fields):
    176         split_array[i]['split'] = 'train'.encode('utf8')
    177         split_array[i+flen]['split'] = 'valid'.encode('utf8')
    178         split_array[i+2*flen]['split'] = 'test'.encode('utf8')
    179         split_array[i]['start'] = 0
    180         split_array[i]['stop'] = train_size
    181         split_array[i+flen]['start'] = train_size
    182         split_array[i+flen]['stop'] = train_size + valid_size
    183         split_array[i+2*flen]['start'] = train_size + valid_size
    184         split_array[i+2*flen]['stop'] = train_size + valid_size + test_size
    185 
    186         for d in [0, flen, 2*flen]:
    187             split_array[i+d]['source'] = field.encode('utf8')
    188 
    189     split_array[:]['indices'] = None
    190     split_array[:]['available'] = True
    191     split_array[:]['comment'] = '.'.encode('utf8')
    192 
    193     print >> sys.stderr, 'writing hdf5 file'
    194     file = h5py.File(outpath, 'w')
    195     for k in all_fields.keys():
    196         file.create_dataset(k, data=hdata[k], maxshape=(data.train_size,))
    197 
    198     file.attrs['split'] = split_array
    199 
    200     file.flush()
    201     file.close()
    202 
    203 if __name__ == '__main__':
    204     if len(sys.argv) < 3 or len(sys.argv) > 4:
    205         print >> sys.stderr, 'Usage: %s test_cutfile valid_cutfile [outfile]' % sys.argv[0]
    206         sys.exit(1)
    207     outpath = os.path.join(data.path, 'tvt.hdf5') if len(sys.argv) < 4 else sys.argv[3]
    208     make_tvt(sys.argv[1], sys.argv[2], outpath)