taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

init_valid.py (1680B)


      1 #!/usr/bin/env python2
      2 # Initialize the valid hdf5
      3 
      4 import os
      5 import sys
      6 
      7 import h5py
      8 import numpy
      9 
     10 import data
     11 
     12 
     13 _fields = {
     14     'trip_id': 'S19',
     15     'call_type': numpy.int8,
     16     'origin_call': numpy.int32,
     17     'origin_stand': numpy.int8,
     18     'taxi_id': numpy.int16,
     19     'timestamp': numpy.int32,
     20     'day_type': numpy.int8,
     21     'missing_data': numpy.bool,
     22     'latitude': data.Polyline,
     23     'longitude': data.Polyline,
     24     'destination_latitude': numpy.float32,
     25     'destination_longitude': numpy.float32,
     26     'travel_time': numpy.int32,
     27 }
     28 
     29 
     30 def init_valid(path):
     31     h5file = h5py.File(path, 'w')
     32     
     33     for k, v in _fields.iteritems():
     34         h5file.create_dataset(k, (0,), dtype=v, maxshape=(None,))
     35 
     36     split_array = numpy.empty(len(_fields), dtype=numpy.dtype([
     37         ('split', 'a', 64),
     38         ('source', 'a', 21),
     39         ('start', numpy.int64, 1),
     40         ('stop', numpy.int64, 1),
     41         ('indices', h5py.special_dtype(ref=h5py.Reference)),
     42         ('available', numpy.bool, 1),
     43         ('comment', 'a', 1)]))
     44 
     45     split_array[:]['split'] = 'dummy'.encode('utf8')
     46     for (i, k) in enumerate(_fields.keys()):
     47         split_array[i]['source'] = k.encode('utf8')
     48     split_array[:]['start'] = 0
     49     split_array[:]['stop'] = 0
     50     split_array[:]['available'] = False
     51     split_array[:]['indices'] = None
     52     split_array[:]['comment'] = '.'.encode('utf8')
     53     h5file.attrs['split'] = split_array
     54 
     55     h5file.flush()
     56     h5file.close()
     57 
     58 if __name__ == '__main__':
     59     if len(sys.argv) > 2:
     60         print >> sys.stderr, 'Usage: %s [file]' % sys.argv[0]
     61         sys.exit(1)
     62     init_valid(sys.argv[1] if len(sys.argv) == 2 else os.path.join(data.path, 'valid.hdf5'))