taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

make_valid_cut.py (2481B)


      1 #!/usr/bin/env python2
      2 # Make a valid dataset by cutting the training set at specified timestamps
      3 
      4 import os
      5 import sys
      6 import importlib
      7 
      8 import h5py
      9 import numpy
     10 
     11 import data
     12 from data.hdf5 import taxi_it
     13 
     14 
     15 _fields = ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude', 'destination_latitude', 'destination_longitude', 'travel_time']
     16 
     17 def make_valid(cutfile, outpath):
     18     cuts = importlib.import_module('.%s' % cutfile, 'data.cuts').cuts
     19 
     20     print "Number of cuts:", len(cuts)
     21 
     22     valid = []
     23 
     24     for line in taxi_it('train'):
     25         time = line['timestamp']
     26         latitude = line['latitude']
     27         longitude = line['longitude']
     28 
     29         if len(latitude) == 0:
     30             continue
     31 
     32         for ts in cuts:
     33             if time <= ts and time + 15 * (len(latitude) - 1) >= ts:
     34                 # keep it
     35                 n = (ts - time) / 15 + 1
     36                 line.update({
     37                     'latitude': latitude[:n],
     38                     'longitude': longitude[:n],
     39                     'destination_latitude': latitude[-1],
     40                     'destination_longitude': longitude[-1],
     41                     'travel_time': 15 * (len(latitude)-1)
     42                 })
     43                 valid.append(line)
     44                 break
     45 
     46     print "Number of trips in validation set:", len(valid)
     47     
     48     file = h5py.File(outpath, 'a')
     49     clen = file['trip_id'].shape[0]
     50     alen = len(valid)
     51     for field in _fields:
     52         dset = file[field]
     53         dset.resize((clen + alen,))
     54         for i in xrange(alen):
     55             dset[clen + i] = valid[i][field]
     56 
     57     splits = file.attrs['split']
     58     slen = splits.shape[0]
     59     splits = numpy.resize(splits, (slen+len(_fields),))
     60     for (i, field) in enumerate(_fields):
     61         splits[slen+i]['split'] = ('cuts/%s' % cutfile).encode('utf8')
     62         splits[slen+i]['source'] = field.encode('utf8')
     63         splits[slen+i]['start'] = clen
     64         splits[slen+i]['stop'] = alen
     65         splits[slen+i]['indices'] = None
     66         splits[slen+i]['available'] = True
     67         splits[slen+i]['comment'] = '.'
     68     file.attrs['split'] = splits
     69 
     70     file.flush()
     71     file.close()
     72 
     73 if __name__ == '__main__':
     74     if len(sys.argv) < 2 or len(sys.argv) > 3:
     75         print >> sys.stderr, 'Usage: %s cutfile [outfile]' % sys.argv[0]
     76         sys.exit(1)
     77     outpath = os.path.join(data.path, 'valid.hdf5') if len(sys.argv) < 3 else sys.argv[2]
     78     make_valid(sys.argv[1], outpath)