 # gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git

batcher.py (8378B)

```      1 from typing import Any, Dict, List, Tuple
2 import collections
3
4 import torch
5
6
7 # Must be kept prefix-sorted!
8 # (prefix, list depth)
9 FEATURE_PREFIXES: List[Tuple[str, int]] = [
10         ("query_e1_neighborhood_", 1),
11         ("query_e2_neighborhood_", 1),
12         ("candidates_e1_neighborhood_", 3),
13         ("candidates_e2_neighborhood_", 3),
14         ("first_e1_neighborhood_", 1),
15         ("first_e2_neighborhood_", 1),
16         ("second_e1_neighborhood_", 1),
17         ("second_e2_neighborhood_", 1),
18         ("third_e1_neighborhood_", 1),
19         ("third_e2_neighborhood_", 1),
20         ("query_", 0),
21         ("candidates_", 2),
22         ("first_", 0),
23         ("second_", 0),
24         ("third_", 0),
25         ("", 0)]
26
27
28 class Batcher:
29     """
30     Batch a group of sample together.
31
32     Two new features are derived from the "text": its length and a mask.
33     """
34     def __init__(self, pad_value: int) -> None:
35         """ Initialize a Batcher, using the provided value to pad text. """
37
38     def add_length_field(self, batch: Dict[str, Any], prefix: str, depth: int) -> None:
39         """ Add the length field for the given prefix. """
40         text: List[Any] = batch[f"{prefix}text"]
41         batch_size: int = len(text)
42
43         if depth == 0:
44             # text is a list of sentences
45             lengths: torch.Tensor = torch.empty((batch_size,), dtype=torch.int64)
46             for b, sentence in enumerate(text):
47                 lengths[b] = sentence.shape
48         elif depth == 1:
49             # text is a list of list of sentences (each sample contains several candidates)
50             size: int = len(text)
51             lengths: torch.Tensor = torch.empty((batch_size, size), dtype=torch.int64)
52             for b, sample in enumerate(text):
53                 for i, sentence in enumerate(sample):
54                     lengths[b, i] = sentence.shape
55         elif depth == 2:
56             # text is a list of list of list of sentences (each sample contains several candidates)
57             way: int = len(text)
58             shot: int = len(text)
59             lengths: torch.Tensor = torch.empty((batch_size, way, shot), dtype=torch.int64)
60             for b, sample in enumerate(text):
61                 for w, candidates in enumerate(sample):
62                     for s, candidate in enumerate(candidates):
63                         lengths[b, w, s] = candidate.shape
64         elif depth == 3:
65             # text is a list of list of list of list of sentences (each sample contains several candidates' neighborhoods)
66             way: int = len(text)
67             shot: int = len(text)
68             size: int = len(text)
69             lengths: torch.Tensor = torch.empty((batch_size, way, shot, size), dtype=torch.int64)
70             for b, sample in enumerate(text):
71                 for w, candidates in enumerate(sample):
72                     for s, candidate in enumerate(candidates):
73                         for n, neighbor in enumerate(candidate):
74                             lengths[b, w, s, n] = neighbor.shape
75
76         batch[f"{prefix}length"] = lengths
77
78     def process_text(self, batch: Dict[str, Any], prefix: str, depth: int, key: str) -> None:
80         in_text: List[Any] = batch[f"{prefix}{key}"]
81         if isinstance(batch[f"{prefix}length"], list):
83         max_seq_len: int = max(batch[f"{prefix}length"].max(), 1)
84         batch_size: int = len(in_text)
85
86         if depth == 0:
87             # text is a list of sentences
88             text: torch.Tensor = torch.empty((batch_size, max_seq_len), dtype=torch.int64)
89             mask: torch.Tensor = torch.empty((batch_size, max_seq_len), dtype=torch.bool)
90             for b, sentence in enumerate(in_text):
91                 text[b, :sentence.shape] = sentence
95         elif depth == 1:
96             # text is a list of list of sentences (each sample contains several candidates)
97             # In this case, we are not sure the tensor is full (some neighborhoods might be of different sizes or even empty)
98             size: int = len(in_text)
99             text: torch.Tensor = torch.empty((batch_size, size, max_seq_len), dtype=torch.int64)
100             mask: torch.Tensor = torch.zeros((batch_size, size, max_seq_len), dtype=torch.bool)
101             for b, samples in enumerate(in_text):
102                 for i, sentence in enumerate(samples):
103                     text[b, i, :sentence.shape] = sentence
104                     text[b, i, sentence.shape:] = self.pad_value
105                     mask[b, i, :sentence.shape] = 1
106         elif depth == 2:
107             # text is a list of list of list of sentences (each sample contains several candidates)
108             # In this case, we are sure the tensor is full (all n way have the save k shots)
109             way: int = len(in_text)
110             shot: int = len(in_text)
111             text: torch.Tensor = torch.empty((batch_size, way, shot, max_seq_len), dtype=torch.int64)
112             mask: torch.Tensor = torch.empty((batch_size, way, shot, max_seq_len), dtype=torch.bool)
113             for b, samples in enumerate(in_text):
114                 for w, candidates in enumerate(samples):
115                     for s, candidate in enumerate(candidates):
116                         text[b, w, s, :candidate.shape] = candidate
117                         text[b, w, s, candidate.shape:] = self.pad_value
118                         mask[b, w, s, :candidate.shape] = 1
119                         mask[b, w, s, candidate.shape:] = 0
120         elif depth == 3:
121             # text is a list of list of list of list of sentences (each sample contains several candidates' neighborhoods)
122             # In this case, we are not sure the tensor is full (some neighborhoods might be of different sizes or even empty)
123             way: int = len(in_text)
124             shot: int = len(in_text)
125             size: int = len(in_text)
126             text: torch.Tensor = torch.empty((batch_size, way, shot, size, max_seq_len), dtype=torch.int64)
127             mask: torch.Tensor = torch.empty((batch_size, way, shot, size, max_seq_len), dtype=torch.bool)
128             for b, samples in enumerate(in_text):
129                 for w, candidates in enumerate(samples):
130                     for s, candidate in enumerate(candidates):
131                         for n, neighbor in enumerate(candidate):
132                             text[b, w, s, n, :neighbor.shape] = neighbor
133                             text[b, w, s, n, neighbor.shape:] = self.pad_value
134                             mask[b, w, s, n, :neighbor.shape] = 1
135                             mask[b, w, s, n, neighbor.shape:] = 0
136
137         batch[f"{prefix}{key}"] = text
138         if f"{prefix}mask" not in batch:
140
141     def process_int_feature(self, batch: Dict[str, Any], prefix: str, feature: str) -> None:
142         """ Transform a list of integer into a torch LongTensor. """
143         # TODO handle neighborhoods of different sizes
144         if isinstance(batch[f"{prefix}{feature}"], torch.Tensor):
145             batch[f"{prefix}{feature}"] = torch.stack(batch[f"{prefix}{feature}"])
146         else:
147             batch[f"{prefix}{feature}"] = torch.tensor(batch[f"{prefix}{feature}"], dtype=torch.int64)
148
149     def __call__(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
150         """ Batch the provided samples """
151         batch = collections.defaultdict(list)
152         for sample in samples:
153             for key, value in sample.items():
154                 batch[key].append(value)
155
156         for key in list(batch.keys()):
157             for prefix, depth in FEATURE_PREFIXES:
158                 if key.startswith(prefix):
159                     break
160             feature: str = key[len(prefix):]
161             if feature in ["text", "mlm_input", "mlm_target"]:
162                 self.process_text(batch, prefix, depth, feature)
163             if feature in ["relation", "entity_positions", "entity_identifiers", "entity_degrees", "edge_identifier", "polarity", "answer", "eid"]:
164                 self.process_int_feature(batch, prefix, feature)
165
166         return batch
```