train.py (1073B)
1 #!/usr/bin/env python2 2 3 from __future__ import print_function 4 import sys 5 6 from dataset import * 7 from model import * 8 from meta_model import * 9 from config import * 10 11 if __name__ == '__main__': 12 if len(sys.argv)<3: 13 print('Usage: {0} data config [models]'.format(sys.argv[0]), file=sys.stderr) 14 sys.exit(1) 15 data = sys.argv[1] 16 config_path = sys.argv[2] 17 18 if len(sys.argv)<4: model_pathes = None 19 elif len(sys.argv)>4: model_pathes = sys.argv[3:] 20 else: model_pathes = sys.argv[3] 21 22 config = load_config(config_path) 23 if config.get('meta', False) and len(sys.argv)<4: 24 model_pathes = [ None ] * config['size'] 25 if not config.get('meta', False) and isinstance(model_pathes, list): 26 print('Error: multiple model specified while running in single mode', file=sys.stderr) 27 sys.exit(1) 28 29 ModelType = Meta_model if config.get('meta', False) else Model 30 data = Dataset(data, config) 31 model = ModelType(data, config, model_pathes) 32 33 model.build_train() 34 model.build_test() 35 model.train() 36 model.test()