commit 051a8e243443b93520a4c6bd506abe1220550802
parent 4be068298edfd7777e6de5572419f9c2bdcbd985
Author: Étienne Simon <esimon@esimon.eu>
Date: Wed, 16 Apr 2014 18:26:15 +0200
Add and change hyperparameters
Diffstat:
1 file changed, 11 insertions(+), 8 deletions(-)
diff --git a/model.py b/model.py
@@ -93,17 +93,17 @@ class Model(object):
Keyword arguments:
cost -- The cost to optimise.
"""
- lr_relations = self.hyperparameters['relation_learning_rate']
- lr_embeddings = self.hyperparameters['embeddings_learning_rate']
+ lr_relations = self.hyperparameters['relation learning rate']
+ lr_embeddings = self.hyperparameters['embeddings learning rate']
return self.relations.updates(cost, lr_relations) + self.embeddings.updates(cost, lr_embeddings)
def train(self):
""" Train the model. """
print >>sys.stderr, '# Training the model "{0}"'.format(self.tag)
- batch_size = self.hyperparameters['train_batch_size']
- validation_frequency = self.hyperparameters['validation_frequency']
- number_epoch = self.hyperparameters['number_epoch']
+ batch_size = self.hyperparameters['train batch size']
+ validation_frequency = self.hyperparameters['validation frequency']
+ number_epoch = self.hyperparameters['number of epoch']
for epoch in xrange(number_epoch):
if epoch % validation_frequency == 0:
@@ -115,7 +115,7 @@ class Model(object):
def error(self, name):
""" Compute the mean rank and top 10 on a given data. """
- batch_size = self.hyperparameters['test_batch_size']
+ batch_size = self.hyperparameters['test batch size']
count, mean, top10 = 0, 0, 0
for (relation, left, right) in self.dataset.iterate(name, batch_size): # TODO Test symmetric
scores = None
@@ -138,8 +138,11 @@ class Model(object):
print >>sys.stderr, 'Validation epoch {:<5}'.format(epoch),
(valid_mean, valid_top10) = self.error('valid')
print >>sys.stderr, 'valid mean: {0:<15} valid top10: {1:<15}'.format(valid_mean, valid_top10),
- (train_mean, train_top10) = self.error('train')
- print >>sys.stderr, 'train mean: {0:<15} train top10: {1:<15}'.format(train_mean, train_top10)
+ if self.hyperparameters['validate on training data']:
+ (train_mean, train_top10) = self.error('train')
+ print >>sys.stderr, 'train mean: {0:<15} train top10: {1:<15}'.format(train_mean, train_top10)
+ else
+ print >>sys.stderr, ''
def test(self):
""" Test the model. """