taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

csv_to_hdf5.py (5189B)


      1 #!/usr/bin/env python2
      2 
      3 import ast
      4 import csv
      5 import os
      6 import sys
      7 
      8 import h5py
      9 import numpy
     10 from fuel.converters.base import fill_hdf5_file
     11 
     12 import data
     13 
     14 
     15 taxi_id_dict = {}
     16 origin_call_dict = {0: 0}
     17 
     18 def get_unique_taxi_id(val):
     19     if val in taxi_id_dict:
     20         return taxi_id_dict[val]
     21     else:
     22         taxi_id_dict[val] = len(taxi_id_dict)
     23         return len(taxi_id_dict) - 1
     24 
     25 def get_unique_origin_call(val):
     26     if val in origin_call_dict:
     27         return origin_call_dict[val]
     28     else:
     29         origin_call_dict[val] = len(origin_call_dict)
     30         return len(origin_call_dict) - 1
     31 
     32 def read_stands(input_directory, h5file):
     33     stands_name = numpy.empty(shape=(data.stands_size,), dtype=('a', 24))
     34     stands_latitude = numpy.empty(shape=(data.stands_size,), dtype=numpy.float32)
     35     stands_longitude = numpy.empty(shape=(data.stands_size,), dtype=numpy.float32)
     36     stands_name[0] = 'None'
     37     stands_latitude[0] = stands_longitude[0] = 0
     38     with open(os.path.join(input_directory, 'metaData_taxistandsID_name_GPSlocation.csv'), 'r') as f:
     39         reader = csv.reader(f)
     40         reader.next() # header
     41         for line in reader:
     42             id = int(line[0])
     43             stands_name[id] = line[1]
     44             stands_latitude[id] = float(line[2])
     45             stands_longitude[id] = float(line[3])
     46     return (('stands', 'stands_name', stands_name),
     47             ('stands', 'stands_latitude', stands_latitude),
     48             ('stands', 'stands_longitude', stands_longitude))
     49 
     50 def read_taxis(input_directory, h5file, dataset):
     51     print >> sys.stderr, 'read %s: begin' % dataset
     52     size=getattr(data, '%s_size'%dataset)
     53     trip_id = numpy.empty(shape=(size,), dtype='S19')
     54     call_type = numpy.empty(shape=(size,), dtype=numpy.int8)
     55     origin_call = numpy.empty(shape=(size,), dtype=numpy.int32)
     56     origin_stand = numpy.empty(shape=(size,), dtype=numpy.int8)
     57     taxi_id = numpy.empty(shape=(size,), dtype=numpy.int16)
     58     timestamp = numpy.empty(shape=(size,), dtype=numpy.int32)
     59     day_type = numpy.empty(shape=(size,), dtype=numpy.int8)
     60     missing_data = numpy.empty(shape=(size,), dtype=numpy.bool)
     61     latitude = numpy.empty(shape=(size,), dtype=data.Polyline)
     62     longitude = numpy.empty(shape=(size,), dtype=data.Polyline)
     63     with open(os.path.join(input_directory, '%s.csv'%dataset), 'r') as f:
     64         reader = csv.reader(f)
     65         reader.next() # header
     66         id=0
     67         for line in reader:
     68             if id%10000==0 and id!=0:
     69                 print >> sys.stderr, 'read %s: %d done' % (dataset, id)
     70             trip_id[id] = line[0]
     71             call_type[id] = ord(line[1][0]) - ord('A')
     72             origin_call[id] = 0 if line[2]=='NA' or line[2]=='' else get_unique_origin_call(int(line[2]))
     73             origin_stand[id] = 0 if line[3]=='NA' or line[3]=='' else int(line[3])
     74             taxi_id[id] = get_unique_taxi_id(int(line[4]))
     75             timestamp[id] = int(line[5])
     76             day_type[id] = ord(line[6][0]) - ord('A')
     77             missing_data[id] = line[7][0] == 'T'
     78             polyline = ast.literal_eval(line[8])
     79             latitude[id] = numpy.array([point[1] for point in polyline], dtype=numpy.float32)
     80             longitude[id] = numpy.array([point[0] for point in polyline], dtype=numpy.float32)
     81             id+=1
     82     splits = ()
     83     print >> sys.stderr, 'read %s: writing' % dataset
     84     for name in ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude']:
     85         splits += ((dataset, name, locals()[name]),)
     86     print >> sys.stderr, 'read %s: end' % dataset
     87     return splits
     88 
     89 def unique(h5file):
     90     unique_taxi_id = numpy.empty(shape=(data.taxi_id_size,), dtype=numpy.int32)
     91     assert len(taxi_id_dict) == data.taxi_id_size
     92     for k, v in taxi_id_dict.items():
     93         unique_taxi_id[v] = k
     94 
     95     unique_origin_call = numpy.empty(shape=(data.origin_call_size,), dtype=numpy.int32)
     96     assert len(origin_call_dict) == data.origin_call_size
     97     for k, v in origin_call_dict.items():
     98         unique_origin_call[v] = k
     99 
    100     return (('unique_taxi_id', 'unique_taxi_id', unique_taxi_id),
    101             ('unique_origin_call', 'unique_origin_call', unique_origin_call))
    102 
    103 def convert(input_directory, save_path):
    104     h5file = h5py.File(save_path, 'w')
    105     split = ()
    106     split += read_stands(input_directory, h5file)
    107     split += read_taxis(input_directory, h5file, 'train')
    108     print 'First origin_call not present in training set: ', len(origin_call_dict)
    109     split += read_taxis(input_directory, h5file, 'test')
    110     split += unique(h5file)
    111 
    112     fill_hdf5_file(h5file, split)
    113 
    114     for name in ['stands_name', 'stands_latitude', 'stands_longitude', 'unique_taxi_id', 'unique_origin_call']:
    115         h5file[name].dims[0].label = 'index'
    116     for name in ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude']:
    117         h5file[name].dims[0].label = 'batch'
    118 
    119     h5file.flush()
    120     h5file.close()
    121 
    122 if __name__ == '__main__':
    123     if len(sys.argv) != 3:
    124         print >> sys.stderr, 'Usage: %s download_dir output_file' % sys.argv[0]
    125         sys.exit(1)
    126     convert(sys.argv[1], sys.argv[2])