commit a67f85dd7a3d6ca69d9adf7cbac2cc796079d223
parent a557b939eb104ec7e0df42193e014e3137eb70f8
Author: Alex Auvolat <alex.auvolat@ens.fr>
Date:   Sat, 25 Jul 2015 14:38:06 -0400
Add monitor_freq config variable
Diffstat:
1 file changed, 10 insertions(+), 5 deletions(-)
diff --git a/train.py b/train.py
@@ -106,16 +106,21 @@ if __name__ == "__main__":
     dump_path = os.path.join('model_data', model_name) + '.pkl'
     logger.info('Dump path: %s' % dump_path)
 
-    extensions=[TrainingDataMonitoring(monitored, prefix='train', every_n_batches=10000),
+    if hasattr(config, 'monitor_freq'):
+        monitor_freq = config.monitor_freq
+    else:
+        monitor_freq = 10000
+
+    extensions=[TrainingDataMonitoring(monitored, prefix='train', every_n_batches=monitor_freq),
                 DataStreamMonitoring(valid_monitored, valid_stream,
                                      prefix='valid',
-                                     every_n_batches=10000),
-                Printing(every_n_batches=10000),
+                                     every_n_batches=monitor_freq),
+                Printing(every_n_batches=monitor_freq),
                 FinishAfter(every_n_batches=10000000),
 
                 SaveLoadParams(dump_path, cg,
                                before_training=True,        # before training -> load params
-                               every_n_batches=10000,       # every N batches -> save params
+                               every_n_batches=monitor_freq,# every N batches -> save params
                                after_epoch=True,            # after epoch -> save params
                                after_training=True,         # after training -> save params
                                ),
@@ -123,7 +128,7 @@ if __name__ == "__main__":
                 RunOnTest(model_name,
                           model,
                           stream,
-                          every_n_batches=10000),
+                          every_n_batches=monitor_freq),
                 ]
 
     if '--progress' in sys.argv: