commit 712035b88be1816d3fbd58ce69ae6464767c780e
parent 66159d9fce0129116e82e74cf3eb1d9e048b253d
Author: Alex Auvolat <alex.auvolat@ens.fr>
Date:   Tue,  5 May 2015 13:11:18 -0400
Add day type and taxi id
Diffstat:
5 files changed, 75 insertions(+), 9 deletions(-)
diff --git a/config/simple_mlp_2_cswdt.py b/config/simple_mlp_2_cswdt.py
@@ -0,0 +1,25 @@
+import model.simple_mlp as model
+
+import data
+
+n_begin_end_pts = 5     # how many points we consider at the beginning and end of the known trajectory
+n_end_pts = 5
+
+n_valid = 1000
+
+dim_embeddings = [
+    ('origin_call', data.n_train_clients+1, 10),
+    ('origin_stand', data.n_stands+1, 10),
+    ('week_of_year', 52, 10),
+    ('day_of_week', 7, 10),
+    ('qhour_of_day', 24 * 4, 10),
+    ('day_type', 3, 10),
+]
+
+dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings)
+dim_hidden = [200, 100]
+dim_output = 2
+
+learning_rate = 0.0001
+momentum = 0.99
+batch_size = 32
diff --git a/config/simple_mlp_tgtcls_1_cswdt.py b/config/simple_mlp_tgtcls_1_cswdt.py
@@ -14,9 +14,10 @@ with open(data.DATA_PATH + "/arrival-clusters.pkl") as f: tgtcls = cPickle.load(
 dim_embeddings = [
     ('origin_call', data.n_train_clients+1, 10),
     ('origin_stand', data.n_stands+1, 10),
-    ('week_of_year', 53, 10),
+    ('week_of_year', 52, 10),
     ('day_of_week', 7, 10),
-    ('qhour_of_day', 24 * 4, 10)
+    ('qhour_of_day', 24 * 4, 10),
+    ('day_type', 3, 10),
 ]
 
 dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings)
diff --git a/config/simple_mlp_tgtcls_1_cswdtx.py b/config/simple_mlp_tgtcls_1_cswdtx.py
@@ -0,0 +1,30 @@
+import cPickle
+
+import data
+
+import model.simple_mlp_tgtcls as model
+
+n_begin_end_pts = 5     # how many points we consider at the beginning and end of the known trajectory
+n_end_pts = 5
+
+n_valid = 1000
+
+with open(data.DATA_PATH + "/arrival-clusters.pkl") as f: tgtcls = cPickle.load(f)
+
+dim_embeddings = [
+    ('origin_call', data.n_train_clients+1, 10),
+    ('origin_stand', data.n_stands+1, 10),
+    ('week_of_year', 52, 10),
+    ('day_of_week', 7, 10),
+    ('qhour_of_day', 24 * 4, 10),
+    ('day_type', 3, 10),
+    ('taxi_id', 448, 10),
+]
+
+dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings)
+dim_hidden = [500]
+dim_output = tgtcls.shape[0]
+
+learning_rate = 0.0001
+momentum = 0.99
+batch_size = 32
diff --git a/data.py b/data.py
@@ -30,9 +30,7 @@ dataset_size = 1710670
 def make_client_ids():
     f = h5py.File(H5DATA_PATH, "r")
     l = f['unique_origin_call']
-    r = {}
-    for i in range(l.shape[0]):
-        r[l[i]] = i
+    r = {l[i]: i for i in range(l.shape[0])}
     return r
 
 client_ids = make_client_ids()
@@ -43,6 +41,18 @@ def get_client_id(n):
     else:
         return 0
 
+# ---- Read taxi IDs and create reverse dictionnary
+
+def make_taxi_ids():
+    f = h5py.File(H5DATA_PATH, "r")
+    l = f['unique_taxi_id']
+    r = {l[i]: i for i in range(l.shape[0])}
+    return r
+
+taxi_ids = make_taxi_ids()
+        
+# ---- Enum types
+
 class CallType(Enum):
     CENTRAL = 0
     STAND = 1
@@ -154,9 +164,9 @@ taxi_columns = [
     ("call_type", lambda l: CallType.from_data(l[1])),
     ("origin_call", lambda l: 0 if l[2] == '' or l[2] == 'NA' else get_client_id(int(l[2]))),
     ("origin_stand", lambda l: 0 if l[3] == '' or l[3] == 'NA' else int(l[3])),
-    ("taxi_id", lambda l: int(l[4])),
+    ("taxi_id", lambda l: taxi_ids[int(l[4])]),
     ("timestamp", lambda l: int(l[5])),
-    ("day_type", lambda l: DayType.from_data(l[6])),
+    ("day_type", lambda l: ord(l[6])-ord('A')),
     ("missing_data", lambda l: l[7][0] == 'T'),
     ("polyline", lambda l: map(tuple, ast.literal_eval(l[8]))),
     ("longitude", lambda l: map(lambda p: p[0], ast.literal_eval(l[8]))),
diff --git a/transformers.py b/transformers.py
@@ -107,7 +107,8 @@ class TaxiAddDateTime(Transformer):
         data = next(self.child_epoch_iterator)
         ts = data[self.id_timestamp]
         date = datetime.datetime.utcfromtimestamp(ts)
-        info = (date.isocalendar()[1] - 1, date.weekday(), date.hour * 4 + date.minute / 15)
+        yearweek = date.isocalendar()[1] - 1
+        info = ((51 if yearweek == 52 else yearweek), date.weekday(), date.hour * 4 + date.minute / 15)
         return data + info
 
 class TaxiExcludeTrips(Transformer):
@@ -122,4 +123,3 @@ class TaxiExcludeTrips(Transformer):
             if not data[self.id_trip_id] in self.exclude: break
         return data
 
-