deep_question_answering

Implementation of "Teaching Machines to Read and Comprehend" proposed by Google DeepMind
git clone https://esimon.eu/repos/deep_question_answering.git
Log | Files | Refs | README | LICENSE

paramsaveload.py (887B)


      1 import logging
      2 
      3 import numpy
      4 
      5 import cPickle
      6 
      7 from blocks.extensions import SimpleExtension
      8 
      9 logging.basicConfig(level='INFO')
     10 logger = logging.getLogger('extensions.SaveLoadParams')
     11 
     12 class SaveLoadParams(SimpleExtension):
     13 	def __init__(self, path, model, **kwargs):
     14 		super(SaveLoadParams, self).__init__(**kwargs)
     15 
     16 		self.path = path
     17 		self.model = model
     18 	
     19 	def do_save(self):
     20 		with open(self.path, 'w') as f:
     21 			logger.info('Saving parameters to %s...'%self.path)
     22 			cPickle.dump(self.model.get_parameter_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
     23 	
     24 	def do_load(self):
     25 		try:
     26 			with open(self.path, 'r') as f:
     27 				logger.info('Loading parameters from %s...'%self.path)
     28 				self.model.set_parameter_values(cPickle.load(f))
     29 		except IOError:
     30 			pass
     31 
     32 	def do(self, which_callback, *args):
     33 		if which_callback == 'before_training':
     34 			self.do_load()
     35 		else:
     36 			self.do_save()
     37