commit 0be3ebaa19f2cf8a630565434e785e5c24929a14
parent 5589a8af8967cfc73d3b6fda8f86acc0d08172b8
Author: Étienne Simon <esimon@esimon.eu>
Date:   Fri, 24 Apr 2015 16:37:49 -0400
Make TaxiData accept multiple files
Diffstat:
| M | data.py |  |  | 48 | +++++++++++++++++++++++++++++++++++++----------- | 
1 file changed, 37 insertions(+), 11 deletions(-)
diff --git a/data.py b/data.py
@@ -55,31 +55,57 @@ class DayType(Enum):
 
 class TaxiData(Dataset):
     provides_sources= ("trip_id","call_type","origin_call","origin_stand","taxi_id","timestamp","day_type","missing_data","polyline")
-
     example_iteration_scheme=None
 
-    def __init__(self, path):
-        self.path=path
+    class State:
+        __slots__ = ('file', 'index', 'reader')
+
+    def __init__(self, pathes, has_header=False):
+        if not isinstance(pathes, list):
+            pathes=[pathes]
+        assert len(pathes)
+        self.pathes=pathes
+        self.has_header=has_header
         super(TaxiData, self).__init__()
 
     def open(self):
-        file=open(self.path)
-        reader=csv.reader(file)
-        reader.next() # Skip header
-        return (file, reader)
+        state=self.State()
+        state.file=open(self.pathes[0])
+        state.index=0
+        state.reader=csv.reader(state.file)
+        if self.has_header:
+            state.reader.next()
+        return state
 
     def close(self, state):
-        state[0].close()
+        state.file.close()
 
     def reset(self, state):
-        state[0].seek(0)
-        state[1]=csv.reader(state[0])
+        if state.index==0:
+            state.file.seek(0)
+        else:
+            state.index=0
+            state.file.close()
+            state.file=open(self.pathes[0])
+        state.reader=csv.reader(state[0])
         return state
 
     def get_data(self, state, request=None):
         if request is not None:
             raise ValueError
-        line=state[1].next()
+        try:
+            line=state.reader.next()
+        except StopIteration:
+            state.file.close()
+            state.index+=1
+            if state.index>=len(self.pathes):
+                raise
+            state.file=open(self.pathes[state.index])
+            state.reader=csv.reader(state.file)
+            if self.has_header:
+                state.reader.next()
+            line=state.reader.next()
+
         line[1]=CallType.from_data(line[1]) # call_type
         line[2]=0 if line[2]=='' or line[2]=='NA' else int(line[2]) # origin_call
         line[3]=0 if line[3]=='' or line[3]=='NA' else int(line[3]) # origin_stand