taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

commit 448e848796757ad9f0a2f681886f868b8f22e81f
parent c2c88c48a0404de0eb834df71fa53ae63fdfd1c7
Author: Alex Auvolat <alex.auvolat@ens.fr>
Date:   Fri, 22 May 2015 15:51:02 -0400

Add ElementwiseRemoveNotFinite step rule.

Diffstat:
Mtrain.py | 37+++++++++++++++++++++++++++++++++++--
1 file changed, 35 insertions(+), 2 deletions(-)

diff --git a/train.py b/train.py @@ -7,8 +7,10 @@ import os import sys from functools import reduce +from theano import tensor + from blocks import roles -from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite +from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule from blocks.extensions import Printing, FinishAfter from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring from blocks.extensions.plot import Plot @@ -21,6 +23,37 @@ from blocks.model import Model logger = logging.getLogger(__name__) + +class ElementwiseRemoveNotFinite(StepRule): + """A step rule that replaces non-finite coefficients by zeros. + + Replaces non-finite elements (such as ``inf`` or ``NaN``) in a step + (the parameter update of a single shared variable) + with a scaled version of the parameters being updated instead. + + Parameters + ---------- + scaler : float, optional + The scaling applied to the parameter in case the step contains + non-finite elements. Defaults to 0.1. + + Notes + ----- + This trick was originally used in the GroundHog_ framework. + + .. _GroundHog: https://github.com/lisa-groundhog/GroundHog + + """ + def __init__(self, scaler=0.1): + self.scaler = scaler + + def compute_step(self, param, previous_step): + not_finite = tensor.isnan(previous_step) + tensor.isinf(previous_step) + step = tensor.switch(not_finite, self.scaler * param, previous_step) + + return step, [] + + if __name__ == "__main__": if len(sys.argv) != 2: print >> sys.stderr, 'Usage: %s config' % sys.argv[0] @@ -66,7 +99,7 @@ if __name__ == "__main__": algorithm = GradientDescent( cost=cost, step_rule=CompositeRule([ - RemoveNotFinite(), + ElementwiseRemoveNotFinite(), AdaDelta(), #Momentum(learning_rate=config.learning_rate, momentum=config.momentum), ]),