taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

commit bd08e452093bba68fe2d79b1e9da76488b203720
parent ad5a03c6f60e5b2d543326bf8917b48e5b390b82
Author: Étienne Simon <esimon@esimon.eu>
Date:   Mon, 22 Jun 2015 14:40:19 -0400

Update memory network

Diffstat:
Mconfig/memory_network_1.py | 5+++--
Mdata/cut.py | 5+++--
Mmodel/memory_network.py | 2+-
3 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/config/memory_network_1.py b/config/memory_network_1.py @@ -21,13 +21,13 @@ class MLPConfig(object): prefix_encoder = MLPConfig() prefix_encoder.dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings) -prefix_encoder.dim_hidden = [50] +prefix_encoder.dim_hidden = [100, 100, 100] prefix_encoder.weights_init = IsotropicGaussian(0.01) prefix_encoder.biases_init = Constant(0.001) candidate_encoder = MLPConfig() candidate_encoder.dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings) -candidate_encoder.dim_hidden = [50] +candidate_encoder.dim_hidden = [100, 100, 100] candidate_encoder.weights_init = IsotropicGaussian(0.01) candidate_encoder.biases_init = Constant(0.001) @@ -38,6 +38,7 @@ batch_size = 32 valid_set = 'cuts/test_times_0' max_splits = 1 +num_cuts = 1000 train_candidate_size = 1000 valid_candidate_size = 10000 diff --git a/data/cut.py b/data/cut.py @@ -11,14 +11,15 @@ last_time = 1404172787 class TaxiTimeCutScheme(IterationScheme): - def __init__(self, dbfile=None, use_cuts=None): + def __init__(self, num_cuts=100, dbfile=None, use_cuts=None): + self.num_cuts = num_cuts self.dbfile = os.path.join(data.path, 'time_index.db') if dbfile == None else dbfile self.use_cuts = use_cuts def get_request_iterator(self): cuts = self.use_cuts if cuts == None: - cuts = [random.randrange(first_time, last_time) for _ in range(100)] + cuts = [random.randrange(first_time, last_time) for _ in range(self.num_cuts)] l = [] with sqlite3.connect(self.dbfile) as db: diff --git a/model/memory_network.py b/model/memory_network.py @@ -88,7 +88,7 @@ class Stream(object): dataset = TaxiDataset('train') - prefix_stream = DataStream(dataset, iteration_scheme=TaxiTimeCutScheme()) + prefix_stream = DataStream(dataset, iteration_scheme=TaxiTimeCutScheme(self.config.num_cuts)) prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, valid_trips_ids) prefix_stream = transformers.TaxiGenerateSplits(prefix_stream, max_splits=self.config.max_splits) prefix_stream = transformers.taxi_add_datetime(prefix_stream)