transform

old TransE-like models
git clone https://esimon.eu/repos/transform.git
Log | Files | Refs | README

commit 94da3eeeb689401e0aac84ad424397c0943a3d6f
parent 7f26ea262456312050e3d6c3e1d3aaf9039bd0b9
Author: Étienne Simon <esimon@esimon.eu>
Date:   Thu, 17 Apr 2014 18:16:34 +0200

Clean output

Diffstat:
Mdataset.py | 4++--
Mmain.py | 3++-
Mmodel.py | 35+++++++++++++++++------------------
Mutils/build Bordes FB15k.py | 34++++++++++++++++++----------------
Mutils/build dummy dataset.py | 3++-
5 files changed, 41 insertions(+), 38 deletions(-)

diff --git a/dataset.py b/dataset.py @@ -1,13 +1,13 @@ #!/usr/bin/env python2 +from utils.log import log import scipy import numpy -import sys import theano class Dataset(object): def __init__(self, prefix): - print >>sys.stderr, '# Loading dataset "{0}"'.format(prefix) + log('# Loading dataset "{0}"\n'.format(prefix)) with open(prefix+'/embeddings', 'r') as file: self.embeddings = file.readlines() with open(prefix+'/relations', 'r') as file: diff --git a/main.py b/main.py @@ -1,5 +1,6 @@ #!/usr/bin/env python2 +from utils.log import log import sys import json @@ -9,7 +10,7 @@ from relations.translations import * if __name__ == '__main__': if len(sys.argv)<3: - print >>sys.stderr, 'Usage: {0} data parameters [model]'.format(sys.argv[0]) + log('Usage: {0} data parameters [model]\n'.format(sys.argv[0]), file=sys.stderr) sys.exit(1) data = sys.argv[1] config = sys.argv[2] diff --git a/model.py b/model.py @@ -1,6 +1,6 @@ #!/usr/bin/env python2 -import sys +from utils.log import log import cPickle import numpy import scipy @@ -32,7 +32,7 @@ class Model(object): hyperparameters -- hyperparameters dictionary tag -- name of the embeddings for parameter declaration """ - print >>sys.stderr, '# Initialising model "{0}"'.format(tag) + log('# Initialising model "{0}"\n'.format(tag)) self = cls() self.embeddings = Embeddings(hyperparameters['rng'], dataset.number_embeddings, hyperparameters['dimension'], tag+'.embeddings') @@ -54,7 +54,7 @@ class Model(object): hyperparameters -- hyperparameters dictionary tag -- name of the embeddings for parameter declaration """ - print >>sys.stderr, '# Loading model from "{0}"'.format(filepath) + log('# Loading model from "{0}"\n'.format(filepath)) self = cls() with open(filepath, 'rb') as file: @@ -75,7 +75,7 @@ class Model(object): def build(self): """ Build theano functions. """ - print >>sys.stderr, '## Compiling Theano graph for model "{0}"'.format(self.tag) + log('## Compiling Theano graph for model "{0}"\n'.format(self.tag)) self.parameters = self.relations.parameters + self.embeddings.parameters inputs = tuple(S.csr_matrix() for _ in xrange(5)) @@ -111,7 +111,7 @@ class Model(object): def train(self): """ Train the model. """ - print >>sys.stderr, '# Training the model "{0}"'.format(self.tag) + log('# Training the model "{0}"\n'.format(self.tag)) batch_size = self.hyperparameters['train batch size'] validation_frequency = self.hyperparameters['validation frequency'] @@ -153,28 +153,27 @@ class Model(object): def validate(self, epoch): """ Validate the model. """ - print >>sys.stderr, 'Validation epoch {:<5}'.format(epoch), + log('Validation epoch {:<5}'.format(epoch)) (valid_mean, valid_top10) = self.error('valid') - print >>sys.stderr, 'valid mean: {0:<15} valid top10: {1:<15}'.format(valid_mean, valid_top10), + log(' valid mean: {0:<15} valid top10: {1:<15}'.format(valid_mean, valid_top10)) if not hasattr(self, 'best_mean') or valid_mean < self.best_mean: self.best_mean = valid_mean + log('(best so far') if self.hyperparameters['save best model']: - print >>sys.stderr, '(best so far, saving...', + log(', saving...') self.save(self.hyperparameters['best model save location']) - print >>sys.stderr, 'done)' - else: - print >>sys.stderr, '(best so far)' - else: - print >>sys.stderr, '' + log(' done') + log(')') if self.hyperparameters['validate on training data']: (train_mean, train_top10) = self.error('train') - print >>sys.stderr, 'train mean: {0:<15} train top10: {1:<15}'.format(train_mean, train_top10) - else: - print >>sys.stderr, '' + log(' train mean: {0:<15} train top10: {1:<15}'.format(train_mean, train_top10)) + log('\n') def test(self): """ Test the model. """ - print >>sys.stderr, '# Testing the model "{0}"'.format(self.tag), + log('# Testing the model "{0}"'.format(self.tag)) (mean, top10) = self.error('test') - print >>sys.stderr, ' mean: {0:<15} top10: {1:<15}'.format(mean, top10) + log(' mean: {0:<15} top10: {1:<15} (saving...'.format(mean, top10)) + self.save(self.hyperparameters['last model save location']) + log(' done)\n') diff --git a/utils/build Bordes FB15k.py b/utils/build Bordes FB15k.py @@ -1,7 +1,9 @@ #!/usr/bin/env python2 +from __future__ import print_function import sys import os +from log import log urls = [ 'https://www.hds.utc.fr/everest/lib/exe/fetch.php?id=en%3Atranse&cache=cache&media=en:fb15k.tgz' ] @@ -10,22 +12,22 @@ def get_archive(path): class URLopener(urllib.FancyURLopener): def http_error_default(self, url, fp, errcode, errmsg, headers): - print >>sys.stderr, 'Error: {0} {1}'.format(errcode, errmsg) + print('Error: {0} {1}'.format(errcode, errmsg), file=sys.stderr) raise IOError archive = path+'/archive.tgz' downloaded = False for url in urls: - print >>sys.stderr, 'Downloading dataset from "{0}"...'.format(url), + log('Downloading dataset from "{0}"...'.format(url)) try: URLopener().retrieve(url, archive) downloaded = True - print >>sys.stderr, ' done' + log(' done\n') except IOError: pass if not downloaded: - print >>sys.stderr, 'Error: Unable to download dataset.' + print('Error: Unable to download dataset.', file=sys.stderr) sys.exit(1) def get_raw(path): @@ -34,28 +36,28 @@ def get_raw(path): get_archive(path) - print >>sys.stderr, 'Raw files not found, extracting archive...', + log('Raw files not found, extracting archive...') raw = path+'/raw' os.mkdir(raw) import tarfile tar = tarfile.open(path+'/archive.tgz', 'r:gz') tar.extractall(raw) - print >>sys.stderr, ' done' + log(' done\n') def compile_dataset(path): get_raw(path) prefix = path+'/raw/FB15k/freebase_mtr100_mte100-' suffix = '.txt' - print >>sys.stderr, 'Reading train file...', + log('Reading train file...') with open(prefix+'train'+suffix, 'r') as file: content = map(lambda line: line.rstrip('\n').split('\t'), file.readlines()) [left, relations, right] = map(set, zip(*content)) entities = left | right - print >>sys.stderr, ' done' + log(' done\n') - print >>sys.stderr, 'Writting entities...', + log('Writting entities...') e2i, i2e, r2i, i2r = {}, {}, {}, {} with open(path+'/embeddings', 'w') as file: i=0 @@ -64,9 +66,9 @@ def compile_dataset(path): i2e[i]=entity file.write(entity+'\n') i+=1 - print >>sys.stderr, ' done ({0} entities written)'.format(i) + log(' done ({0} entities written)\n'.format(i)) - print >>sys.stderr, 'Writting relations...', + log('Writting relations...') with open(path+'/relations', 'w') as file: i=0 for relation in relations: @@ -74,10 +76,10 @@ def compile_dataset(path): i2r[i]=relation file.write(relation+'\n') i+=1 - print >>sys.stderr, ' done ({0} relations written)'.format(i) + log(' done ({0} relations written)\n'.format(i)) for name in ['train', 'valid', 'test']: - print >>sys.stderr, 'Compiling {0}...'.format(name), + log('Compiling {0}...'.format(name)) count = 0 with open(prefix+name+suffix, 'r') as infile: with open(path+'/'+name, 'w') as outfile: @@ -87,11 +89,11 @@ def compile_dataset(path): outfile.write('{0}\t{1}\t{2}\n'.format(e2i[left], r2i[relation], e2i[right])) else: count+=1 - print >>sys.stderr, ' done ({0} entit{1} removed)'.format(count, 'y' if count<2 else 'ies') + log(' done ({0} entit{1} removed)\n'.format(count, 'y' if count<2 else 'ies')) if __name__ == '__main__': if len(sys.argv)<2: - print >>sys.stderr, 'Usage: {0} path'.format(sys.argv[0]) + print('Usage: {0} path'.format(sys.argv[0]), file=sys.stderr) sys.exit(1) path = sys.argv[1] @@ -99,4 +101,4 @@ if __name__ == '__main__': os.mkdir(path) compile_dataset(path) - print 'Bordes FB15k was successfully built in {0}'.format(path) + log('Bordes FB15k was successfully built in {0}\n'.format(path)) diff --git a/utils/build dummy dataset.py b/utils/build dummy dataset.py @@ -1,5 +1,6 @@ #!/usr/bin/env python2 +from __future__ import print_function import sys import os import shutil @@ -34,7 +35,7 @@ def construct_dummy_dataset(kind, prefix, n_embeddings, n_relations): if __name__ == '__main__': if len(sys.argv)<5: - print >>sys.stderr, 'Usage: {0} {{id, halfperm}} dataset_name n_embeddings n_relations'.format(sys.argv[0]) + print('Usage: {0} {{id, halfperm}} dataset_name n_embeddings n_relations'.format(sys.argv[0]), file=sys.stderr) sys.exit(1) kind = sys.argv[1] prefix = sys.argv[2]