deep_question_answering

Implementation of "Teaching Machines to Read and Comprehend" proposed by Google DeepMind
git clone https://esimon.eu/repos/deep_question_answering.git
Log | Files | Refs | README | LICENSE

train.py (3586B)


      1 #!/usr/bin/env python
      2 
      3 import logging
      4 import numpy
      5 import sys
      6 import os
      7 import importlib
      8 
      9 import theano
     10 from theano import tensor
     11 
     12 from blocks.extensions import Printing, SimpleExtension, FinishAfter, ProgressBar
     13 from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
     14 from blocks.graph import ComputationGraph
     15 from blocks.main_loop import MainLoop
     16 from blocks.model import Model
     17 from blocks.algorithms import GradientDescent
     18 
     19 try:
     20     from blocks.extras.extensions.plot import Plot
     21     plot_avail = True
     22 except ImportError:
     23     plot_avail = False
     24     print "No plotting extension available."
     25 
     26 import data
     27 from paramsaveload import SaveLoadParams
     28 
     29 logging.basicConfig(level='INFO')
     30 logger = logging.getLogger(__name__)
     31 
     32 sys.setrecursionlimit(500000)
     33 
     34 if __name__ == "__main__":
     35     if len(sys.argv) != 2:
     36         print >> sys.stderr, 'Usage: %s config' % sys.argv[0]
     37         sys.exit(1)
     38     model_name = sys.argv[1]
     39     config = importlib.import_module('.%s' % model_name, 'config')
     40 
     41     # Build datastream
     42     path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/training")
     43     valid_path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/validation")
     44     vocab_path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/stats/training/vocab.txt")
     45 
     46     ds, train_stream = data.setup_datastream(path, vocab_path, config)
     47     _, valid_stream = data.setup_datastream(valid_path, vocab_path, config)
     48 
     49     dump_path = os.path.join("model_params", model_name+".pkl")
     50 
     51     # Build model
     52     m = config.Model(config, ds.vocab_size)
     53 
     54     # Build the Blocks stuff for training
     55     model = Model(m.sgd_cost)
     56 
     57     algorithm = GradientDescent(cost=m.sgd_cost,
     58                                 step_rule=config.step_rule,
     59                                 parameters=model.parameters)
     60 
     61     extensions = [
     62             TrainingDataMonitoring(
     63                 [v for l in m.monitor_vars for v in l],
     64                 prefix='train',
     65                 every_n_batches=config.print_freq)
     66     ]
     67     if config.save_freq is not None and dump_path is not None:
     68         extensions += [
     69             SaveLoadParams(path=dump_path,
     70                            model=model,
     71                            before_training=True,
     72                            after_training=True,
     73                            after_epoch=True,
     74                            every_n_batches=config.save_freq)
     75         ]
     76     if valid_stream is not None and config.valid_freq != -1:
     77         extensions += [
     78             DataStreamMonitoring(
     79                 [v for l in m.monitor_vars_valid for v in l],
     80                 valid_stream,
     81                 prefix='valid',
     82                 every_n_batches=config.valid_freq),
     83         ]
     84     if plot_avail:
     85         plot_channels = [['train_' + v.name for v in lt] + ['valid_' + v.name for v in lv]
     86                          for lt, lv in zip(m.monitor_vars, m.monitor_vars_valid)]
     87         extensions += [
     88             Plot(document='deepmind_qa_'+model_name,
     89                  channels=plot_channels,
     90                  # server_url='http://localhost:5006/', # If you need, change this
     91                  every_n_batches=config.print_freq)
     92         ]
     93     extensions += [
     94             Printing(every_n_batches=config.print_freq,
     95                      after_epoch=True),
     96             ProgressBar()
     97     ]
     98 
     99     main_loop = MainLoop(
    100         model=model,
    101         data_stream=train_stream,
    102         algorithm=algorithm,
    103         extensions=extensions
    104     )
    105 
    106     # Run the model !
    107     main_loop.run()
    108     main_loop.profile.report()
    109 
    110 
    111 
    112 #  vim: set sts=4 ts=4 sw=4 tw=0 et :