transform

old TransE-like models
git clone https://esimon.eu/repos/transform.git
Log | Files | Refs | README

commit 6881a2a302c5abc3e2ef4b710fa2033ce83615ea
parent c61b71b63396648f490d9cb10e31de2bcdba601f
Author: Étienne Simon <esimon@esimon.eu>
Date:   Wed, 30 Apr 2014 15:52:35 +0200

Save epoch number

Diffstat:
Mmodel.py | 15+++++++++------
1 file changed, 9 insertions(+), 6 deletions(-)

diff --git a/model.py b/model.py @@ -43,11 +43,13 @@ class Model(object): if filepath is None: Relations = config['relations'] + self.epoch = 0 self.embeddings = Embeddings(config['rng'], dataset.number_embeddings, config['dimension'], self.tag+'.embeddings') self.relations = Relations(config['rng'], dataset.number_relations, config['dimension'], self.tag+'.relations') else: log('## Loading model from "{0}"\n'.format(filepath)) with open(filepath, 'rb') as file: + self.epoch = cPickle.load(file) self.embeddings = cPickle.load(file) self.relations = cPickle.load(file) @@ -122,13 +124,14 @@ class Model(object): validation_frequency = self.config['validation frequency'] number_epoch = self.config['number of epoch'] - for epoch in xrange(number_epoch): + while self.epoch < number_epoch: for (relation, left_positive, right_positive, left_negative, right_negative) in self.dataset.training_minibatch(batch_size): self.normalise_function() self.train_function(relation, left_positive, right_positive, left_negative, right_negative) - if (epoch+1) % validation_frequency == 0: - self.validate(epoch+1) + self.epoch += 1 + if self.epoch % validation_frequency == 0: + self.validate() def error(self, name, transform_scores=(lambda x: x)): """ Compute the mean rank, standard deviation and top 10 on a given data. """ @@ -147,12 +150,12 @@ class Model(object): top10 = numpy.mean(map(lambda x: x<=10, result)) return (mean, std, top10) - def validate(self, epoch): + def validate(self): """ Validate the model. """ - log('Validation epoch {:<5}'.format(epoch)) + log('Validation epoch {:<5}'.format(self.epoch)) (valid_mean, valid_std, valid_top10) = self.error('valid') log(' valid mean: {0:<15} valid std: {1:<15} valid top10: {2:<15}'.format(valid_mean, valid_std, valid_top10)) - datalog(self.config['datalog path']+'/'+self.config['model name'], epoch, valid_mean, valid_std, valid_top10) + datalog(self.config['datalog path']+'/'+self.config['model name'], self.epoch, valid_mean, valid_std, valid_top10) if not hasattr(self, 'best_mean') or valid_mean < self.best_mean: self.best_mean = valid_mean log('(best so far')