commit 6881a2a302c5abc3e2ef4b710fa2033ce83615ea
parent c61b71b63396648f490d9cb10e31de2bcdba601f
Author: Étienne Simon <esimon@esimon.eu>
Date: Wed, 30 Apr 2014 15:52:35 +0200
Save epoch number
Diffstat:
1 file changed, 9 insertions(+), 6 deletions(-)
diff --git a/model.py b/model.py
@@ -43,11 +43,13 @@ class Model(object):
if filepath is None:
Relations = config['relations']
+ self.epoch = 0
self.embeddings = Embeddings(config['rng'], dataset.number_embeddings, config['dimension'], self.tag+'.embeddings')
self.relations = Relations(config['rng'], dataset.number_relations, config['dimension'], self.tag+'.relations')
else:
log('## Loading model from "{0}"\n'.format(filepath))
with open(filepath, 'rb') as file:
+ self.epoch = cPickle.load(file)
self.embeddings = cPickle.load(file)
self.relations = cPickle.load(file)
@@ -122,13 +124,14 @@ class Model(object):
validation_frequency = self.config['validation frequency']
number_epoch = self.config['number of epoch']
- for epoch in xrange(number_epoch):
+ while self.epoch < number_epoch:
for (relation, left_positive, right_positive, left_negative, right_negative) in self.dataset.training_minibatch(batch_size):
self.normalise_function()
self.train_function(relation, left_positive, right_positive, left_negative, right_negative)
- if (epoch+1) % validation_frequency == 0:
- self.validate(epoch+1)
+ self.epoch += 1
+ if self.epoch % validation_frequency == 0:
+ self.validate()
def error(self, name, transform_scores=(lambda x: x)):
""" Compute the mean rank, standard deviation and top 10 on a given data. """
@@ -147,12 +150,12 @@ class Model(object):
top10 = numpy.mean(map(lambda x: x<=10, result))
return (mean, std, top10)
- def validate(self, epoch):
+ def validate(self):
""" Validate the model. """
- log('Validation epoch {:<5}'.format(epoch))
+ log('Validation epoch {:<5}'.format(self.epoch))
(valid_mean, valid_std, valid_top10) = self.error('valid')
log(' valid mean: {0:<15} valid std: {1:<15} valid top10: {2:<15}'.format(valid_mean, valid_std, valid_top10))
- datalog(self.config['datalog path']+'/'+self.config['model name'], epoch, valid_mean, valid_std, valid_top10)
+ datalog(self.config['datalog path']+'/'+self.config['model name'], self.epoch, valid_mean, valid_std, valid_top10)
if not hasattr(self, 'best_mean') or valid_mean < self.best_mean:
self.best_mean = valid_mean
log('(best so far')