transform

old TransE-like models
git clone https://esimon.eu/repos/transform.git
Log | Files | Refs | README

model.py (9351B)


      1 #!/usr/bin/env python2
      2 
      3 from utils.log import *
      4 import cPickle
      5 import numpy
      6 import scipy
      7 import theano
      8 import theano.tensor as T
      9 import theano.sparse as S
     10 
     11 from embeddings import *
     12 
     13 def L1_norm(l, r):
     14     return -T.sum(abs(l-r), axis=1)
     15 
     16 def L2_norm(l, r):
     17     return -T.sqrt(T.sum(T.sqr(l-r), axis=1))
     18 
     19 def cosine(l, r):
     20     l_norm = T.sum(T.sqr(l), axis=1)
     21     r_norm = T.sum(T.sqr(r), axis=1)
     22     return T.sum(l*r, axis=1) / (l_norm*r_norm)
     23 
     24 class Model(object):
     25     """ Model class.
     26 
     27     Training model using SGD with a contrastive criterion.
     28     """
     29 
     30     def __init__(self, dataset, config, filepath=None):
     31         """ Initialise a model.
     32 
     33         Keyword arguments:
     34         dataset -- dataset on which the model will be trained and tested
     35         config -- config dictionary
     36         filepath -- path to the Model file
     37         """
     38 
     39         log('# Initialising model "{0}"\n'.format(config['model name']))
     40         self.dataset = dataset
     41         self.config = config
     42         self.tag = config['model name']
     43 
     44         if filepath is None:
     45             Relations = config['relations']
     46             self.epoch = 0
     47             self.embeddings = Embeddings(config['rng'], dataset.number_entities, config['dimension'], self.tag+'.embeddings')
     48             self.relations = Relations(config['rng'], dataset.number_relations, config['dimension'], self.tag+'.relations')
     49         else:
     50             log('## Loading model from "{0}"\n'.format(filepath))
     51             with open(filepath, 'rb') as file:
     52                 self.epoch = cPickle.load(file)
     53                 self.embeddings = cPickle.load(file)
     54                 self.relations = cPickle.load(file)
     55 
     56     def save(self, filepath):
     57         """ Save the model in a file. """
     58         with open(filepath, 'wb') as file:
     59             cPickle.dump(self.epoch, file, -1)
     60             cPickle.dump(self.embeddings, file, -1)
     61             cPickle.dump(self.relations, file, -1)
     62 
     63     def build_train(self):
     64         """ Build theano train functions. """
     65         log('## Compiling Theano graph for training model "{0}"\n'.format(self.tag))
     66         input_relation = S.csr_matrix("relation")
     67         input_left_positive = S.csr_matrix("left positive")
     68         input_right_positive = S.csr_matrix("right positive")
     69         input_left_negative = S.csr_matrix("left negative")
     70         input_right_negative = S.csr_matrix("right negative")
     71         inputs = [ input_relation, input_left_positive, input_right_positive, input_left_negative, input_right_negative ]
     72         left_positive, right_positive = self.embeddings.embed(input_left_positive), self.embeddings.embed(input_right_positive)
     73         left_negative, right_negative = self.embeddings.embed(input_left_negative), self.embeddings.embed(input_right_negative)
     74         relation = self.relations.lookup(input_relation)
     75 
     76         score_positive = self.config['similarity'](self.relations.transform(left_positive, relation), right_positive)
     77         score_left_negative = self.config['similarity'](self.relations.transform(left_negative, relation), right_positive)
     78         score_right_negative = self.config['similarity'](self.relations.transform(left_positive, relation), right_negative)
     79         score_left = self.config['margin'] + score_positive - score_left_negative
     80         score_right = self.config['margin'] + score_positive - score_right_negative
     81 
     82         violating_margin_left = score_left>0
     83         violating_margin_right = score_right>0
     84         criterion_left = T.sum(violating_margin_left*score_left)
     85         criterion_right = T.sum(violating_margin_right*score_right)
     86         criterion = criterion_left + criterion_right
     87 
     88         self.train_function = theano.function(inputs=inputs, outputs=[criterion], updates=self.updates(criterion))
     89         self.normalise_function = theano.function(inputs=[], outputs=[], updates=self.embeddings.normalise_updates())
     90 
     91     def build_test(self):
     92         """ Build theano test functions. """
     93         log('## Compiling Theano graph for testing model "{0}"\n'.format(self.tag))
     94         input_relation = S.csr_matrix("relation")
     95         input_left = S.csr_matrix("left")
     96         input_right = S.csr_matrix("right")
     97         inputs = [ input_relation, input_left, input_right ]
     98         left, right = self.embeddings.embed(input_left), self.embeddings.embed(input_right)
     99         relation = self.relations.lookup(input_relation)
    100 
    101         relation = map(lambda r: T.addbroadcast(r, 0), relation)
    102         left_broadcasted = T.addbroadcast(left, 0)
    103         right_broadcasted = T.addbroadcast(right, 0)
    104         left_score = self.config['similarity'](self.relations.transform(left_broadcasted, relation), right)
    105         right_score = self.config['similarity'](self.relations.transform(left, relation), right_broadcasted)
    106 
    107         self.left_scoring_function = theano.function(inputs=inputs, outputs=[left_score])
    108         self.right_scoring_function = theano.function(inputs=inputs, outputs=[right_score])
    109 
    110     def updates(self, cost):
    111         """ Compute the updates to perform a SGD step w.r.t. a given cost.
    112 
    113         Keyword arguments:
    114         cost -- The cost to optimise.
    115         """
    116         lr_relations = self.config['relation learning rate']
    117         lr_embeddings = self.config['embeddings learning rate']
    118         return self.relations.updates(cost, lr_relations) + self.embeddings.updates(cost, lr_embeddings)
    119 
    120     def train(self):
    121         """ Train the model. """
    122         log('# Training the model "{0}"\n'.format(self.tag))
    123 
    124         batch_size = self.config['train batch size']
    125         validation_frequency = self.config['validation frequency']
    126         number_epoch = self.config['number of epoch']
    127 
    128         while self.epoch < number_epoch:
    129             for (relation, left_positive, right_positive, left_negative, right_negative) in self.dataset.training_minibatch(self.config['rng'], batch_size):
    130                 self.train_function(relation, left_positive, right_positive, left_negative, right_negative)
    131                 self.normalise_function()
    132 
    133             self.epoch += 1
    134             if self.epoch % validation_frequency == 0:
    135                 self.validate()
    136 
    137     def error(self, name, transform_scores=(lambda x: x)):
    138         """ Compute the mean rank, standard deviation and top 10 on a given data. """
    139         result = []
    140         for (relation, left, right) in self.dataset.iterate(name):
    141             entities = self.dataset.universe
    142             left_scores = self.left_scoring_function(relation, left, entities)
    143             right_scores = self.right_scoring_function(relation, entities, right)
    144             left_scores = transform_scores(left_scores)
    145             right_scores = transform_scores(right_scores)
    146             left_rank = 1+numpy.asscalar(numpy.where(numpy.argsort(left_scores)==right.indices[0])[1]) # FIXME Ugly
    147             right_rank = 1+numpy.asscalar(numpy.where(numpy.argsort(right_scores)==left.indices[0])[1]) # FIXME Ugly
    148             result.extend((left_rank, right_rank))
    149         mean = numpy.mean(result)
    150         std = numpy.std(result)
    151         top10 = numpy.mean(map(lambda x: x<=10, result))
    152         return (mean, std, top10, result)
    153 
    154     def validate(self):
    155         """ Validate the model. """
    156         log('Validation model "{0}" epoch {1:<5}: begin\n'.format(self.tag, self.epoch))
    157         (valid_mean, valid_std, valid_top10, valid_distribution) = self.error('valid')
    158         log('Validation model "{0}" epoch {1:<5}: mean: {2:<15} valid std: {3:<15} valid top10: {4:<15}\n'.format(self.tag, self.epoch, valid_mean, valid_std, valid_top10))
    159         datalog(self.config['datalog path']+'/'+self.config['model name'], 'summary', self.epoch, valid_mean, valid_std, valid_top10)
    160         datalog(self.config['datalog path']+'/'+self.config['model name'], 'distribution', self.epoch, valid_distribution)
    161         if not hasattr(self, 'best_mean') or valid_mean < self.best_mean:
    162             self.best_mean = valid_mean
    163             if self.config['save best model']:
    164                 log('Validation model "{0}" epoch {1:<5}: best model so far, saving...\n'.format(self.tag, self.epoch))
    165                 self.save('{0}/{1}.best'.format(self.config['best model save path'], self.config['model name']))
    166                 log('Validation model "{0}" epoch {1:<5}: saved\n'.format(self.tag, self.epoch))
    167 
    168         if self.config['validate on training data']:
    169             (train_mean, train_std, train_top10, _) = self.error('train')
    170             log('Validation model "{0}" epoch {1:<5} train mean: {2:<15} std: {3:<15} train top10: {4:<15}\n'.format(self.tag, self.epoch, train_mean, train_std, train_top10))
    171 
    172     def test(self, save=True):
    173         """ Test the model. """
    174         if save:
    175             log('# Test model "{0}": begin\n'.format(self.tag))
    176 
    177         (mean, std, top10, distribution) = self.error('test')
    178         log('# Test model "{0}": mean: {1:<15} std: {2:<15} top10: {3:<15}\n'.format(self.tag, mean, std, top10))
    179         datalog(self.config['datalog path']+'/'+self.config['model name'], 'summary', 'test', mean, std, top10)
    180         datalog(self.config['datalog path']+'/'+self.config['model name'], 'distribution', 'test', distribution)
    181 
    182         if save:
    183             log('# Test model "{0}": saving...\n'.format(self.tag))
    184             self.save('{0}/{1}.last'.format(self.config['last model save path'], self.config['model name']))
    185             log('# Test model "{0}": saved\n'.format(self.tag))