data.py (6723B)
1 import logging 2 import random 3 import numpy 4 5 import cPickle 6 7 from picklable_itertools import iter_ 8 9 from fuel.datasets import Dataset 10 from fuel.streams import DataStream 11 from fuel.schemes import IterationScheme, ConstantScheme 12 from fuel.transformers import Batch, Mapping, SortMapping, Unpack, Padding, Transformer 13 14 import sys 15 import os 16 17 logging.basicConfig(level='INFO') 18 logger = logging.getLogger(__name__) 19 20 class QADataset(Dataset): 21 def __init__(self, path, vocab_file, n_entities, need_sep_token, **kwargs): 22 self.provides_sources = ('context', 'question', 'answer', 'candidates') 23 24 self.path = path 25 26 self.vocab = ['@entity%d' % i for i in range(n_entities)] + \ 27 [w.rstrip('\n') for w in open(vocab_file)] + \ 28 ['<UNK>', '@placeholder'] + \ 29 (['<SEP>'] if need_sep_token else []) 30 31 self.n_entities = n_entities 32 self.vocab_size = len(self.vocab) 33 self.reverse_vocab = {w: i for i, w in enumerate(self.vocab)} 34 35 super(QADataset, self).__init__(**kwargs) 36 37 def to_word_id(self, w, cand_mapping): 38 if w in cand_mapping: 39 return cand_mapping[w] 40 elif w[:7] == '@entity': 41 raise ValueError("Unmapped entity token: %s"%w) 42 elif w in self.reverse_vocab: 43 return self.reverse_vocab[w] 44 else: 45 return self.reverse_vocab['<UNK>'] 46 47 def to_word_ids(self, s, cand_mapping): 48 return numpy.array([self.to_word_id(x, cand_mapping) for x in s.split(' ')], dtype=numpy.int32) 49 50 def get_data(self, state=None, request=None): 51 if request is None or state is not None: 52 raise ValueError("Expected a request (name of a question file) and no state.") 53 54 lines = [l.rstrip('\n') for l in open(os.path.join(self.path, request))] 55 56 ctx = lines[2] 57 q = lines[4] 58 a = lines[6] 59 cand = [s.split(':')[0] for s in lines[8:]] 60 61 entities = range(self.n_entities) 62 while len(cand) > len(entities): 63 logger.warning("Too many entities (%d) for question: %s, using duplicate entity identifiers" 64 %(len(cand), request)) 65 entities = entities + entities 66 random.shuffle(entities) 67 cand_mapping = {t: k for t, k in zip(cand, entities)} 68 69 ctx = self.to_word_ids(ctx, cand_mapping) 70 q = self.to_word_ids(q, cand_mapping) 71 cand = numpy.array([self.to_word_id(x, cand_mapping) for x in cand], dtype=numpy.int32) 72 a = numpy.int32(self.to_word_id(a, cand_mapping)) 73 74 if not a < self.n_entities: 75 raise ValueError("Invalid answer token %d"%a) 76 if not numpy.all(cand < self.n_entities): 77 raise ValueError("Invalid candidate in list %s"%repr(cand)) 78 if not numpy.all(ctx < self.vocab_size): 79 raise ValueError("Context word id out of bounds: %d"%int(ctx.max())) 80 if not numpy.all(ctx >= 0): 81 raise ValueError("Context word id negative: %d"%int(ctx.min())) 82 if not numpy.all(q < self.vocab_size): 83 raise ValueError("Question word id out of bounds: %d"%int(q.max())) 84 if not numpy.all(q >= 0): 85 raise ValueError("Question word id negative: %d"%int(q.min())) 86 87 return (ctx, q, a, cand) 88 89 class QAIterator(IterationScheme): 90 requests_examples = True 91 def __init__(self, path, shuffle=False, **kwargs): 92 self.path = path 93 self.shuffle = shuffle 94 95 super(QAIterator, self).__init__(**kwargs) 96 97 def get_request_iterator(self): 98 l = [f for f in os.listdir(self.path) 99 if os.path.isfile(os.path.join(self.path, f))] 100 if self.shuffle: 101 random.shuffle(l) 102 return iter_(l) 103 104 # -------------- DATASTREAM SETUP -------------------- 105 106 107 class ConcatCtxAndQuestion(Transformer): 108 produces_examples = True 109 def __init__(self, stream, concat_question_before, separator_token=None, **kwargs): 110 assert stream.sources == ('context', 'question', 'answer', 'candidates') 111 self.sources = ('question', 'answer', 'candidates') 112 113 self.sep = numpy.array([separator_token] if separator_token is not None else [], 114 dtype=numpy.int32) 115 self.concat_question_before = concat_question_before 116 117 super(ConcatCtxAndQuestion, self).__init__(stream, **kwargs) 118 119 def get_data(self, request=None): 120 if request is not None: 121 raise ValueError('Unsupported: request') 122 123 ctx, q, a, cand = next(self.child_epoch_iterator) 124 125 if self.concat_question_before: 126 return (numpy.concatenate([q, self.sep, ctx]), a, cand) 127 else: 128 return (numpy.concatenate([ctx, self.sep, q]), a, cand) 129 130 class _balanced_batch_helper(object): 131 def __init__(self, key): 132 self.key = key 133 def __call__(self, data): 134 return data[self.key].shape[0] 135 136 def setup_datastream(path, vocab_file, config): 137 ds = QADataset(path, vocab_file, config.n_entities, need_sep_token=config.concat_ctx_and_question) 138 it = QAIterator(path, shuffle=config.shuffle_questions) 139 140 stream = DataStream(ds, iteration_scheme=it) 141 142 if config.concat_ctx_and_question: 143 stream = ConcatCtxAndQuestion(stream, config.concat_question_before, ds.reverse_vocab['<SEP>']) 144 145 # Sort sets of multiple batches to make batches of similar sizes 146 stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size * config.sort_batch_count)) 147 comparison = _balanced_batch_helper(stream.sources.index('question' if config.concat_ctx_and_question else 'context')) 148 stream = Mapping(stream, SortMapping(comparison)) 149 stream = Unpack(stream) 150 151 stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size)) 152 stream = Padding(stream, mask_sources=['context', 'question', 'candidates'], mask_dtype='int32') 153 154 return ds, stream 155 156 if __name__ == "__main__": 157 # Test 158 class DummyConfig: 159 def __init__(self): 160 self.shuffle_entities = True 161 self.shuffle_questions = False 162 self.concat_ctx_and_question = False 163 self.concat_question_before = False 164 self.batch_size = 2 165 self.sort_batch_count = 1000 166 167 ds, stream = setup_datastream(os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/training"), 168 os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/stats/training/vocab.txt"), 169 DummyConfig()) 170 it = stream.get_epoch_iterator() 171 172 for i, d in enumerate(stream.get_epoch_iterator()): 173 print '--' 174 print d 175 if i > 2: break 176 177 # vim: set sts=4 ts=4 sw=4 tw=0 et :