model.py (9351B)
1 #!/usr/bin/env python2 2 3 from utils.log import * 4 import cPickle 5 import numpy 6 import scipy 7 import theano 8 import theano.tensor as T 9 import theano.sparse as S 10 11 from embeddings import * 12 13 def L1_norm(l, r): 14 return -T.sum(abs(l-r), axis=1) 15 16 def L2_norm(l, r): 17 return -T.sqrt(T.sum(T.sqr(l-r), axis=1)) 18 19 def cosine(l, r): 20 l_norm = T.sum(T.sqr(l), axis=1) 21 r_norm = T.sum(T.sqr(r), axis=1) 22 return T.sum(l*r, axis=1) / (l_norm*r_norm) 23 24 class Model(object): 25 """ Model class. 26 27 Training model using SGD with a contrastive criterion. 28 """ 29 30 def __init__(self, dataset, config, filepath=None): 31 """ Initialise a model. 32 33 Keyword arguments: 34 dataset -- dataset on which the model will be trained and tested 35 config -- config dictionary 36 filepath -- path to the Model file 37 """ 38 39 log('# Initialising model "{0}"\n'.format(config['model name'])) 40 self.dataset = dataset 41 self.config = config 42 self.tag = config['model name'] 43 44 if filepath is None: 45 Relations = config['relations'] 46 self.epoch = 0 47 self.embeddings = Embeddings(config['rng'], dataset.number_entities, config['dimension'], self.tag+'.embeddings') 48 self.relations = Relations(config['rng'], dataset.number_relations, config['dimension'], self.tag+'.relations') 49 else: 50 log('## Loading model from "{0}"\n'.format(filepath)) 51 with open(filepath, 'rb') as file: 52 self.epoch = cPickle.load(file) 53 self.embeddings = cPickle.load(file) 54 self.relations = cPickle.load(file) 55 56 def save(self, filepath): 57 """ Save the model in a file. """ 58 with open(filepath, 'wb') as file: 59 cPickle.dump(self.epoch, file, -1) 60 cPickle.dump(self.embeddings, file, -1) 61 cPickle.dump(self.relations, file, -1) 62 63 def build_train(self): 64 """ Build theano train functions. """ 65 log('## Compiling Theano graph for training model "{0}"\n'.format(self.tag)) 66 input_relation = S.csr_matrix("relation") 67 input_left_positive = S.csr_matrix("left positive") 68 input_right_positive = S.csr_matrix("right positive") 69 input_left_negative = S.csr_matrix("left negative") 70 input_right_negative = S.csr_matrix("right negative") 71 inputs = [ input_relation, input_left_positive, input_right_positive, input_left_negative, input_right_negative ] 72 left_positive, right_positive = self.embeddings.embed(input_left_positive), self.embeddings.embed(input_right_positive) 73 left_negative, right_negative = self.embeddings.embed(input_left_negative), self.embeddings.embed(input_right_negative) 74 relation = self.relations.lookup(input_relation) 75 76 score_positive = self.config['similarity'](self.relations.transform(left_positive, relation), right_positive) 77 score_left_negative = self.config['similarity'](self.relations.transform(left_negative, relation), right_positive) 78 score_right_negative = self.config['similarity'](self.relations.transform(left_positive, relation), right_negative) 79 score_left = self.config['margin'] + score_positive - score_left_negative 80 score_right = self.config['margin'] + score_positive - score_right_negative 81 82 violating_margin_left = score_left>0 83 violating_margin_right = score_right>0 84 criterion_left = T.sum(violating_margin_left*score_left) 85 criterion_right = T.sum(violating_margin_right*score_right) 86 criterion = criterion_left + criterion_right 87 88 self.train_function = theano.function(inputs=inputs, outputs=[criterion], updates=self.updates(criterion)) 89 self.normalise_function = theano.function(inputs=[], outputs=[], updates=self.embeddings.normalise_updates()) 90 91 def build_test(self): 92 """ Build theano test functions. """ 93 log('## Compiling Theano graph for testing model "{0}"\n'.format(self.tag)) 94 input_relation = S.csr_matrix("relation") 95 input_left = S.csr_matrix("left") 96 input_right = S.csr_matrix("right") 97 inputs = [ input_relation, input_left, input_right ] 98 left, right = self.embeddings.embed(input_left), self.embeddings.embed(input_right) 99 relation = self.relations.lookup(input_relation) 100 101 relation = map(lambda r: T.addbroadcast(r, 0), relation) 102 left_broadcasted = T.addbroadcast(left, 0) 103 right_broadcasted = T.addbroadcast(right, 0) 104 left_score = self.config['similarity'](self.relations.transform(left_broadcasted, relation), right) 105 right_score = self.config['similarity'](self.relations.transform(left, relation), right_broadcasted) 106 107 self.left_scoring_function = theano.function(inputs=inputs, outputs=[left_score]) 108 self.right_scoring_function = theano.function(inputs=inputs, outputs=[right_score]) 109 110 def updates(self, cost): 111 """ Compute the updates to perform a SGD step w.r.t. a given cost. 112 113 Keyword arguments: 114 cost -- The cost to optimise. 115 """ 116 lr_relations = self.config['relation learning rate'] 117 lr_embeddings = self.config['embeddings learning rate'] 118 return self.relations.updates(cost, lr_relations) + self.embeddings.updates(cost, lr_embeddings) 119 120 def train(self): 121 """ Train the model. """ 122 log('# Training the model "{0}"\n'.format(self.tag)) 123 124 batch_size = self.config['train batch size'] 125 validation_frequency = self.config['validation frequency'] 126 number_epoch = self.config['number of epoch'] 127 128 while self.epoch < number_epoch: 129 for (relation, left_positive, right_positive, left_negative, right_negative) in self.dataset.training_minibatch(self.config['rng'], batch_size): 130 self.train_function(relation, left_positive, right_positive, left_negative, right_negative) 131 self.normalise_function() 132 133 self.epoch += 1 134 if self.epoch % validation_frequency == 0: 135 self.validate() 136 137 def error(self, name, transform_scores=(lambda x: x)): 138 """ Compute the mean rank, standard deviation and top 10 on a given data. """ 139 result = [] 140 for (relation, left, right) in self.dataset.iterate(name): 141 entities = self.dataset.universe 142 left_scores = self.left_scoring_function(relation, left, entities) 143 right_scores = self.right_scoring_function(relation, entities, right) 144 left_scores = transform_scores(left_scores) 145 right_scores = transform_scores(right_scores) 146 left_rank = 1+numpy.asscalar(numpy.where(numpy.argsort(left_scores)==right.indices[0])[1]) # FIXME Ugly 147 right_rank = 1+numpy.asscalar(numpy.where(numpy.argsort(right_scores)==left.indices[0])[1]) # FIXME Ugly 148 result.extend((left_rank, right_rank)) 149 mean = numpy.mean(result) 150 std = numpy.std(result) 151 top10 = numpy.mean(map(lambda x: x<=10, result)) 152 return (mean, std, top10, result) 153 154 def validate(self): 155 """ Validate the model. """ 156 log('Validation model "{0}" epoch {1:<5}: begin\n'.format(self.tag, self.epoch)) 157 (valid_mean, valid_std, valid_top10, valid_distribution) = self.error('valid') 158 log('Validation model "{0}" epoch {1:<5}: mean: {2:<15} valid std: {3:<15} valid top10: {4:<15}\n'.format(self.tag, self.epoch, valid_mean, valid_std, valid_top10)) 159 datalog(self.config['datalog path']+'/'+self.config['model name'], 'summary', self.epoch, valid_mean, valid_std, valid_top10) 160 datalog(self.config['datalog path']+'/'+self.config['model name'], 'distribution', self.epoch, valid_distribution) 161 if not hasattr(self, 'best_mean') or valid_mean < self.best_mean: 162 self.best_mean = valid_mean 163 if self.config['save best model']: 164 log('Validation model "{0}" epoch {1:<5}: best model so far, saving...\n'.format(self.tag, self.epoch)) 165 self.save('{0}/{1}.best'.format(self.config['best model save path'], self.config['model name'])) 166 log('Validation model "{0}" epoch {1:<5}: saved\n'.format(self.tag, self.epoch)) 167 168 if self.config['validate on training data']: 169 (train_mean, train_std, train_top10, _) = self.error('train') 170 log('Validation model "{0}" epoch {1:<5} train mean: {2:<15} std: {3:<15} train top10: {4:<15}\n'.format(self.tag, self.epoch, train_mean, train_std, train_top10)) 171 172 def test(self, save=True): 173 """ Test the model. """ 174 if save: 175 log('# Test model "{0}": begin\n'.format(self.tag)) 176 177 (mean, std, top10, distribution) = self.error('test') 178 log('# Test model "{0}": mean: {1:<15} std: {2:<15} top10: {3:<15}\n'.format(self.tag, mean, std, top10)) 179 datalog(self.config['datalog path']+'/'+self.config['model name'], 'summary', 'test', mean, std, top10) 180 datalog(self.config['datalog path']+'/'+self.config['model name'], 'distribution', 'test', distribution) 181 182 if save: 183 log('# Test model "{0}": saving...\n'.format(self.tag)) 184 self.save('{0}/{1}.last'.format(self.config['last model save path'], self.config['model name'])) 185 log('# Test model "{0}": saved\n'.format(self.tag))