taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

cut.py (1051B)


      1 from fuel.schemes import IterationScheme
      2 import sqlite3
      3 import random
      4 import os
      5 from picklable_itertools import iter_
      6 
      7 import data
      8 
      9 first_time = 1372636853
     10 last_time = 1404172787
     11 
     12 
     13 class TaxiTimeCutScheme(IterationScheme):
     14     def __init__(self, num_cuts=100, dbfile=None, use_cuts=None):
     15         self.num_cuts = num_cuts
     16         self.dbfile = os.path.join(data.path, 'time_index.db') if dbfile == None else dbfile
     17         self.use_cuts = use_cuts
     18 
     19     def get_request_iterator(self):
     20         cuts = self.use_cuts
     21         if cuts == None:
     22             cuts = [random.randrange(first_time, last_time) for _ in range(self.num_cuts)]
     23 
     24         l = []
     25         with sqlite3.connect(self.dbfile) as db:
     26             c = db.cursor()
     27             for cut in cuts:
     28                 part = [i for (i,) in
     29                     c.execute('SELECT trip FROM trip_times WHERE begin >= ? AND begin <= ? AND end >= ?',
     30                                 (cut - 40000, cut, cut))]
     31                 l = l + part
     32         random.shuffle(l)
     33                 
     34         return iter_(l)
     35