taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

cluster_arrival.py (1052B)


      1 #!/usr/bin/env python2
      2 import numpy
      3 import cPickle
      4 import scipy.misc
      5 import os
      6 
      7 from sklearn.cluster import MeanShift, estimate_bandwidth
      8 from sklearn.datasets.samples_generator import make_blobs
      9 from itertools import cycle
     10 
     11 import data
     12 from data.hdf5 import taxi_it
     13 from data.transformers import add_destination
     14 
     15 print "Generating arrival point list"
     16 dests = []
     17 for v in taxi_it("train"):
     18     if len(v['latitude']) == 0: continue
     19     dests.append([v['latitude'][-1], v['longitude'][-1]])
     20 pts = numpy.array(dests)
     21 
     22 with open(os.path.join(data.path, "arrivals.pkl"), "w") as f:
     23     cPickle.dump(pts, f, protocol=cPickle.HIGHEST_PROTOCOL)
     24 
     25 print "Doing clustering"
     26 bw = estimate_bandwidth(pts, quantile=.1, n_samples=1000)
     27 print bw
     28 bw = 0.001 # (
     29 
     30 ms = MeanShift(bandwidth=bw, bin_seeding=True, min_bin_freq=5)
     31 ms.fit(pts)
     32 cluster_centers = ms.cluster_centers_
     33 
     34 print "Clusters shape: ", cluster_centers.shape
     35 
     36 with open(os.path.join(data.path, "arrival-clusters.pkl"), "w") as f:
     37     cPickle.dump(cluster_centers, f, protocol=cPickle.HIGHEST_PROTOCOL)
     38