commit 29a37fe435f216b21cb684071597a32779abe78a
parent 95ff4ae00954ae9a0ffb41297b2f65da251f183b
Author: Étienne Simon <esimon@esimon.eu>
Date: Fri, 13 Jun 2014 13:59:13 +0200
Datalog rank distribution
Diffstat:
1 file changed, 8 insertions(+), 6 deletions(-)
diff --git a/model.py b/model.py
@@ -149,14 +149,15 @@ class Model(object):
mean = numpy.mean(result)
std = numpy.std(result)
top10 = numpy.mean(map(lambda x: x<=10, result))
- return (mean, std, top10)
+ return (mean, std, top10, result)
def validate(self):
""" Validate the model. """
log('Validation model "{0}" epoch {1:<5}: begin\n'.format(self.tag, self.epoch))
- (valid_mean, valid_std, valid_top10) = self.error('valid')
+ (valid_mean, valid_std, valid_top10, valid_distribution) = self.error('valid')
log('Validation model "{0}" epoch {1:<5}: mean: {2:<15} valid std: {3:<15} valid top10: {4:<15}\n'.format(self.tag, self.epoch, valid_mean, valid_std, valid_top10))
- datalog(self.config['datalog path']+'/'+self.config['model name'], self.epoch, valid_mean, valid_std, valid_top10)
+ datalog(self.config['datalog path']+'/'+self.config['model name'], 'summary', self.epoch, valid_mean, valid_std, valid_top10)
+ datalog(self.config['datalog path']+'/'+self.config['model name'], 'distribution', self.epoch, valid_distribution)
if not hasattr(self, 'best_mean') or valid_mean < self.best_mean:
self.best_mean = valid_mean
if self.config['save best model']:
@@ -165,7 +166,7 @@ class Model(object):
log('Validation model "{0}" epoch {1:<5}: saved\n'.format(self.tag, self.epoch))
if self.config['validate on training data']:
- (train_mean, train_std, train_top10) = self.error('train')
+ (train_mean, train_std, train_top10, _) = self.error('train')
log('Validation model "{0}" epoch {1:<5} train mean: {2:<15} std: {3:<15} train top10: {4:<15}\n'.format(self.tag, self.epochtrain_mean, train_std, train_top10))
def test(self, save=True):
@@ -173,9 +174,10 @@ class Model(object):
if save:
log('# Test model "{0}": begin\n'.format(self.tag))
- (mean, std, top10) = self.error('test')
+ (mean, std, top10, distribution) = self.error('test')
log('# Test model "{0}": mean: {1:<15} std: {2:<15} top10: {3:<15}\n'.format(self.tag, mean, std, top10))
- datalog(self.config['datalog path']+'/'+self.config['model name'], 'test', mean, std, top10)
+ datalog(self.config['datalog path']+'/'+self.config['model name'], 'summary', 'test', mean, std, top10)
+ datalog(self.config['datalog path']+'/'+self.config['model name'], 'distribution', 'test', distribution)
if save:
log('# Test model "{0}": saving...\n'.format(self.tag))