deep_question_answering

Implementation of "Teaching Machines to Read and Comprehend" proposed by Google DeepMind
git clone https://esimon.eu/repos/deep_question_answering.git
Log | Files | Refs | README | LICENSE

deep_lstm.py (3769B)


      1 import theano
      2 from theano import tensor
      3 import numpy
      4 
      5 from blocks.bricks import Tanh, Softmax, Linear, MLP, Identity, Rectifier
      6 from blocks.bricks.lookup import LookupTable
      7 from blocks.bricks.recurrent import LSTM
      8 
      9 from blocks.graph import ComputationGraph, apply_dropout
     10 
     11 
     12 class Model():
     13     def __init__(self, config, vocab_size):
     14         question = tensor.imatrix('question')
     15         question_mask = tensor.imatrix('question_mask')
     16         answer = tensor.ivector('answer')
     17         candidates = tensor.imatrix('candidates')
     18         candidates_mask = tensor.imatrix('candidates_mask')
     19 
     20         bricks = []
     21 
     22 
     23         # set time as first dimension
     24         question = question.dimshuffle(1, 0)
     25         question_mask = question_mask.dimshuffle(1, 0)
     26 
     27         # Embed questions
     28         embed = LookupTable(vocab_size, config.embed_size, name='question_embed')
     29         bricks.append(embed)
     30         qembed = embed.apply(question)
     31 
     32         # Create and apply LSTM stack
     33         curr_dim = config.embed_size
     34         curr_hidden = qembed
     35 
     36         hidden_list = []
     37         for k, dim in enumerate(config.lstm_size):
     38             lstm_in = Linear(input_dim=curr_dim, output_dim=4*dim, name='lstm_in_%d'%k)
     39             lstm = LSTM(dim=dim, activation=Tanh(), name='lstm_%d'%k)
     40             bricks = bricks + [lstm_in, lstm]
     41 
     42             tmp = lstm_in.apply(curr_hidden)
     43             hidden, _ = lstm.apply(tmp, mask=question_mask.astype(theano.config.floatX))
     44             hidden_list.append(hidden)
     45             if config.skip_connections:
     46                 curr_hidden = tensor.concatenate([hidden, qembed], axis=2)
     47                 curr_dim =  dim + config.embed_size
     48             else:
     49                 curr_hidden = hidden
     50                 curr_dim = dim
     51 
     52         # Create and apply output MLP
     53         if config.skip_connections:
     54             out_mlp = MLP(dims=[sum(config.lstm_size)] + config.out_mlp_hidden + [config.n_entities],
     55                           activations=config.out_mlp_activations + [Identity()],
     56                           name='out_mlp')
     57             bricks.append(out_mlp)
     58 
     59             probs = out_mlp.apply(tensor.concatenate([h[-1,:,:] for h in hidden_list], axis=1))
     60         else:
     61             out_mlp = MLP(dims=[config.lstm_size[-1]] + config.out_mlp_hidden + [config.n_entities],
     62                           activations=config.out_mlp_activations + [Identity()],
     63                           name='out_mlp')
     64             bricks.append(out_mlp)
     65 
     66             probs = out_mlp.apply(hidden_list[-1][-1,:,:])
     67 
     68         is_candidate = tensor.eq(tensor.arange(config.n_entities, dtype='int32')[None, None, :],
     69                                  tensor.switch(candidates_mask, candidates, -tensor.ones_like(candidates))[:, :, None]).sum(axis=1)
     70         probs = tensor.switch(is_candidate, probs, -1000 * tensor.ones_like(probs))
     71 
     72         # Calculate prediction, cost and error rate
     73         pred = probs.argmax(axis=1)
     74         cost = Softmax().categorical_cross_entropy(answer, probs).mean()
     75         error_rate = tensor.neq(answer, pred).mean()
     76 
     77         # Apply dropout
     78         cg = ComputationGraph([cost, error_rate])
     79         if config.dropout > 0:
     80             cg = apply_dropout(cg, hidden_list, config.dropout)
     81         [cost_reg, error_rate_reg] = cg.outputs
     82 
     83         # Other stuff
     84         cost_reg.name = cost.name = 'cost'
     85         error_rate_reg.name = error_rate.name = 'error_rate'
     86 
     87         self.sgd_cost = cost_reg
     88         self.monitor_vars = [[cost_reg], [error_rate_reg]]
     89         self.monitor_vars_valid = [[cost], [error_rate]]
     90 
     91         # Initialize bricks
     92         for brick in bricks:
     93             brick.weights_init = config.weights_init
     94             brick.biases_init = config.biases_init
     95             brick.initialize()
     96 
     97         
     98 
     99 #  vim: set sts=4 ts=4 sw=4 tw=0 et :