transform

old TransE-like models
git clone https://esimon.eu/repos/transform.git
Log | Files | Refs | README

commit 1f2dd395d6133edb3375e836d93ad919b5020e28
parent 3d42500856e4bf066482b749e3a824f214791c47
Author: Étienne Simon <esimon@esimon.eu>
Date:   Wed, 30 Apr 2014 19:14:15 +0200

Add model-by-model test

Diffstat:
Mmeta_model.py | 5+++++
Mmodel.py | 14+++++++++-----
Mtest.py | 1+
3 files changed, 15 insertions(+), 5 deletions(-)

diff --git a/meta_model.py b/meta_model.py @@ -56,6 +56,11 @@ class Meta_model(object): (mean, std, top10) = self.error('test') log(' mean: {0:<15} std: {1:<15} top10: {2:<15}\n'.format(mean, std, top10)) + def test_all(self): + """ Test all the sub models. """ + for model in self.models: + model.test(save=False) + def train(self): """ Train the model. """ threads = [ threading.Thread(target=model.train, args=()) for model in self.models ] diff --git a/model.py b/model.py @@ -168,11 +168,15 @@ class Model(object): (train_mean, train_std, train_top10) = self.error('train') log('Validation model "{0}" epoch {1:<5} train mean: {2:<15} std: {3:<15} train top10: {4:<15}\n'.format(self.tag, self.epochtrain_mean, train_std, train_top10)) - def test(self): + def test(self, save=True): """ Test the model. """ - log('# Test model "{0}": begin\n'.format(self.tag)) + if save: + log('# Test model "{0}": begin\n'.format(self.tag)) + (mean, std, top10) = self.error('test') log('# Test model "{0}": mean: {1:<15} std: {2:<15} top10: {3:<15}\n'.format(self.tag, mean, std, top10)) - log('# Test model "{0}": saving...\n'.format(self.tag)) - self.save('{0}/{1}.last'.format(self.config['last model save path'], self.config['model name'])) - log('# Test model "{0}": saved\n'.format(self.tag)) + + if save: + log('# Test model "{0}": saving...\n'.format(self.tag)) + self.save('{0}/{1}.last'.format(self.config['last model save path'], self.config['model name'])) + log('# Test model "{0}": saved\n'.format(self.tag)) diff --git a/test.py b/test.py @@ -31,4 +31,5 @@ if __name__ == '__main__': data = Dataset(data) model = ModelType(data, config, model_pathes) model.build_test() + model.test_all() model.test()