commit 1f2dd395d6133edb3375e836d93ad919b5020e28
parent 3d42500856e4bf066482b749e3a824f214791c47
Author: Étienne Simon <esimon@esimon.eu>
Date:   Wed, 30 Apr 2014 19:14:15 +0200
Add model-by-model test
Diffstat:
3 files changed, 15 insertions(+), 5 deletions(-)
diff --git a/meta_model.py b/meta_model.py
@@ -56,6 +56,11 @@ class Meta_model(object):
         (mean, std, top10) = self.error('test')
         log(' mean: {0:<15} std: {1:<15} top10: {2:<15}\n'.format(mean, std, top10))
 
+    def test_all(self):
+        """ Test all the sub models. """
+        for model in self.models:
+            model.test(save=False)
+
     def train(self):
         """ Train the model. """
         threads = [ threading.Thread(target=model.train, args=()) for model in self.models ]
diff --git a/model.py b/model.py
@@ -168,11 +168,15 @@ class Model(object):
             (train_mean, train_std, train_top10) = self.error('train')
             log('Validation model "{0}" epoch {1:<5} train mean: {2:<15} std: {3:<15} train top10: {4:<15}\n'.format(self.tag, self.epochtrain_mean, train_std, train_top10))
 
-    def test(self):
+    def test(self, save=True):
         """ Test the model. """
-        log('# Test model "{0}": begin\n'.format(self.tag))
+        if save:
+            log('# Test model "{0}": begin\n'.format(self.tag))
+
         (mean, std, top10) = self.error('test')
         log('# Test model "{0}": mean: {1:<15} std: {2:<15} top10: {3:<15}\n'.format(self.tag, mean, std, top10))
-        log('# Test model "{0}": saving...\n'.format(self.tag))
-        self.save('{0}/{1}.last'.format(self.config['last model save path'], self.config['model name']))
-        log('# Test model "{0}": saved\n'.format(self.tag))
+
+        if save:
+            log('# Test model "{0}": saving...\n'.format(self.tag))
+            self.save('{0}/{1}.last'.format(self.config['last model save path'], self.config['model name']))
+            log('# Test model "{0}": saved\n'.format(self.tag))
diff --git a/test.py b/test.py
@@ -31,4 +31,5 @@ if __name__ == '__main__':
     data = Dataset(data)
     model = ModelType(data, config, model_pathes)
     model.build_test()
+    model.test_all()
     model.test()