deep_bidir_lstm.py (4651B)
1 import theano 2 from theano import tensor 3 import numpy 4 5 from blocks.bricks import Tanh, Softmax, Linear, MLP, Identity, Rectifier 6 from blocks.bricks.lookup import LookupTable 7 from blocks.bricks.recurrent import LSTM 8 9 from blocks.filter import VariableFilter 10 from blocks.roles import WEIGHT 11 from blocks.graph import ComputationGraph, apply_dropout, apply_noise 12 13 class Model(): 14 def __init__(self, config, vocab_size): 15 question = tensor.imatrix('question') 16 question_mask = tensor.imatrix('question_mask') 17 answer = tensor.ivector('answer') 18 candidates = tensor.imatrix('candidates') 19 candidates_mask = tensor.imatrix('candidates_mask') 20 21 bricks = [] 22 23 24 # set time as first dimension 25 question = question.dimshuffle(1, 0) 26 question_mask = question_mask.dimshuffle(1, 0) 27 28 # Embed questions 29 embed = LookupTable(vocab_size, config.embed_size, name='question_embed') 30 bricks.append(embed) 31 qembed = embed.apply(question) 32 33 # Create and apply LSTM stack 34 curr_dim = [config.embed_size] 35 curr_hidden = [qembed] 36 37 hidden_list = [] 38 for k, dim in enumerate(config.lstm_size): 39 fwd_lstm_ins = [Linear(input_dim=d, output_dim=4*dim, name='fwd_lstm_in_%d_%d'%(k,l)) for l, d in enumerate(curr_dim)] 40 fwd_lstm = LSTM(dim=dim, activation=Tanh(), name='fwd_lstm_%d'%k) 41 42 bwd_lstm_ins = [Linear(input_dim=d, output_dim=4*dim, name='bwd_lstm_in_%d_%d'%(k,l)) for l, d in enumerate(curr_dim)] 43 bwd_lstm = LSTM(dim=dim, activation=Tanh(), name='bwd_lstm_%d'%k) 44 45 bricks = bricks + [fwd_lstm, bwd_lstm] + fwd_lstm_ins + bwd_lstm_ins 46 47 fwd_tmp = sum(x.apply(v) for x, v in zip(fwd_lstm_ins, curr_hidden)) 48 bwd_tmp = sum(x.apply(v) for x, v in zip(bwd_lstm_ins, curr_hidden)) 49 fwd_hidden, _ = fwd_lstm.apply(fwd_tmp, mask=question_mask.astype(theano.config.floatX)) 50 bwd_hidden, _ = bwd_lstm.apply(bwd_tmp[::-1], mask=question_mask.astype(theano.config.floatX)[::-1]) 51 hidden_list = hidden_list + [fwd_hidden, bwd_hidden] 52 if config.skip_connections: 53 curr_hidden = [qembed, fwd_hidden, bwd_hidden[::-1]] 54 curr_dim = [config.embed_size, dim, dim] 55 else: 56 curr_hidden = [fwd_hidden, bwd_hidden[::-1]] 57 curr_dim = [dim, dim] 58 59 # Create and apply output MLP 60 if config.skip_connections: 61 out_mlp = MLP(dims=[2*sum(config.lstm_size)] + config.out_mlp_hidden + [config.n_entities], 62 activations=config.out_mlp_activations + [Identity()], 63 name='out_mlp') 64 bricks.append(out_mlp) 65 66 probs = out_mlp.apply(tensor.concatenate([h[-1,:,:] for h in hidden_list], axis=1)) 67 else: 68 out_mlp = MLP(dims=[2*config.lstm_size[-1]] + config.out_mlp_hidden + [config.n_entities], 69 activations=config.out_mlp_activations + [Identity()], 70 name='out_mlp') 71 bricks.append(out_mlp) 72 73 probs = out_mlp.apply(tensor.concatenate([h[-1,:,:] for h in hidden_list[-2:]], axis=1)) 74 75 is_candidate = tensor.eq(tensor.arange(config.n_entities, dtype='int32')[None, None, :], 76 tensor.switch(candidates_mask, candidates, -tensor.ones_like(candidates))[:, :, None]).sum(axis=1) 77 probs = tensor.switch(is_candidate, probs, -1000 * tensor.ones_like(probs)) 78 79 # Calculate prediction, cost and error rate 80 pred = probs.argmax(axis=1) 81 cost = Softmax().categorical_cross_entropy(answer, probs).mean() 82 error_rate = tensor.neq(answer, pred).mean() 83 84 # Apply dropout 85 cg = ComputationGraph([cost, error_rate]) 86 if config.w_noise > 0: 87 noise_vars = VariableFilter(roles=[WEIGHT])(cg) 88 cg = apply_noise(cg, noise_vars, config.w_noise) 89 if config.dropout > 0: 90 cg = apply_dropout(cg, hidden_list, config.dropout) 91 [cost_reg, error_rate_reg] = cg.outputs 92 93 # Other stuff 94 cost_reg.name = cost.name = 'cost' 95 error_rate_reg.name = error_rate.name = 'error_rate' 96 97 self.sgd_cost = cost_reg 98 self.monitor_vars = [[cost_reg], [error_rate_reg]] 99 self.monitor_vars_valid = [[cost], [error_rate]] 100 101 # Initialize bricks 102 for brick in bricks: 103 brick.weights_init = config.weights_init 104 brick.biases_init = config.biases_init 105 brick.initialize() 106 107 108 109 # vim: set sts=4 ts=4 sw=4 tw=0 et :