dataset.py (3299B)
1 #!/usr/bin/env python2 2 3 from utils.log import * 4 import scipy 5 import numpy 6 import theano 7 8 class Dataset(object): 9 def __init__(self, prefix, config={}): 10 self.config = config 11 log('# Loading dataset "{0}"\n'.format(prefix)) 12 with open(prefix+'/entities', 'r') as file: 13 self.entities = file.read().splitlines() 14 with open(prefix+'/relations', 'r') as file: 15 self.relations = file.read().splitlines() 16 self.number_entities = len(self.entities) 17 self.number_relations = len(self.relations) 18 self.load_file(prefix, 'train') 19 self.load_file(prefix, 'valid') 20 self.load_file(prefix, 'test') 21 self.universe = scipy.sparse.eye(self.number_entities, format='csr', dtype=theano.config.floatX) 22 23 def load_file(self, prefix, name): 24 with open('{0}/{1}'.format(prefix, name), 'r') as file: 25 lines = file.readlines() 26 if 'truncate {0} set'.format(prefix) in self.config: 27 lines = lines[:config['truncate {0} set'.format(prefix)]] 28 content = map(lambda line: map(int, line.split('\t')), lines) 29 [left, relation, right] = map(list, zip(*content)) 30 N = len(relation) 31 setattr(self, name+'_size', N) 32 setattr(self, name+'_right', scipy.sparse.csr_matrix(([1]*N, right, range(N+1)), shape=(N, self.number_entities), dtype=theano.config.floatX)) 33 setattr(self, name+'_relation', scipy.sparse.csr_matrix(([1]*N, relation, range(N+1)), shape=(N, self.number_relations), dtype=theano.config.floatX)) 34 setattr(self, name+'_left', scipy.sparse.csr_matrix(([1]*N, left, range(N+1)), shape=(N, self.number_entities), dtype=theano.config.floatX)) 35 36 def training_minibatch(self, rng, batch_size): 37 # Sampling corrupted entities 38 def sample_matrix(): 39 row = range(self.train_size+1) 40 col = rng.randint(0, self.number_entities, size=self.train_size) 41 data = numpy.ones(self.train_size) 42 random_embeddings = scipy.sparse.csr_matrix((data, col, row), shape=(self.train_size, self.number_entities), dtype=theano.config.floatX) 43 return random_embeddings 44 corrupted_left = sample_matrix() 45 corrupted_right = sample_matrix() 46 47 # Shuffling training set 48 order = rng.permutation(self.train_size) 49 train_left = self.train_left[order, :] 50 train_right = self.train_right[order, :] 51 train_relation = self.train_relation[order, :] 52 53 # Yielding batches 54 ls = numpy.linspace(0, self.train_size, 1+self.train_size/batch_size) 55 for i in xrange(len(ls)-1): 56 f = int(ls[i]) 57 t = int(ls[i+1]) 58 left_positive = train_left[f:t] 59 right_positive = train_right[f:t] 60 left_negative = corrupted_left[f:t] 61 right_negative = corrupted_right[f:t] 62 relation = train_relation[f:t] 63 yield (relation, left_positive, right_positive, left_negative, right_negative) 64 65 def iterate(self, name): 66 N = getattr(self, name+'_size') 67 relation = getattr(self, name+'_relation') 68 left = getattr(self, name+'_left') 69 right = getattr(self, name+'_right') 70 for i in xrange(N): 71 yield (relation[i], left[i], right[i])