commit 1ce979232de97cb635a26d4ec2a574464a567c61
parent 7cb97ecd8c457d0cc0712148b927620fc047558e
Author: Étienne Simon <esimon@esimon.eu>
Date: Wed, 23 Apr 2014 11:56:17 +0200
Change config keys
Diffstat:
3 files changed, 11 insertions(+), 28 deletions(-)
diff --git a/dataset.py b/dataset.py
@@ -17,6 +17,7 @@ class Dataset(object):
self.load_file(prefix, 'train')
self.load_file(prefix, 'valid')
self.load_file(prefix, 'test')
+ self.universe = scipy.sparse.eye(len(self.embeddings), format='csr', dtype=theano.config.floatX)
def load_file(self, prefix, name):
with open('{0}/{1}'.format(prefix, name), 'r') as file:
@@ -64,12 +65,3 @@ class Dataset(object):
right = getattr(self, name+'_right')
for i in xrange(N):
yield (relation[i], left[i], right[i])
-
- def universe_minibatch(self, batch_size):
- N = len(self.embeddings)
- entities = scipy.sparse.eye(N, format='csr', dtype=theano.config.floatX)
- for i in xrange(N/batch_size):
- yield entities[i*batch_size:(i+1)*batch_size]
- last = (N/batch_size)*batch_size
- if last != N:
- yield entities[last:N]
diff --git a/main.py b/main.py
@@ -11,7 +11,7 @@ from relations import *
if __name__ == '__main__':
if len(sys.argv)<3:
- print('Usage: {0} data parameters [model]'.format(sys.argv[0]), file=sys.stderr)
+ print('Usage: {0} data config [model]'.format(sys.argv[0]), file=sys.stderr)
sys.exit(1)
data = sys.argv[1]
config_path = sys.argv[2]
diff --git a/model.py b/model.py
@@ -140,26 +140,17 @@ class Model(object):
def error(self, name):
""" Compute the mean rank and top 10 on a given data. """
- batch_size = self.config['test batch size']
count, mean, top10 = 0, 0, 0
for (relation, left, right) in self.dataset.iterate(name):
left_scores, right_scores = None, None
- for entities in self.dataset.universe_minibatch(batch_size):
- left_batch_result = self.left_scoring_function(relation, left, entities)
- right_batch_result = self.right_scoring_function(relation, entities, right)
- if left_scores is None:
- left_scores = numpy.array(left_batch_result, dtype=theano.config.floatX)
- else:
- left_scores = numpy.concatenate((left_scores, left_batch_result), axis=1)
- if right_scores is None:
- right_scores = numpy.array(right_batch_result, dtype=theano.config.floatX)
- else:
- right_scores = numpy.concatenate((right_scores, right_batch_result), axis=1)
+ entities = self.dataset.universe
+ left_scores = self.left_scoring_function(relation, left, entities)
+ right_scores = self.right_scoring_function(relation, entities, right)
left_rank = 1+numpy.asscalar(numpy.where(numpy.argsort(left_scores)==right.indices[0])[1]) # FIXME Ugly
right_rank = 1+numpy.asscalar(numpy.where(numpy.argsort(right_scores)==left.indices[0])[1]) # FIXME Ugly
- count = count + 2
- mean = mean + left_rank + right_rank
- top10 = top10 + (left_rank<=10) + (right_rank<=10)
+ count += 2
+ mean += left_rank + right_rank
+ top10 += (left_rank<=10) + (right_rank<=10)
mean = float(mean) / count
top10 = float(top10) / count
return (mean, top10)
@@ -169,13 +160,13 @@ class Model(object):
log('Validation epoch {:<5}'.format(epoch))
(valid_mean, valid_top10) = self.error('valid')
log(' valid mean: {0:<15} valid top10: {1:<15}'.format(valid_mean, valid_top10))
- datalog(self.config['datalog filepath'], epoch, valid_mean, valid_top10)
+ datalog(self.config['datalog path']+'/'+self.config['model name'], epoch, valid_mean, valid_top10)
if not hasattr(self, 'best_mean') or valid_mean < self.best_mean:
self.best_mean = valid_mean
log('(best so far')
if self.config['save best model']:
log(', saving...')
- self.save(self.config['best model save location'])
+ self.save('{0}/{1}.best'.format(self.config['best model save path'], self.config['model name']))
log(' done')
log(')')
@@ -189,5 +180,5 @@ class Model(object):
log('# Testing the model "{0}"'.format(self.tag))
(mean, top10) = self.error('test')
log(' mean: {0:<15} top10: {1:<15} (saving...'.format(mean, top10))
- self.save(self.config['last model save location'])
+ self.save('{0}/{1}.last'.format(self.config['last model save path'], self.config['model name']))
log(' done)\n')