taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

ext_test.py (3705B)


      1 #!/usr/bin/env python
      2 
      3 import logging
      4 import os
      5 import csv
      6 
      7 from blocks.model import Model
      8 from blocks.extensions import SimpleExtension
      9 
     10 logger = logging.getLogger(__name__)
     11 
     12 class RunOnTest(SimpleExtension):
     13     def __init__(self, model_name, model, stream, **kwargs):
     14         super(RunOnTest, self).__init__(**kwargs)
     15 
     16         self.model_name = model_name
     17 
     18         cg = Model(model.predict(**stream.inputs()))
     19 
     20         self.inputs = cg.inputs
     21         self.outputs = model.predict.outputs
     22 
     23         req_vars_test = model.predict.inputs + ['trip_id']
     24         self.test_stream = stream.test(req_vars_test)
     25 
     26         self.function = cg.get_theano_function()
     27 
     28         self.best_dvc = None
     29         self.best_tvc = None
     30 
     31     def do(self, which_callback, *args):
     32         iter_no = self.main_loop.log.status['iterations_done']
     33         if 'valid_destination_cost' in self.main_loop.log.current_row:
     34             dvc = self.main_loop.log.current_row['valid_destination_cost']
     35         elif 'valid_model_cost_cost' in self.main_loop.log.current_row:
     36             dvc = self.main_loop.log.current_row['valid_model_cost_cost']
     37         elif 'valid_model_valid_cost_cost' in self.main_loop.log.current_row:
     38             dvc = self.main_loop.log.current_row['valid_model_valid_cost_cost']
     39         else:
     40             raise RuntimeError("Unknown model type")
     41 
     42         if 'valid_time_cost' in self.main_loop.log.current_row:
     43             tvc = self.main_loop.log.current_row['valid_time_cost']
     44         elif 'valid_model_cost_cost' in self.main_loop.log.current_row:
     45             tvc = self.main_loop.log.current_row['valid_model_cost_cost']
     46         elif 'valid_model_valid_cost_cost' in self.main_loop.log.current_row:
     47             tvc = self.main_loop.log.current_row['valid_model_valid_cost_cost']
     48         else:
     49             raise RuntimeError("Unknown model type")
     50 
     51         output_dvc = (self.best_dvc is None or dvc < self.best_dvc) and 'destination' in self.outputs
     52         output_tvc = (self.best_tvc is None or tvc < self.best_tvc) and 'duration' in self.outputs
     53 
     54         if not output_dvc and not output_tvc:
     55             return
     56 
     57         if output_dvc:
     58             self.best_dvc = dvc
     59             dest_outname = 'test-dest-%s-it%09d-cost%.3f.csv' % (self.model_name, iter_no, dvc)
     60             dest_outfile = open(os.path.join('output', dest_outname), 'w')
     61             dest_outcsv = csv.writer(dest_outfile)
     62             dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"])
     63             logger.info("Generating output for test set: %s" % dest_outname)
     64         if output_tvc:
     65             self.best_tvc = tvc
     66             time_outname = 'test-time-%s-it%09d-cost%.3f.csv' % (self.model_name, iter_no, tvc)
     67             time_outfile = open(os.path.join('output', time_outname), 'w')
     68             time_outcsv = csv.writer(time_outfile)
     69             time_outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"])
     70             logger.info("Generating output for test set: %s" % time_outname)
     71 
     72         for d in self.test_stream.get_epoch_iterator(as_dict=True):
     73             input_values = [d[k.name] for k in self.inputs]
     74             output_values = self.function(*input_values)
     75             for i in range(d['trip_id'].shape[0]):
     76                 if output_dvc:
     77                     destination = output_values[self.outputs.index('destination')]
     78                     dest_outcsv.writerow([d['trip_id'][i], destination[i, 0], destination[i, 1]])
     79                 if output_tvc:
     80                     duration = output_values[self.outputs.index('duration')]
     81                     time_outcsv.writerow([d['trip_id'][i], int(round(duration[i]))])
     82 
     83         if output_dvc:
     84             dest_outfile.close()
     85         if output_tvc:
     86             time_outfile.close()
     87