taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

rfc4180.py (3460B)


      1 import ast
      2 import csv
      3 import numpy
      4 import os
      5 
      6 from fuel.datasets import Dataset
      7 from fuel.streams import DataStream
      8 from fuel.iterator import DataIterator
      9 
     10 import data
     11 from data.hdf5 import origin_call_normalize, taxi_id_normalize
     12 
     13 
     14 class TaxiData(Dataset):
     15     example_iteration_scheme=None
     16 
     17     class State:
     18         __slots__ = ('file', 'index', 'reader')
     19 
     20     def __init__(self, pathes, columns, has_header=False):
     21         if not isinstance(pathes, list):
     22             pathes=[pathes]
     23         assert len(pathes)>0
     24         self.columns=columns
     25         self.provides_sources = tuple(map(lambda x: x[0], columns))
     26         self.pathes=pathes
     27         self.has_header=has_header
     28         super(TaxiData, self).__init__()
     29 
     30     def open(self):
     31         state=self.State()
     32         state.file=open(self.pathes[0])
     33         state.index=0
     34         state.reader=csv.reader(state.file)
     35         if self.has_header:
     36             state.reader.next()
     37         return state
     38 
     39     def close(self, state):
     40         state.file.close()
     41 
     42     def reset(self, state):
     43         if state.index==0:
     44             state.file.seek(0)
     45         else:
     46             state.index=0
     47             state.file.close()
     48             state.file=open(self.pathes[0])
     49         state.reader=csv.reader(state.file)
     50         return state
     51 
     52     def get_data(self, state, request=None):
     53         if request is not None:
     54             raise ValueError
     55         try:
     56             line=state.reader.next()
     57         except (ValueError, StopIteration):
     58             # print state.index
     59             state.file.close()
     60             state.index+=1
     61             if state.index>=len(self.pathes):
     62                 raise StopIteration
     63             state.file=open(self.pathes[state.index])
     64             state.reader=csv.reader(state.file)
     65             if self.has_header:
     66                 state.reader.next()
     67             return self.get_data(state)
     68 
     69         values = []
     70         for _, constructor in self.columns:
     71             values.append(constructor(line))
     72         return tuple(values)
     73 
     74 taxi_columns = [
     75     ("trip_id", lambda l: l[0]),
     76     ("call_type", lambda l: ord(l[1])-ord('A')),
     77     ("origin_call", lambda l: 0 if l[2] == '' or l[2] == 'NA' else origin_call_normalize(int(l[2]))),
     78     ("origin_stand", lambda l: 0 if l[3] == '' or l[3] == 'NA' else int(l[3])),
     79     ("taxi_id", lambda l: taxi_id_normalize(int(l[4]))),
     80     ("timestamp", lambda l: int(l[5])),
     81     ("day_type", lambda l: ord(l[6])-ord('A')),
     82     ("missing_data", lambda l: l[7][0] == 'T'),
     83     ("polyline", lambda l: map(tuple, ast.literal_eval(l[8]))),
     84     ("longitude", lambda l: map(lambda p: p[0], ast.literal_eval(l[8]))),
     85     ("latitude", lambda l: map(lambda p: p[1], ast.literal_eval(l[8]))),
     86 ]
     87 
     88 taxi_columns_valid = taxi_columns + [
     89     ("destination_longitude", lambda l: numpy.float32(float(l[9]))),
     90     ("destination_latitude", lambda l: numpy.float32(float(l[10]))),
     91     ("time", lambda l: int(l[11])),
     92 ]
     93 
     94 train_file = os.path.join(data.path, 'train.csv')
     95 valid_file = os.path.join(data.path, 'valid2-cut.csv')
     96 test_file = os.path.join(data.path, 'test.csv')
     97 
     98 train_data=TaxiData(train_file, taxi_columns, has_header=True)
     99 valid_data = TaxiData(valid_file, taxi_columns_valid)
    100 test_data = TaxiData(test_file, taxi_columns, has_header=True)
    101 
    102 with open(os.path.join(data.path, 'valid2-cut-ids.txt')) as f:
    103     valid_trips = [l for l in f]
    104 
    105 def train_it():
    106     return DataIterator(DataStream(train_data))
    107 
    108 def test_it():
    109     return DataIterator(DataStream(valid_data))