attentive_reader.py (6956B)
1 import theano 2 from theano import tensor 3 import numpy 4 5 from blocks.bricks import Tanh, Softmax, Linear, MLP, Identity, Rectifier 6 from blocks.bricks.lookup import LookupTable 7 from blocks.bricks.recurrent import LSTM 8 9 from blocks.filter import VariableFilter 10 from blocks.roles import WEIGHT 11 from blocks.graph import ComputationGraph, apply_dropout, apply_noise 12 13 def make_bidir_lstm_stack(seq, seq_dim, mask, sizes, skip=True, name=''): 14 bricks = [] 15 16 curr_dim = [seq_dim] 17 curr_hidden = [seq] 18 19 hidden_list = [] 20 for k, dim in enumerate(sizes): 21 fwd_lstm_ins = [Linear(input_dim=d, output_dim=4*dim, name='%s_fwd_lstm_in_%d_%d'%(name,k,l)) for l, d in enumerate(curr_dim)] 22 fwd_lstm = LSTM(dim=dim, activation=Tanh(), name='%s_fwd_lstm_%d'%(name,k)) 23 24 bwd_lstm_ins = [Linear(input_dim=d, output_dim=4*dim, name='%s_bwd_lstm_in_%d_%d'%(name,k,l)) for l, d in enumerate(curr_dim)] 25 bwd_lstm = LSTM(dim=dim, activation=Tanh(), name='%s_bwd_lstm_%d'%(name,k)) 26 27 bricks = bricks + [fwd_lstm, bwd_lstm] + fwd_lstm_ins + bwd_lstm_ins 28 29 fwd_tmp = sum(x.apply(v) for x, v in zip(fwd_lstm_ins, curr_hidden)) 30 bwd_tmp = sum(x.apply(v) for x, v in zip(bwd_lstm_ins, curr_hidden)) 31 fwd_hidden, _ = fwd_lstm.apply(fwd_tmp, mask=mask) 32 bwd_hidden, _ = bwd_lstm.apply(bwd_tmp[::-1], mask=mask[::-1]) 33 hidden_list = hidden_list + [fwd_hidden, bwd_hidden] 34 if skip: 35 curr_hidden = [seq, fwd_hidden, bwd_hidden[::-1]] 36 curr_dim = [seq_dim, dim, dim] 37 else: 38 curr_hidden = [fwd_hidden, bwd_hidden[::-1]] 39 curr_dim = [dim, dim] 40 41 return bricks, hidden_list 42 43 class Model(): 44 def __init__(self, config, vocab_size): 45 question = tensor.imatrix('question') 46 question_mask = tensor.imatrix('question_mask') 47 context = tensor.imatrix('context') 48 context_mask = tensor.imatrix('context_mask') 49 answer = tensor.ivector('answer') 50 candidates = tensor.imatrix('candidates') 51 candidates_mask = tensor.imatrix('candidates_mask') 52 53 bricks = [] 54 55 question = question.dimshuffle(1, 0) 56 question_mask = question_mask.dimshuffle(1, 0) 57 context = context.dimshuffle(1, 0) 58 context_mask = context_mask.dimshuffle(1, 0) 59 60 # Embed questions and cntext 61 embed = LookupTable(vocab_size, config.embed_size, name='question_embed') 62 bricks.append(embed) 63 64 qembed = embed.apply(question) 65 cembed = embed.apply(context) 66 67 qlstms, qhidden_list = make_bidir_lstm_stack(qembed, config.embed_size, question_mask.astype(theano.config.floatX), 68 config.question_lstm_size, config.question_skip_connections, 'q') 69 clstms, chidden_list = make_bidir_lstm_stack(cembed, config.embed_size, context_mask.astype(theano.config.floatX), 70 config.ctx_lstm_size, config.ctx_skip_connections, 'ctx') 71 bricks = bricks + qlstms + clstms 72 73 # Calculate question encoding (concatenate layer1) 74 if config.question_skip_connections: 75 qenc_dim = 2*sum(config.question_lstm_size) 76 qenc = tensor.concatenate([h[-1,:,:] for h in qhidden_list], axis=1) 77 else: 78 qenc_dim = 2*config.question_lstm_size[-1] 79 qenc = tensor.concatenate([h[-1,:,:] for h in qhidden_list[-2:]], axis=1) 80 qenc.name = 'qenc' 81 82 # Calculate context encoding (concatenate layer1) 83 if config.ctx_skip_connections: 84 cenc_dim = 2*sum(config.ctx_lstm_size) 85 cenc = tensor.concatenate(chidden_list, axis=2) 86 else: 87 cenc_dim = 2*config.ctx_lstm_size[-1] 88 cenc = tensor.concatenate(chidden_list[-2:], axis=2) 89 cenc.name = 'cenc' 90 91 # Attention mechanism MLP 92 attention_mlp = MLP(dims=config.attention_mlp_hidden + [1], 93 activations=config.attention_mlp_activations[1:] + [Identity()], 94 name='attention_mlp') 95 attention_qlinear = Linear(input_dim=qenc_dim, output_dim=config.attention_mlp_hidden[0], name='attq') 96 attention_clinear = Linear(input_dim=cenc_dim, output_dim=config.attention_mlp_hidden[0], use_bias=False, name='attc') 97 bricks += [attention_mlp, attention_qlinear, attention_clinear] 98 layer1 = Tanh().apply(attention_clinear.apply(cenc.reshape((cenc.shape[0]*cenc.shape[1], cenc.shape[2]))) 99 .reshape((cenc.shape[0],cenc.shape[1],config.attention_mlp_hidden[0])) 100 + attention_qlinear.apply(qenc)[None, :, :]) 101 layer1.name = 'layer1' 102 att_weights = attention_mlp.apply(layer1.reshape((layer1.shape[0]*layer1.shape[1], layer1.shape[2]))) 103 att_weights.name = 'att_weights_0' 104 att_weights = att_weights.reshape((layer1.shape[0], layer1.shape[1])) 105 att_weights.name = 'att_weights' 106 107 attended = tensor.sum(cenc * tensor.nnet.softmax(att_weights.T).T[:, :, None], axis=0) 108 attended.name = 'attended' 109 110 # Now we can calculate our output 111 out_mlp = MLP(dims=[cenc_dim + qenc_dim] + config.out_mlp_hidden + [config.n_entities], 112 activations=config.out_mlp_activations + [Identity()], 113 name='out_mlp') 114 bricks += [out_mlp] 115 probs = out_mlp.apply(tensor.concatenate([attended, qenc], axis=1)) 116 probs.name = 'probs' 117 118 is_candidate = tensor.eq(tensor.arange(config.n_entities, dtype='int32')[None, None, :], 119 tensor.switch(candidates_mask, candidates, -tensor.ones_like(candidates))[:, :, None]).sum(axis=1) 120 probs = tensor.switch(is_candidate, probs, -1000 * tensor.ones_like(probs)) 121 122 # Calculate prediction, cost and error rate 123 pred = probs.argmax(axis=1) 124 cost = Softmax().categorical_cross_entropy(answer, probs).mean() 125 error_rate = tensor.neq(answer, pred).mean() 126 127 # Apply dropout 128 cg = ComputationGraph([cost, error_rate]) 129 if config.w_noise > 0: 130 noise_vars = VariableFilter(roles=[WEIGHT])(cg) 131 cg = apply_noise(cg, noise_vars, config.w_noise) 132 if config.dropout > 0: 133 cg = apply_dropout(cg, qhidden_list + chidden_list, config.dropout) 134 [cost_reg, error_rate_reg] = cg.outputs 135 136 # Other stuff 137 cost_reg.name = cost.name = 'cost' 138 error_rate_reg.name = error_rate.name = 'error_rate' 139 140 self.sgd_cost = cost_reg 141 self.monitor_vars = [[cost_reg], [error_rate_reg]] 142 self.monitor_vars_valid = [[cost], [error_rate]] 143 144 # Initialize bricks 145 for brick in bricks: 146 brick.weights_init = config.weights_init 147 brick.biases_init = config.biases_init 148 brick.initialize() 149 150 151 152 # vim: set sts=4 ts=4 sw=4 tw=0 et :