taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

commit 5b496677ea1db59a6718e5c9b2958177c76cb25f
parent 95b565afb7e1c2a6eb23ca9f7c13cd6efaf55a39
Author: Alex Auvolat <alex.auvolat@ens.fr>
Date:   Tue,  5 May 2015 10:56:21 -0400

Refactor architecture so that embedding sizes can be easily changed.

Diffstat:
Dconfig/simple_mlp_0.py | 19-------------------
Aconfig/simple_mlp_2_cs.py | 25+++++++++++++++++++++++++
Aconfig/simple_mlp_2_noembed.py | 22++++++++++++++++++++++
Dconfig/simple_mlp_tgtcls_0.py | 25-------------------------
Aconfig/simple_mlp_tgtcls_0_cs.py | 29+++++++++++++++++++++++++++++
Dconfig/simple_mlp_tgtcls_1.py | 25-------------------------
Aconfig/simple_mlp_tgtcls_1_cs.py | 29+++++++++++++++++++++++++++++
Mmodel/simple_mlp.py | 30++++++++++++++++--------------
Mmodel/simple_mlp_tgtcls.py | 31+++++++++++++++++--------------
Mtrain.py | 38++++++++++++++++++++------------------
10 files changed, 158 insertions(+), 115 deletions(-)

diff --git a/config/simple_mlp_0.py b/config/simple_mlp_0.py @@ -1,19 +0,0 @@ -import model.simple_mlp as model - -n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday -n_dom = 31 -n_hour = 24 - -n_begin_end_pts = 5 # how many points we consider at the beginning and end of the known trajectory -n_end_pts = 5 - -n_valid = 1000 - -dim_embed = 10 -dim_input = n_begin_end_pts * 2 * 2 + dim_embed + dim_embed -dim_hidden = [200, 100] -dim_output = 2 - -learning_rate = 0.0001 -momentum = 0.99 -batch_size = 32 diff --git a/config/simple_mlp_2_cs.py b/config/simple_mlp_2_cs.py @@ -0,0 +1,25 @@ +import model.simple_mlp as model + +import data + +n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday +n_dom = 31 +n_hour = 24 + +n_begin_end_pts = 5 # how many points we consider at the beginning and end of the known trajectory +n_end_pts = 5 + +n_valid = 1000 + +dim_embeddings = [ + ('origin_call', data.n_train_clients+1, 10), + ('origin_stand', data.n_stands+1, 10) +] + +dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings) +dim_hidden = [200, 100] +dim_output = 2 + +learning_rate = 0.0001 +momentum = 0.99 +batch_size = 32 diff --git a/config/simple_mlp_2_noembed.py b/config/simple_mlp_2_noembed.py @@ -0,0 +1,22 @@ +import model.simple_mlp as model + +import data + +n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday +n_dom = 31 +n_hour = 24 + +n_begin_end_pts = 5 # how many points we consider at the beginning and end of the known trajectory +n_end_pts = 5 + +n_valid = 1000 + +dim_embeddings = [] # do not use embeddings + +dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings) +dim_hidden = [200, 100] +dim_output = 2 + +learning_rate = 0.0001 +momentum = 0.99 +batch_size = 32 diff --git a/config/simple_mlp_tgtcls_0.py b/config/simple_mlp_tgtcls_0.py @@ -1,25 +0,0 @@ -import cPickle - -import data - -import model.simple_mlp_tgtcls as model - -n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday -n_dom = 31 -n_hour = 24 - -n_begin_end_pts = 5 # how many points we consider at the beginning and end of the known trajectory -n_end_pts = 5 - -n_valid = 1000 - -with open(data.DATA_PATH + "/arrival-clusters.pkl") as f: tgtcls = cPickle.load(f) - -dim_embed = 10 -dim_input = n_begin_end_pts * 2 * 2 + dim_embed + dim_embed -dim_hidden = [] -dim_output = tgtcls.shape[0] - -learning_rate = 0.0001 -momentum = 0.99 -batch_size = 32 diff --git a/config/simple_mlp_tgtcls_0_cs.py b/config/simple_mlp_tgtcls_0_cs.py @@ -0,0 +1,29 @@ +import cPickle + +import data + +import model.simple_mlp_tgtcls as model + +n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday +n_dom = 31 +n_hour = 24 + +n_begin_end_pts = 5 # how many points we consider at the beginning and end of the known trajectory +n_end_pts = 5 + +n_valid = 1000 + +with open(data.DATA_PATH + "/arrival-clusters.pkl") as f: tgtcls = cPickle.load(f) + +dim_embeddings = [ + ('origin_call', data.n_train_clients+1, 10), + ('origin_stand', data.n_stands+1, 10) +] + +dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings) +dim_hidden = [] +dim_output = tgtcls.shape[0] + +learning_rate = 0.0001 +momentum = 0.99 +batch_size = 32 diff --git a/config/simple_mlp_tgtcls_1.py b/config/simple_mlp_tgtcls_1.py @@ -1,25 +0,0 @@ -import cPickle - -import data - -import model.simple_mlp_tgtcls as model - -n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday -n_dom = 31 -n_hour = 24 - -n_begin_end_pts = 5 # how many points we consider at the beginning and end of the known trajectory -n_end_pts = 5 - -n_valid = 1000 - -with open(data.DATA_PATH + "/arrival-clusters.pkl") as f: tgtcls = cPickle.load(f) - -dim_embed = 10 -dim_input = n_begin_end_pts * 2 * 2 + dim_embed + dim_embed -dim_hidden = [500] -dim_output = tgtcls.shape[0] - -learning_rate = 0.0001 -momentum = 0.99 -batch_size = 32 diff --git a/config/simple_mlp_tgtcls_1_cs.py b/config/simple_mlp_tgtcls_1_cs.py @@ -0,0 +1,29 @@ +import cPickle + +import data + +import model.simple_mlp_tgtcls as model + +n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday +n_dom = 31 +n_hour = 24 + +n_begin_end_pts = 5 # how many points we consider at the beginning and end of the known trajectory +n_end_pts = 5 + +n_valid = 1000 + +with open(data.DATA_PATH + "/arrival-clusters.pkl") as f: tgtcls = cPickle.load(f) + +dim_embeddings = [ + ('origin_call', data.n_train_clients+1, 10), + ('origin_stand', data.n_stands+1, 10) +] + +dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings) +dim_hidden = [500] +dim_output = tgtcls.shape[0] + +learning_rate = 0.0001 +momentum = 0.99 +batch_size = 32 diff --git a/model/simple_mlp.py b/model/simple_mlp.py @@ -17,25 +17,27 @@ class Model(object): x_lastk_latitude = (tensor.matrix('last_k_latitude') - data.porto_center[0]) / data.data_std[0] x_lastk_longitude = (tensor.matrix('last_k_longitude') - data.porto_center[1]) / data.data_std[1] - x_client = tensor.lvector('origin_call') - x_stand = tensor.lvector('origin_stand') + input_list = [x_firstk_latitude, x_firstk_longitude, x_lastk_latitude, x_lastk_longitude] + embed_tables = [] + + self.require_inputs = ['first_k_latitude', 'first_k_longitude', 'last_k_latitude', 'last_k_longitude'] + + for (varname, num, dim) in config.dim_embeddings: + self.require_inputs.append(varname) + vardata = tensor.lvector(varname) + tbl = LookupTable(length=num, dim=dim, name='%s_lookup'%varname) + embed_tables.append(tbl) + input_list.append(tbl.apply(vardata)) y = tensor.concatenate((tensor.vector('destination_latitude')[:, None], tensor.vector('destination_longitude')[:, None]), axis=1) # Define the model - client_embed_table = LookupTable(length=data.n_train_clients+1, dim=config.dim_embed, name='client_lookup') - stand_embed_table = LookupTable(length=data.n_stands+1, dim=config.dim_embed, name='stand_lookup') mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()], dims=[config.dim_input] + config.dim_hidden + [config.dim_output]) # Create the Theano variables - client_embed = client_embed_table.apply(x_client) - stand_embed = stand_embed_table.apply(x_stand) - inputs = tensor.concatenate([x_firstk_latitude, x_firstk_longitude, - x_lastk_latitude, x_lastk_longitude, - client_embed, stand_embed], - axis=1) + inputs = tensor.concatenate(input_list, axis=1) # inputs = theano.printing.Print("inputs")(inputs) outputs = mlp.apply(inputs) @@ -55,13 +57,13 @@ class Model(object): hcost.name = 'hcost' # Initialization - client_embed_table.weights_init = IsotropicGaussian(0.001) - stand_embed_table.weights_init = IsotropicGaussian(0.001) + for tbl in embed_tables: + tbl.weights_init = IsotropicGaussian(0.001) mlp.weights_init = IsotropicGaussian(0.01) mlp.biases_init = Constant(0.001) - client_embed_table.initialize() - stand_embed_table.initialize() + for tbl in embed_tables: + tbl.initialize() mlp.initialize() self.cost = cost diff --git a/model/simple_mlp_tgtcls.py b/model/simple_mlp_tgtcls.py @@ -20,26 +20,29 @@ class Model(object): x_lastk_latitude = (tensor.matrix('last_k_latitude') - data.porto_center[0]) / data.data_std[0] x_lastk_longitude = (tensor.matrix('last_k_longitude') - data.porto_center[1]) / data.data_std[1] - x_client = tensor.lvector('origin_call') - x_stand = tensor.lvector('origin_stand') + input_list = [x_firstk_latitude, x_firstk_longitude, x_lastk_latitude, x_lastk_longitude] + embed_tables = [] + + self.require_inputs = ['first_k_latitude', 'first_k_longitude', 'last_k_latitude', 'last_k_longitude'] + + for (varname, num, dim) in config.dim_embeddings: + self.require_inputs.append(varname) + vardata = tensor.lvector(varname) + tbl = LookupTable(length=num, dim=dim, name='%s_lookup'%varname) + embed_tables.append(tbl) + input_list.append(tbl.apply(vardata)) y = tensor.concatenate((tensor.vector('destination_latitude')[:, None], tensor.vector('destination_longitude')[:, None]), axis=1) # Define the model - client_embed_table = LookupTable(length=data.n_train_clients+1, dim=config.dim_embed, name='client_lookup') - stand_embed_table = LookupTable(length=data.n_stands+1, dim=config.dim_embed, name='stand_lookup') mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Softmax()], dims=[config.dim_input] + config.dim_hidden + [config.dim_output]) classes = theano.shared(numpy.array(config.tgtcls, dtype=theano.config.floatX), name='classes') # Create the Theano variables - client_embed = client_embed_table.apply(x_client) - stand_embed = stand_embed_table.apply(x_stand) - inputs = tensor.concatenate([x_firstk_latitude, x_firstk_longitude, - x_lastk_latitude, x_lastk_longitude, - client_embed, stand_embed], - axis=1) + inputs = tensor.concatenate(input_list, axis=1) + # inputs = theano.printing.Print("inputs")(inputs) cls_probas = mlp.apply(inputs) outputs = tensor.dot(cls_probas, classes) @@ -56,13 +59,13 @@ class Model(object): hcost.name = 'hcost' # Initialization - client_embed_table.weights_init = IsotropicGaussian(0.001) - stand_embed_table.weights_init = IsotropicGaussian(0.001) + for tbl in embed_tables: + tbl.weights_init = IsotropicGaussian(0.001) mlp.weights_init = IsotropicGaussian(0.01) mlp.biases_init = Constant(0.001) - client_embed_table.initialize() - stand_embed_table.initialize() + for tbl in embed_tables: + tbl.initialize() mlp.initialize() self.cost = cost diff --git a/train.py b/train.py @@ -42,7 +42,7 @@ if __name__ == "__main__": config = importlib.import_module(model_name) -def setup_train_stream(): +def setup_train_stream(req_vars): # Load the training and test data train = H5PYDataset(data.H5DATA_PATH, which_set='train', @@ -51,34 +51,33 @@ def setup_train_stream(): train = DataStream(train, iteration_scheme=SequentialExampleScheme(data.dataset_size - config.n_valid)) train = transformers.filter_out_trips(data.valid_trips, train) train = transformers.TaxiGenerateSplits(train, max_splits=100) + train = transformers.add_first_k(config.n_begin_end_pts, train) train = transformers.add_last_k(config.n_begin_end_pts, train) - train = transformers.Select(train, ('origin_stand', 'origin_call', 'first_k_latitude', - 'last_k_latitude', 'first_k_longitude', 'last_k_longitude', - 'destination_latitude', 'destination_longitude')) + train = transformers.Select(train, tuple(req_vars)) + train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size)) return train_stream -def setup_valid_stream(): +def setup_valid_stream(req_vars): valid = DataStream(data.valid_data) + valid = transformers.add_first_k(config.n_begin_end_pts, valid) valid = transformers.add_last_k(config.n_begin_end_pts, valid) - valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k_latitude', - 'last_k_latitude', 'first_k_longitude', 'last_k_longitude', - 'destination_latitude', 'destination_longitude')) + valid = transformers.Select(valid, tuple(req_vars)) + valid_stream = Batch(valid, iteration_scheme=ConstantScheme(1000)) return valid_stream -def setup_test_stream(): - test = data.test_data +def setup_test_stream(req_vars): + test = DataStream(data.test_data) - test = DataStream(test) test = transformers.add_first_k(config.n_begin_end_pts, test) test = transformers.add_last_k(config.n_begin_end_pts, test) - test = transformers.Select(test, ('trip_id', 'origin_stand', 'origin_call', 'first_k_latitude', - 'last_k_latitude', 'first_k_longitude', 'last_k_longitude')) + test = transformers.Select(test, tuple(req_vars)) + test_stream = Batch(test, iteration_scheme=ConstantScheme(1000)) return test_stream @@ -91,8 +90,11 @@ def main(): hcost = model.hcost outputs = model.outputs - train_stream = setup_train_stream() - valid_stream = setup_valid_stream() + req_vars = model.require_inputs + [ 'destination_latitude', 'destination_longitude' ] + req_vars_test = model.require_inputs + [ 'trip_id' ] + + train_stream = setup_train_stream(req_vars) + valid_stream = setup_valid_stream(req_vars) # Training cg = ComputationGraph(cost) @@ -110,7 +112,7 @@ def main(): # Checkpoint('model.pkl', every_n_batches=100), Dump('model_data/' + model_name, every_n_batches=1000), LoadFromDump('model_data/' + model_name), - FinishAfter(after_epoch=10), + FinishAfter(after_epoch=42), ] main_loop = MainLoop( @@ -122,9 +124,9 @@ def main(): main_loop.profile.report() # Produce an output on the test data - test_stream = setup_test_stream() + test_stream = setup_test_stream(req_vars_test) - outfile = open("test-output-%s.csv" % model_name, "w") + outfile = open("output/test-output-%s.csv" % model_name, "w") outcsv = csv.writer(outfile) outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"]) for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']):