taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

commit fe608831c62c7dba60a3bf57433d97b999e567c8
parent e7aba08e6b209ac7f091eb9f08b49a2c90b070ed
Author: Étienne Simon <esimon@esimon.eu>
Date:   Thu, 23 Jul 2015 18:34:51 -0400

Fix tvt hdf5

Diffstat:
Mdata/make_tvt.py | 12++++++++++++
1 file changed, 12 insertions(+), 0 deletions(-)

diff --git a/data/make_tvt.py b/data/make_tvt.py @@ -31,6 +31,9 @@ native_fields = { all_fields = { 'path_len': numpy.int16, 'cluster': numpy.int16, + 'destination_latitude': numpy.float32, + 'destination_longitude': numpy.float32, + 'travel_time': numpy.int32, } all_fields.update(native_fields) @@ -125,6 +128,15 @@ def make_tvt(test_cuts_name, valid_cuts_name, outpath): i = train_i train_i += 1 + trajlen = len(traindata['latitude'][idtraj]) + if trajlen == 0: + hdata['destination_latitude'] = data.train_gps_mean[0] + hdata['destination_longitude'] = data.train_gps_mean[1] + else: + hdata['destination_latitude'] = traindata['latitude'][idtraj][-1] + hdata['destination_longitude'] = traindata['longitude'][idtraj][-1] + hdata['travel_time'] = trajlen + for field in native_fields: val = traindata[field][idtraj] if field in ['latitude', 'longitude']: