transform

old TransE-like models
git clone https://esimon.eu/repos/transform.git
Log | Files | Refs | README

commit 29a37fe435f216b21cb684071597a32779abe78a
parent 95ff4ae00954ae9a0ffb41297b2f65da251f183b
Author: Étienne Simon <esimon@esimon.eu>
Date:   Fri, 13 Jun 2014 13:59:13 +0200

Datalog rank distribution

Diffstat:
Mmodel.py | 14++++++++------
1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/model.py b/model.py @@ -149,14 +149,15 @@ class Model(object): mean = numpy.mean(result) std = numpy.std(result) top10 = numpy.mean(map(lambda x: x<=10, result)) - return (mean, std, top10) + return (mean, std, top10, result) def validate(self): """ Validate the model. """ log('Validation model "{0}" epoch {1:<5}: begin\n'.format(self.tag, self.epoch)) - (valid_mean, valid_std, valid_top10) = self.error('valid') + (valid_mean, valid_std, valid_top10, valid_distribution) = self.error('valid') log('Validation model "{0}" epoch {1:<5}: mean: {2:<15} valid std: {3:<15} valid top10: {4:<15}\n'.format(self.tag, self.epoch, valid_mean, valid_std, valid_top10)) - datalog(self.config['datalog path']+'/'+self.config['model name'], self.epoch, valid_mean, valid_std, valid_top10) + datalog(self.config['datalog path']+'/'+self.config['model name'], 'summary', self.epoch, valid_mean, valid_std, valid_top10) + datalog(self.config['datalog path']+'/'+self.config['model name'], 'distribution', self.epoch, valid_distribution) if not hasattr(self, 'best_mean') or valid_mean < self.best_mean: self.best_mean = valid_mean if self.config['save best model']: @@ -165,7 +166,7 @@ class Model(object): log('Validation model "{0}" epoch {1:<5}: saved\n'.format(self.tag, self.epoch)) if self.config['validate on training data']: - (train_mean, train_std, train_top10) = self.error('train') + (train_mean, train_std, train_top10, _) = self.error('train') log('Validation model "{0}" epoch {1:<5} train mean: {2:<15} std: {3:<15} train top10: {4:<15}\n'.format(self.tag, self.epochtrain_mean, train_std, train_top10)) def test(self, save=True): @@ -173,9 +174,10 @@ class Model(object): if save: log('# Test model "{0}": begin\n'.format(self.tag)) - (mean, std, top10) = self.error('test') + (mean, std, top10, distribution) = self.error('test') log('# Test model "{0}": mean: {1:<15} std: {2:<15} top10: {3:<15}\n'.format(self.tag, mean, std, top10)) - datalog(self.config['datalog path']+'/'+self.config['model name'], 'test', mean, std, top10) + datalog(self.config['datalog path']+'/'+self.config['model name'], 'summary', 'test', mean, std, top10) + datalog(self.config['datalog path']+'/'+self.config['model name'], 'distribution', 'test', distribution) if save: log('# Test model "{0}": saving...\n'.format(self.tag))