taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

commit 0be3ebaa19f2cf8a630565434e785e5c24929a14
parent 5589a8af8967cfc73d3b6fda8f86acc0d08172b8
Author: Étienne Simon <esimon@esimon.eu>
Date:   Fri, 24 Apr 2015 16:37:49 -0400

Make TaxiData accept multiple files

Diffstat:
Mdata.py | 48+++++++++++++++++++++++++++++++++++++-----------
1 file changed, 37 insertions(+), 11 deletions(-)

diff --git a/data.py b/data.py @@ -55,31 +55,57 @@ class DayType(Enum): class TaxiData(Dataset): provides_sources= ("trip_id","call_type","origin_call","origin_stand","taxi_id","timestamp","day_type","missing_data","polyline") - example_iteration_scheme=None - def __init__(self, path): - self.path=path + class State: + __slots__ = ('file', 'index', 'reader') + + def __init__(self, pathes, has_header=False): + if not isinstance(pathes, list): + pathes=[pathes] + assert len(pathes) + self.pathes=pathes + self.has_header=has_header super(TaxiData, self).__init__() def open(self): - file=open(self.path) - reader=csv.reader(file) - reader.next() # Skip header - return (file, reader) + state=self.State() + state.file=open(self.pathes[0]) + state.index=0 + state.reader=csv.reader(state.file) + if self.has_header: + state.reader.next() + return state def close(self, state): - state[0].close() + state.file.close() def reset(self, state): - state[0].seek(0) - state[1]=csv.reader(state[0]) + if state.index==0: + state.file.seek(0) + else: + state.index=0 + state.file.close() + state.file=open(self.pathes[0]) + state.reader=csv.reader(state[0]) return state def get_data(self, state, request=None): if request is not None: raise ValueError - line=state[1].next() + try: + line=state.reader.next() + except StopIteration: + state.file.close() + state.index+=1 + if state.index>=len(self.pathes): + raise + state.file=open(self.pathes[state.index]) + state.reader=csv.reader(state.file) + if self.has_header: + state.reader.next() + line=state.reader.next() + line[1]=CallType.from_data(line[1]) # call_type line[2]=0 if line[2]=='' or line[2]=='NA' else int(line[2]) # origin_call line[3]=0 if line[3]=='' or line[3]=='NA' else int(line[3]) # origin_stand