paramsaveload.py (887B)
1 import logging 2 3 import numpy 4 5 import cPickle 6 7 from blocks.extensions import SimpleExtension 8 9 logging.basicConfig(level='INFO') 10 logger = logging.getLogger('extensions.SaveLoadParams') 11 12 class SaveLoadParams(SimpleExtension): 13 def __init__(self, path, model, **kwargs): 14 super(SaveLoadParams, self).__init__(**kwargs) 15 16 self.path = path 17 self.model = model 18 19 def do_save(self): 20 with open(self.path, 'w') as f: 21 logger.info('Saving parameters to %s...'%self.path) 22 cPickle.dump(self.model.get_parameter_values(), f, protocol=cPickle.HIGHEST_PROTOCOL) 23 24 def do_load(self): 25 try: 26 with open(self.path, 'r') as f: 27 logger.info('Loading parameters from %s...'%self.path) 28 self.model.set_parameter_values(cPickle.load(f)) 29 except IOError: 30 pass 31 32 def do(self, which_callback, *args): 33 if which_callback == 'before_training': 34 self.do_load() 35 else: 36 self.do_save() 37