train.py (3586B)
1 #!/usr/bin/env python 2 3 import logging 4 import numpy 5 import sys 6 import os 7 import importlib 8 9 import theano 10 from theano import tensor 11 12 from blocks.extensions import Printing, SimpleExtension, FinishAfter, ProgressBar 13 from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring 14 from blocks.graph import ComputationGraph 15 from blocks.main_loop import MainLoop 16 from blocks.model import Model 17 from blocks.algorithms import GradientDescent 18 19 try: 20 from blocks.extras.extensions.plot import Plot 21 plot_avail = True 22 except ImportError: 23 plot_avail = False 24 print "No plotting extension available." 25 26 import data 27 from paramsaveload import SaveLoadParams 28 29 logging.basicConfig(level='INFO') 30 logger = logging.getLogger(__name__) 31 32 sys.setrecursionlimit(500000) 33 34 if __name__ == "__main__": 35 if len(sys.argv) != 2: 36 print >> sys.stderr, 'Usage: %s config' % sys.argv[0] 37 sys.exit(1) 38 model_name = sys.argv[1] 39 config = importlib.import_module('.%s' % model_name, 'config') 40 41 # Build datastream 42 path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/training") 43 valid_path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/validation") 44 vocab_path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/stats/training/vocab.txt") 45 46 ds, train_stream = data.setup_datastream(path, vocab_path, config) 47 _, valid_stream = data.setup_datastream(valid_path, vocab_path, config) 48 49 dump_path = os.path.join("model_params", model_name+".pkl") 50 51 # Build model 52 m = config.Model(config, ds.vocab_size) 53 54 # Build the Blocks stuff for training 55 model = Model(m.sgd_cost) 56 57 algorithm = GradientDescent(cost=m.sgd_cost, 58 step_rule=config.step_rule, 59 parameters=model.parameters) 60 61 extensions = [ 62 TrainingDataMonitoring( 63 [v for l in m.monitor_vars for v in l], 64 prefix='train', 65 every_n_batches=config.print_freq) 66 ] 67 if config.save_freq is not None and dump_path is not None: 68 extensions += [ 69 SaveLoadParams(path=dump_path, 70 model=model, 71 before_training=True, 72 after_training=True, 73 after_epoch=True, 74 every_n_batches=config.save_freq) 75 ] 76 if valid_stream is not None and config.valid_freq != -1: 77 extensions += [ 78 DataStreamMonitoring( 79 [v for l in m.monitor_vars_valid for v in l], 80 valid_stream, 81 prefix='valid', 82 every_n_batches=config.valid_freq), 83 ] 84 if plot_avail: 85 plot_channels = [['train_' + v.name for v in lt] + ['valid_' + v.name for v in lv] 86 for lt, lv in zip(m.monitor_vars, m.monitor_vars_valid)] 87 extensions += [ 88 Plot(document='deepmind_qa_'+model_name, 89 channels=plot_channels, 90 # server_url='http://localhost:5006/', # If you need, change this 91 every_n_batches=config.print_freq) 92 ] 93 extensions += [ 94 Printing(every_n_batches=config.print_freq, 95 after_epoch=True), 96 ProgressBar() 97 ] 98 99 main_loop = MainLoop( 100 model=model, 101 data_stream=train_stream, 102 algorithm=algorithm, 103 extensions=extensions 104 ) 105 106 # Run the model ! 107 main_loop.run() 108 main_loop.profile.report() 109 110 111 112 # vim: set sts=4 ts=4 sw=4 tw=0 et :