transform

old TransE-like models
git clone https://esimon.eu/repos/transform.git
Log | Files | Refs | README

commit 051a8e243443b93520a4c6bd506abe1220550802
parent 4be068298edfd7777e6de5572419f9c2bdcbd985
Author: Étienne Simon <esimon@esimon.eu>
Date:   Wed, 16 Apr 2014 18:26:15 +0200

Add and change hyperparameters

Diffstat:
Mmodel.py | 19+++++++++++--------
1 file changed, 11 insertions(+), 8 deletions(-)

diff --git a/model.py b/model.py @@ -93,17 +93,17 @@ class Model(object): Keyword arguments: cost -- The cost to optimise. """ - lr_relations = self.hyperparameters['relation_learning_rate'] - lr_embeddings = self.hyperparameters['embeddings_learning_rate'] + lr_relations = self.hyperparameters['relation learning rate'] + lr_embeddings = self.hyperparameters['embeddings learning rate'] return self.relations.updates(cost, lr_relations) + self.embeddings.updates(cost, lr_embeddings) def train(self): """ Train the model. """ print >>sys.stderr, '# Training the model "{0}"'.format(self.tag) - batch_size = self.hyperparameters['train_batch_size'] - validation_frequency = self.hyperparameters['validation_frequency'] - number_epoch = self.hyperparameters['number_epoch'] + batch_size = self.hyperparameters['train batch size'] + validation_frequency = self.hyperparameters['validation frequency'] + number_epoch = self.hyperparameters['number of epoch'] for epoch in xrange(number_epoch): if epoch % validation_frequency == 0: @@ -115,7 +115,7 @@ class Model(object): def error(self, name): """ Compute the mean rank and top 10 on a given data. """ - batch_size = self.hyperparameters['test_batch_size'] + batch_size = self.hyperparameters['test batch size'] count, mean, top10 = 0, 0, 0 for (relation, left, right) in self.dataset.iterate(name, batch_size): # TODO Test symmetric scores = None @@ -138,8 +138,11 @@ class Model(object): print >>sys.stderr, 'Validation epoch {:<5}'.format(epoch), (valid_mean, valid_top10) = self.error('valid') print >>sys.stderr, 'valid mean: {0:<15} valid top10: {1:<15}'.format(valid_mean, valid_top10), - (train_mean, train_top10) = self.error('train') - print >>sys.stderr, 'train mean: {0:<15} train top10: {1:<15}'.format(train_mean, train_top10) + if self.hyperparameters['validate on training data']: + (train_mean, train_top10) = self.error('train') + print >>sys.stderr, 'train mean: {0:<15} train top10: {1:<15}'.format(train_mean, train_top10) + else + print >>sys.stderr, '' def test(self): """ Test the model. """