transform

old TransE-like models
git clone https://esimon.eu/repos/transform.git
Log | Files | Refs | README

meta_model.py (3296B)


      1 #!/usr/bin/env python2
      2 
      3 from utils.log import *
      4 from config import *
      5 from model import *
      6 import numpy
      7 import threading
      8 
      9 class Meta_model(object):
     10     """ Meta-model class. """
     11 
     12     def __init__(self, dataset, config, pathes=None):
     13         self.dataset = dataset
     14         self.combine_scores = config['scores combinator']
     15         configs = expand_config(config)
     16         if pathes is None:
     17             pathes = [ '{0}/{1}.best'.format(config['best model save path'], config['model name']) for config in configs ]
     18         self.models = [ Model(dataset, config, path) for config, path in zip(configs, pathes) ]
     19 
     20     def build_test(self):
     21         for model in self.models:
     22             model.build_test()
     23 
     24     def build_train(self):
     25         for model in self.models:
     26             model.build_train()
     27 
     28     def left_scoring_function(self, relation, left, right):
     29         res = [ model.left_scoring_function(relation, left, right) for model in self.models ]
     30         return numpy.transpose(res).reshape(right.shape[0], len(self.models))
     31 
     32     def right_scoring_function(self, relation, left, right):
     33         res = [ model.right_scoring_function(relation, left, right) for model in self.models ]
     34         return numpy.transpose(res).reshape(left.shape[0], len(self.models))
     35 
     36     def error(self, name):
     37         """ Compute the mean rank, standard deviation and top 10 on a given data. """
     38         result = []
     39         for (relation, left, right) in self.dataset.iterate(name):
     40             entities = self.dataset.universe
     41             raw_left_scores = self.left_scoring_function(relation, left, entities)
     42             raw_right_scores = self.right_scoring_function(relation, entities, right)
     43             left_scores = self.combine_scores(raw_left_scores)
     44             right_scores = self.combine_scores(raw_right_scores)
     45             left_rank = 1+numpy.asscalar(numpy.where(numpy.argsort(left_scores)==right.indices[0])[0]) # FIXME Ugly
     46             right_rank = 1+numpy.asscalar(numpy.where(numpy.argsort(right_scores)==left.indices[0])[0]) # FIXME Ugly
     47             result.extend((left_rank, right_rank))
     48         mean = numpy.mean(result)
     49         std = numpy.std(result)
     50         top10 = numpy.mean(map(lambda x: x<=10, result))
     51         return (mean, std, top10)
     52 
     53     def test(self):
     54         """ Test the model. """
     55         log('# Testing the model')
     56         (mean, std, top10) = self.error('test')
     57         log(' mean: {0:<15} std: {1:<15} top10: {2:<15}\n'.format(mean, std, top10))
     58 
     59     def test_all(self):
     60         """ Test all the sub models. """
     61         for model in self.models:
     62             model.test(save=False)
     63 
     64     def train(self):
     65         """ Train the model. """
     66         if len(self.models)==1:
     67             self.models[0].train()
     68         elif 'threaded' in self.config and self.config['threaded']:
     69             threads = [ threading.Thread(target=model.train, args=()) for model in self.models ]
     70             for (model, thread) in zip(self.models, threads):
     71                 log('# Starting thread for model {0}\n'.format(model.tag))
     72                 thread.start()
     73             log('# Waiting for children to join\n')
     74             for thread in threads:
     75                 thread.join()
     76             log('# All children joined\n')
     77         else:
     78             for model in self.models:
     79                 model.train()