meta_model.py (3296B)
1 #!/usr/bin/env python2 2 3 from utils.log import * 4 from config import * 5 from model import * 6 import numpy 7 import threading 8 9 class Meta_model(object): 10 """ Meta-model class. """ 11 12 def __init__(self, dataset, config, pathes=None): 13 self.dataset = dataset 14 self.combine_scores = config['scores combinator'] 15 configs = expand_config(config) 16 if pathes is None: 17 pathes = [ '{0}/{1}.best'.format(config['best model save path'], config['model name']) for config in configs ] 18 self.models = [ Model(dataset, config, path) for config, path in zip(configs, pathes) ] 19 20 def build_test(self): 21 for model in self.models: 22 model.build_test() 23 24 def build_train(self): 25 for model in self.models: 26 model.build_train() 27 28 def left_scoring_function(self, relation, left, right): 29 res = [ model.left_scoring_function(relation, left, right) for model in self.models ] 30 return numpy.transpose(res).reshape(right.shape[0], len(self.models)) 31 32 def right_scoring_function(self, relation, left, right): 33 res = [ model.right_scoring_function(relation, left, right) for model in self.models ] 34 return numpy.transpose(res).reshape(left.shape[0], len(self.models)) 35 36 def error(self, name): 37 """ Compute the mean rank, standard deviation and top 10 on a given data. """ 38 result = [] 39 for (relation, left, right) in self.dataset.iterate(name): 40 entities = self.dataset.universe 41 raw_left_scores = self.left_scoring_function(relation, left, entities) 42 raw_right_scores = self.right_scoring_function(relation, entities, right) 43 left_scores = self.combine_scores(raw_left_scores) 44 right_scores = self.combine_scores(raw_right_scores) 45 left_rank = 1+numpy.asscalar(numpy.where(numpy.argsort(left_scores)==right.indices[0])[0]) # FIXME Ugly 46 right_rank = 1+numpy.asscalar(numpy.where(numpy.argsort(right_scores)==left.indices[0])[0]) # FIXME Ugly 47 result.extend((left_rank, right_rank)) 48 mean = numpy.mean(result) 49 std = numpy.std(result) 50 top10 = numpy.mean(map(lambda x: x<=10, result)) 51 return (mean, std, top10) 52 53 def test(self): 54 """ Test the model. """ 55 log('# Testing the model') 56 (mean, std, top10) = self.error('test') 57 log(' mean: {0:<15} std: {1:<15} top10: {2:<15}\n'.format(mean, std, top10)) 58 59 def test_all(self): 60 """ Test all the sub models. """ 61 for model in self.models: 62 model.test(save=False) 63 64 def train(self): 65 """ Train the model. """ 66 if len(self.models)==1: 67 self.models[0].train() 68 elif 'threaded' in self.config and self.config['threaded']: 69 threads = [ threading.Thread(target=model.train, args=()) for model in self.models ] 70 for (model, thread) in zip(self.models, threads): 71 log('# Starting thread for model {0}\n'.format(model.tag)) 72 thread.start() 73 log('# Waiting for children to join\n') 74 for thread in threads: 75 thread.join() 76 log('# All children joined\n') 77 else: 78 for model in self.models: 79 model.train()