transform

old TransE-like models
git clone https://esimon.eu/repos/transform.git
Log | Files | Refs | README

commit f258bea01d6ea61fa97a4e0d9d81009bbbd354f8
parent 92227335bd42489fe0a25e7670b36d05fb4a519f
Author: Étienne Simon <esimon@esimon.eu>
Date:   Thu, 17 Apr 2014 12:18:07 +0200

Add auto-saving & Fix loading/saving

Diffstat:
Mmain.py | 12++++++++----
Mmodel.py | 12+++++++++++-
2 files changed, 19 insertions(+), 5 deletions(-)

diff --git a/main.py b/main.py @@ -9,18 +9,22 @@ from relations.translations import * if __name__ == '__main__': if len(sys.argv)<3: - print >>sys.stderr, 'Usage: {0} data parameters'.format(sys.argv[0]) + print >>sys.stderr, 'Usage: {0} data parameters [model]'.format(sys.argv[0]) sys.exit(1) data = sys.argv[1] config = sys.argv[2] + model_path = None if len(sys.argv)<4 else sys.argv[3] with open(config, 'r') as config_file: hyperparameters = json.load(config_file) for k, v in hyperparameters.iteritems(): - if isinstance(v, basestring): - hyperparameters[k] = eval(v) + if isinstance(v, basestring) and v.startswith('python:'): + hyperparameters[k] = eval(v[7:]) data = Dataset(data) - model = Model.initialise(Translations, data, hyperparameters, 'TransE') + if model_path is None: + model = Model.initialise(Translations, data, hyperparameters, 'TransE') + else: + model = Model.load(model_path, data, hyperparameters, 'TransE') model.train() model.test() diff --git a/model.py b/model.py @@ -1,6 +1,7 @@ #!/usr/bin/env python2 import sys +import cPickle import numpy import scipy import theano @@ -44,13 +45,14 @@ class Model(object): return self @classmethod - def load(cls, filepath, dataset, hyperparameters): + def load(cls, filepath, dataset, hyperparameters, tag): """ Load a model from a file. Keyword arguments: filepath -- path to the Model file dataset -- dataset on which the model will be trained and tested hyperparameters -- hyperparameters dictionary + tag -- name of the embeddings for parameter declaration """ print >>sys.stderr, '# Loading model from "{0}"'.format(filepath) @@ -60,6 +62,7 @@ class Model(object): self.relations = cPickle.load(file) self.dataset = dataset; self.hyperparameters = hyperparameters; + self.tag = tag self.build() return self @@ -153,6 +156,13 @@ class Model(object): print >>sys.stderr, 'Validation epoch {:<5}'.format(epoch), (valid_mean, valid_top10) = self.error('valid') print >>sys.stderr, 'valid mean: {0:<15} valid top10: {1:<15}'.format(valid_mean, valid_top10), + if not hasattr(self, 'best_mean') or valid_mean > self.best_mean: + print >>sys.stderr, ' (best so far', + if self.hyperparameters['save best model']: + print >>sys.stderr, ' saving', + self.save(self.hyperparameters['best model save location']) + print >>sys.stderr, ')', + if self.hyperparameters['validate on training data']: (train_mean, train_top10) = self.error('train') print >>sys.stderr, 'train mean: {0:<15} train top10: {1:<15}'.format(train_mean, train_top10)