commit 95ff4ae00954ae9a0ffb41297b2f65da251f183b
parent a6f3d27b7ab6409c39f4c5c2f6a6db351c2851aa
Author: Étienne Simon <esimon@esimon.eu>
Date: Fri, 13 Jun 2014 11:22:07 +0200
Add config parameter to truncate dataset
Diffstat:
3 files changed, 9 insertions(+), 5 deletions(-)
diff --git a/dataset.py b/dataset.py
@@ -6,7 +6,8 @@ import numpy
import theano
class Dataset(object):
- def __init__(self, prefix):
+ def __init__(self, prefix, config):
+ self.config = config
log('# Loading dataset "{0}"\n'.format(prefix))
with open(prefix+'/entities', 'r') as file:
self.entities = file.readlines()
@@ -21,8 +22,11 @@ class Dataset(object):
def load_file(self, prefix, name):
with open('{0}/{1}'.format(prefix, name), 'r') as file:
- content = map(lambda line: map(int, line.split('\t')), file.readlines())
- [left, relation, right] = map(list, zip(*content))
+ lines = file.readlines()
+ if 'truncate {0} set'.format(prefix) in self.config:
+ lines = lines[:config['truncate {0} set'.format(prefix)]]
+ content = map(lambda line: map(int, line.split('\t')), lines)
+ [left, relation, right] = map(list, zip(*content))
N = len(relation)
setattr(self, name+'_size', N)
setattr(self, name+'_right', scipy.sparse.csr_matrix(([1]*N, right, range(N+1)), shape=(N, self.number_entities), dtype=theano.config.floatX))
diff --git a/test.py b/test.py
@@ -28,7 +28,7 @@ if __name__ == '__main__':
sys.exit(1)
ModelType = Meta_model if config.get('meta', False) else Model
- data = Dataset(data)
+ data = Dataset(data, config)
model = ModelType(data, config, model_pathes)
model.build_test()
model.test_all()
diff --git a/train.py b/train.py
@@ -27,7 +27,7 @@ if __name__ == '__main__':
sys.exit(1)
ModelType = Meta_model if config.get('meta', False) else Model
- data = Dataset(data)
+ data = Dataset(data, config)
model = ModelType(data, config, model_pathes)
model.build_train()