commit f258bea01d6ea61fa97a4e0d9d81009bbbd354f8
parent 92227335bd42489fe0a25e7670b36d05fb4a519f
Author: Étienne Simon <esimon@esimon.eu>
Date: Thu, 17 Apr 2014 12:18:07 +0200
Add auto-saving & Fix loading/saving
Diffstat:
2 files changed, 19 insertions(+), 5 deletions(-)
diff --git a/main.py b/main.py
@@ -9,18 +9,22 @@ from relations.translations import *
if __name__ == '__main__':
if len(sys.argv)<3:
- print >>sys.stderr, 'Usage: {0} data parameters'.format(sys.argv[0])
+ print >>sys.stderr, 'Usage: {0} data parameters [model]'.format(sys.argv[0])
sys.exit(1)
data = sys.argv[1]
config = sys.argv[2]
+ model_path = None if len(sys.argv)<4 else sys.argv[3]
with open(config, 'r') as config_file:
hyperparameters = json.load(config_file)
for k, v in hyperparameters.iteritems():
- if isinstance(v, basestring):
- hyperparameters[k] = eval(v)
+ if isinstance(v, basestring) and v.startswith('python:'):
+ hyperparameters[k] = eval(v[7:])
data = Dataset(data)
- model = Model.initialise(Translations, data, hyperparameters, 'TransE')
+ if model_path is None:
+ model = Model.initialise(Translations, data, hyperparameters, 'TransE')
+ else:
+ model = Model.load(model_path, data, hyperparameters, 'TransE')
model.train()
model.test()
diff --git a/model.py b/model.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python2
import sys
+import cPickle
import numpy
import scipy
import theano
@@ -44,13 +45,14 @@ class Model(object):
return self
@classmethod
- def load(cls, filepath, dataset, hyperparameters):
+ def load(cls, filepath, dataset, hyperparameters, tag):
""" Load a model from a file.
Keyword arguments:
filepath -- path to the Model file
dataset -- dataset on which the model will be trained and tested
hyperparameters -- hyperparameters dictionary
+ tag -- name of the embeddings for parameter declaration
"""
print >>sys.stderr, '# Loading model from "{0}"'.format(filepath)
@@ -60,6 +62,7 @@ class Model(object):
self.relations = cPickle.load(file)
self.dataset = dataset;
self.hyperparameters = hyperparameters;
+ self.tag = tag
self.build()
return self
@@ -153,6 +156,13 @@ class Model(object):
print >>sys.stderr, 'Validation epoch {:<5}'.format(epoch),
(valid_mean, valid_top10) = self.error('valid')
print >>sys.stderr, 'valid mean: {0:<15} valid top10: {1:<15}'.format(valid_mean, valid_top10),
+ if not hasattr(self, 'best_mean') or valid_mean > self.best_mean:
+ print >>sys.stderr, ' (best so far',
+ if self.hyperparameters['save best model']:
+ print >>sys.stderr, ' saving',
+ self.save(self.hyperparameters['best model save location'])
+ print >>sys.stderr, ')',
+
if self.hyperparameters['validate on training data']:
(train_mean, train_top10) = self.error('train')
print >>sys.stderr, 'train mean: {0:<15} train top10: {1:<15}'.format(train_mean, train_top10)