taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

memory_network_bidir_2.py (1474B)


      1 from blocks.initialization import IsotropicGaussian, Constant
      2 
      3 from blocks.bricks import Tanh
      4 
      5 import data
      6 from model.memory_network_bidir import Model, Stream
      7 
      8 
      9 dim_embeddings = [
     10     ('origin_call', data.origin_call_train_size, 10),
     11     ('origin_stand', data.stands_size, 10),
     12     ('week_of_year', 52, 10),
     13     ('day_of_week', 7, 10),
     14     ('qhour_of_day', 24 * 4, 10),
     15     ('day_type', 3, 10),
     16 ]
     17 
     18 embed_weights_init = IsotropicGaussian(0.001)
     19 
     20 
     21 class RNNConfig(object):
     22     __slots__ = ('rec_state_dim', 'dim_embeddings', 'embed_weights_init',
     23                  'dim_hidden', 'weights_init', 'biases_init')
     24 
     25 prefix_encoder = RNNConfig()
     26 prefix_encoder.dim_embeddings = dim_embeddings
     27 prefix_encoder.embed_weights_init = embed_weights_init
     28 prefix_encoder.rec_state_dim = 100
     29 prefix_encoder.dim_hidden = [100, 100]
     30 prefix_encoder.weights_init = IsotropicGaussian(0.01)
     31 prefix_encoder.biases_init = Constant(0.001)
     32 
     33 candidate_encoder = RNNConfig()
     34 candidate_encoder.dim_embeddings = dim_embeddings
     35 candidate_encoder.embed_weights_init = embed_weights_init
     36 candidate_encoder.rec_state_dim = 100
     37 candidate_encoder.dim_hidden = [100, 100]
     38 candidate_encoder.weights_init = IsotropicGaussian(0.01)
     39 candidate_encoder.biases_init = Constant(0.001)
     40 
     41 representation_size = 100
     42 representation_activation = Tanh
     43 
     44 normalize_representation = True
     45 
     46 
     47 batch_size = 100
     48 batch_sort_size = 20
     49 
     50 max_splits = 100
     51 
     52 train_candidate_size = 1000
     53 valid_candidate_size = 1000
     54 test_candidate_size = 1000