transform

old TransE-like models
git clone https://esimon.eu/repos/transform.git
Log | Files | Refs | README

commit 1ce979232de97cb635a26d4ec2a574464a567c61
parent 7cb97ecd8c457d0cc0712148b927620fc047558e
Author: Étienne Simon <esimon@esimon.eu>
Date:   Wed, 23 Apr 2014 11:56:17 +0200

Change config keys

Diffstat:
Mdataset.py | 10+---------
Mmain.py | 2+-
Mmodel.py | 27+++++++++------------------
3 files changed, 11 insertions(+), 28 deletions(-)

diff --git a/dataset.py b/dataset.py @@ -17,6 +17,7 @@ class Dataset(object): self.load_file(prefix, 'train') self.load_file(prefix, 'valid') self.load_file(prefix, 'test') + self.universe = scipy.sparse.eye(len(self.embeddings), format='csr', dtype=theano.config.floatX) def load_file(self, prefix, name): with open('{0}/{1}'.format(prefix, name), 'r') as file: @@ -64,12 +65,3 @@ class Dataset(object): right = getattr(self, name+'_right') for i in xrange(N): yield (relation[i], left[i], right[i]) - - def universe_minibatch(self, batch_size): - N = len(self.embeddings) - entities = scipy.sparse.eye(N, format='csr', dtype=theano.config.floatX) - for i in xrange(N/batch_size): - yield entities[i*batch_size:(i+1)*batch_size] - last = (N/batch_size)*batch_size - if last != N: - yield entities[last:N] diff --git a/main.py b/main.py @@ -11,7 +11,7 @@ from relations import * if __name__ == '__main__': if len(sys.argv)<3: - print('Usage: {0} data parameters [model]'.format(sys.argv[0]), file=sys.stderr) + print('Usage: {0} data config [model]'.format(sys.argv[0]), file=sys.stderr) sys.exit(1) data = sys.argv[1] config_path = sys.argv[2] diff --git a/model.py b/model.py @@ -140,26 +140,17 @@ class Model(object): def error(self, name): """ Compute the mean rank and top 10 on a given data. """ - batch_size = self.config['test batch size'] count, mean, top10 = 0, 0, 0 for (relation, left, right) in self.dataset.iterate(name): left_scores, right_scores = None, None - for entities in self.dataset.universe_minibatch(batch_size): - left_batch_result = self.left_scoring_function(relation, left, entities) - right_batch_result = self.right_scoring_function(relation, entities, right) - if left_scores is None: - left_scores = numpy.array(left_batch_result, dtype=theano.config.floatX) - else: - left_scores = numpy.concatenate((left_scores, left_batch_result), axis=1) - if right_scores is None: - right_scores = numpy.array(right_batch_result, dtype=theano.config.floatX) - else: - right_scores = numpy.concatenate((right_scores, right_batch_result), axis=1) + entities = self.dataset.universe + left_scores = self.left_scoring_function(relation, left, entities) + right_scores = self.right_scoring_function(relation, entities, right) left_rank = 1+numpy.asscalar(numpy.where(numpy.argsort(left_scores)==right.indices[0])[1]) # FIXME Ugly right_rank = 1+numpy.asscalar(numpy.where(numpy.argsort(right_scores)==left.indices[0])[1]) # FIXME Ugly - count = count + 2 - mean = mean + left_rank + right_rank - top10 = top10 + (left_rank<=10) + (right_rank<=10) + count += 2 + mean += left_rank + right_rank + top10 += (left_rank<=10) + (right_rank<=10) mean = float(mean) / count top10 = float(top10) / count return (mean, top10) @@ -169,13 +160,13 @@ class Model(object): log('Validation epoch {:<5}'.format(epoch)) (valid_mean, valid_top10) = self.error('valid') log(' valid mean: {0:<15} valid top10: {1:<15}'.format(valid_mean, valid_top10)) - datalog(self.config['datalog filepath'], epoch, valid_mean, valid_top10) + datalog(self.config['datalog path']+'/'+self.config['model name'], epoch, valid_mean, valid_top10) if not hasattr(self, 'best_mean') or valid_mean < self.best_mean: self.best_mean = valid_mean log('(best so far') if self.config['save best model']: log(', saving...') - self.save(self.config['best model save location']) + self.save('{0}/{1}.best'.format(self.config['best model save path'], self.config['model name'])) log(' done') log(')') @@ -189,5 +180,5 @@ class Model(object): log('# Testing the model "{0}"'.format(self.tag)) (mean, top10) = self.error('test') log(' mean: {0:<15} top10: {1:<15} (saving...'.format(mean, top10)) - self.save(self.config['last model save location']) + self.save('{0}/{1}.last'.format(self.config['last model save path'], self.config['model name'])) log(' done)\n')