commit 94da3eeeb689401e0aac84ad424397c0943a3d6f
parent 7f26ea262456312050e3d6c3e1d3aaf9039bd0b9
Author: Étienne Simon <esimon@esimon.eu>
Date: Thu, 17 Apr 2014 18:16:34 +0200
Clean output
Diffstat:
5 files changed, 41 insertions(+), 38 deletions(-)
diff --git a/dataset.py b/dataset.py
@@ -1,13 +1,13 @@
#!/usr/bin/env python2
+from utils.log import log
import scipy
import numpy
-import sys
import theano
class Dataset(object):
def __init__(self, prefix):
- print >>sys.stderr, '# Loading dataset "{0}"'.format(prefix)
+ log('# Loading dataset "{0}"\n'.format(prefix))
with open(prefix+'/embeddings', 'r') as file:
self.embeddings = file.readlines()
with open(prefix+'/relations', 'r') as file:
diff --git a/main.py b/main.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python2
+from utils.log import log
import sys
import json
@@ -9,7 +10,7 @@ from relations.translations import *
if __name__ == '__main__':
if len(sys.argv)<3:
- print >>sys.stderr, 'Usage: {0} data parameters [model]'.format(sys.argv[0])
+ log('Usage: {0} data parameters [model]\n'.format(sys.argv[0]), file=sys.stderr)
sys.exit(1)
data = sys.argv[1]
config = sys.argv[2]
diff --git a/model.py b/model.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python2
-import sys
+from utils.log import log
import cPickle
import numpy
import scipy
@@ -32,7 +32,7 @@ class Model(object):
hyperparameters -- hyperparameters dictionary
tag -- name of the embeddings for parameter declaration
"""
- print >>sys.stderr, '# Initialising model "{0}"'.format(tag)
+ log('# Initialising model "{0}"\n'.format(tag))
self = cls()
self.embeddings = Embeddings(hyperparameters['rng'], dataset.number_embeddings, hyperparameters['dimension'], tag+'.embeddings')
@@ -54,7 +54,7 @@ class Model(object):
hyperparameters -- hyperparameters dictionary
tag -- name of the embeddings for parameter declaration
"""
- print >>sys.stderr, '# Loading model from "{0}"'.format(filepath)
+ log('# Loading model from "{0}"\n'.format(filepath))
self = cls()
with open(filepath, 'rb') as file:
@@ -75,7 +75,7 @@ class Model(object):
def build(self):
""" Build theano functions. """
- print >>sys.stderr, '## Compiling Theano graph for model "{0}"'.format(self.tag)
+ log('## Compiling Theano graph for model "{0}"\n'.format(self.tag))
self.parameters = self.relations.parameters + self.embeddings.parameters
inputs = tuple(S.csr_matrix() for _ in xrange(5))
@@ -111,7 +111,7 @@ class Model(object):
def train(self):
""" Train the model. """
- print >>sys.stderr, '# Training the model "{0}"'.format(self.tag)
+ log('# Training the model "{0}"\n'.format(self.tag))
batch_size = self.hyperparameters['train batch size']
validation_frequency = self.hyperparameters['validation frequency']
@@ -153,28 +153,27 @@ class Model(object):
def validate(self, epoch):
""" Validate the model. """
- print >>sys.stderr, 'Validation epoch {:<5}'.format(epoch),
+ log('Validation epoch {:<5}'.format(epoch))
(valid_mean, valid_top10) = self.error('valid')
- print >>sys.stderr, 'valid mean: {0:<15} valid top10: {1:<15}'.format(valid_mean, valid_top10),
+ log(' valid mean: {0:<15} valid top10: {1:<15}'.format(valid_mean, valid_top10))
if not hasattr(self, 'best_mean') or valid_mean < self.best_mean:
self.best_mean = valid_mean
+ log('(best so far')
if self.hyperparameters['save best model']:
- print >>sys.stderr, '(best so far, saving...',
+ log(', saving...')
self.save(self.hyperparameters['best model save location'])
- print >>sys.stderr, 'done)'
- else:
- print >>sys.stderr, '(best so far)'
- else:
- print >>sys.stderr, ''
+ log(' done')
+ log(')')
if self.hyperparameters['validate on training data']:
(train_mean, train_top10) = self.error('train')
- print >>sys.stderr, 'train mean: {0:<15} train top10: {1:<15}'.format(train_mean, train_top10)
- else:
- print >>sys.stderr, ''
+ log(' train mean: {0:<15} train top10: {1:<15}'.format(train_mean, train_top10))
+ log('\n')
def test(self):
""" Test the model. """
- print >>sys.stderr, '# Testing the model "{0}"'.format(self.tag),
+ log('# Testing the model "{0}"'.format(self.tag))
(mean, top10) = self.error('test')
- print >>sys.stderr, ' mean: {0:<15} top10: {1:<15}'.format(mean, top10)
+ log(' mean: {0:<15} top10: {1:<15} (saving...'.format(mean, top10))
+ self.save(self.hyperparameters['last model save location'])
+ log(' done)\n')
diff --git a/utils/build Bordes FB15k.py b/utils/build Bordes FB15k.py
@@ -1,7 +1,9 @@
#!/usr/bin/env python2
+from __future__ import print_function
import sys
import os
+from log import log
urls = [ 'https://www.hds.utc.fr/everest/lib/exe/fetch.php?id=en%3Atranse&cache=cache&media=en:fb15k.tgz' ]
@@ -10,22 +12,22 @@ def get_archive(path):
class URLopener(urllib.FancyURLopener):
def http_error_default(self, url, fp, errcode, errmsg, headers):
- print >>sys.stderr, 'Error: {0} {1}'.format(errcode, errmsg)
+ print('Error: {0} {1}'.format(errcode, errmsg), file=sys.stderr)
raise IOError
archive = path+'/archive.tgz'
downloaded = False
for url in urls:
- print >>sys.stderr, 'Downloading dataset from "{0}"...'.format(url),
+ log('Downloading dataset from "{0}"...'.format(url))
try:
URLopener().retrieve(url, archive)
downloaded = True
- print >>sys.stderr, ' done'
+ log(' done\n')
except IOError:
pass
if not downloaded:
- print >>sys.stderr, 'Error: Unable to download dataset.'
+ print('Error: Unable to download dataset.', file=sys.stderr)
sys.exit(1)
def get_raw(path):
@@ -34,28 +36,28 @@ def get_raw(path):
get_archive(path)
- print >>sys.stderr, 'Raw files not found, extracting archive...',
+ log('Raw files not found, extracting archive...')
raw = path+'/raw'
os.mkdir(raw)
import tarfile
tar = tarfile.open(path+'/archive.tgz', 'r:gz')
tar.extractall(raw)
- print >>sys.stderr, ' done'
+ log(' done\n')
def compile_dataset(path):
get_raw(path)
prefix = path+'/raw/FB15k/freebase_mtr100_mte100-'
suffix = '.txt'
- print >>sys.stderr, 'Reading train file...',
+ log('Reading train file...')
with open(prefix+'train'+suffix, 'r') as file:
content = map(lambda line: line.rstrip('\n').split('\t'), file.readlines())
[left, relations, right] = map(set, zip(*content))
entities = left | right
- print >>sys.stderr, ' done'
+ log(' done\n')
- print >>sys.stderr, 'Writting entities...',
+ log('Writting entities...')
e2i, i2e, r2i, i2r = {}, {}, {}, {}
with open(path+'/embeddings', 'w') as file:
i=0
@@ -64,9 +66,9 @@ def compile_dataset(path):
i2e[i]=entity
file.write(entity+'\n')
i+=1
- print >>sys.stderr, ' done ({0} entities written)'.format(i)
+ log(' done ({0} entities written)\n'.format(i))
- print >>sys.stderr, 'Writting relations...',
+ log('Writting relations...')
with open(path+'/relations', 'w') as file:
i=0
for relation in relations:
@@ -74,10 +76,10 @@ def compile_dataset(path):
i2r[i]=relation
file.write(relation+'\n')
i+=1
- print >>sys.stderr, ' done ({0} relations written)'.format(i)
+ log(' done ({0} relations written)\n'.format(i))
for name in ['train', 'valid', 'test']:
- print >>sys.stderr, 'Compiling {0}...'.format(name),
+ log('Compiling {0}...'.format(name))
count = 0
with open(prefix+name+suffix, 'r') as infile:
with open(path+'/'+name, 'w') as outfile:
@@ -87,11 +89,11 @@ def compile_dataset(path):
outfile.write('{0}\t{1}\t{2}\n'.format(e2i[left], r2i[relation], e2i[right]))
else:
count+=1
- print >>sys.stderr, ' done ({0} entit{1} removed)'.format(count, 'y' if count<2 else 'ies')
+ log(' done ({0} entit{1} removed)\n'.format(count, 'y' if count<2 else 'ies'))
if __name__ == '__main__':
if len(sys.argv)<2:
- print >>sys.stderr, 'Usage: {0} path'.format(sys.argv[0])
+ print('Usage: {0} path'.format(sys.argv[0]), file=sys.stderr)
sys.exit(1)
path = sys.argv[1]
@@ -99,4 +101,4 @@ if __name__ == '__main__':
os.mkdir(path)
compile_dataset(path)
- print 'Bordes FB15k was successfully built in {0}'.format(path)
+ log('Bordes FB15k was successfully built in {0}\n'.format(path))
diff --git a/utils/build dummy dataset.py b/utils/build dummy dataset.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python2
+from __future__ import print_function
import sys
import os
import shutil
@@ -34,7 +35,7 @@ def construct_dummy_dataset(kind, prefix, n_embeddings, n_relations):
if __name__ == '__main__':
if len(sys.argv)<5:
- print >>sys.stderr, 'Usage: {0} {{id, halfperm}} dataset_name n_embeddings n_relations'.format(sys.argv[0])
+ print('Usage: {0} {{id, halfperm}} dataset_name n_embeddings n_relations'.format(sys.argv[0]), file=sys.stderr)
sys.exit(1)
kind = sys.argv[1]
prefix = sys.argv[2]