dataset.py (32424B)
1 from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union 2 import collections 3 import pathlib 4 import random 5 6 import torch 7 import transformers 8 9 import gbure.data.dictionary 10 from gbure.data.graph import Graph 11 import gbure.utils 12 13 14 class SupervisedDataset(torch.utils.data.Dataset): 15 """ 16 Read a preprocessed supervised relation extraction dataset. 17 18 A preprocessed dataset can be created from the gbure.data.prepare_* scripts. 19 """ 20 shuffleable: bool = True 21 22 def __init__(self, config: gbure.utils.dotdict, path: pathlib.Path, tokenizer: transformers.PreTrainedTokenizer, evaluation: bool, rng: Optional[random.Random] = None, data: Optional[List[Tuple[torch.Tensor, int, int, int]]] = None) -> None: 23 """ Initialize a supervised dataset and load the data in RAM. """ 24 super().__init__() 25 26 self.config: gbure.utils.dotdict = config 27 self.path: pathlib.Path = path 28 self.tokenizer: transformers.PreTrainedTokenizer = tokenizer 29 self.evaluation: bool = evaluation 30 if data is None: 31 self.load() 32 else: 33 self.data = data 34 35 def load(self) -> None: 36 """ Load the dataset into RAM. """ 37 dstype: str 38 self.data: List[Tuple[torch.Tensor, int, int, int]] 39 dstype, self.data = torch.load(self.path) 40 assert(dstype == "supervised") 41 42 def __len__(self) -> int: 43 """ Get the number of samples in the dataset. """ 44 return len(self.data) 45 46 def __getitem__(self, index: int) -> Dict[str, Any]: 47 """ Get the sample at the given index. """ 48 sample: Dict[str, Any] = {} 49 sample["text"] = self.data[index][0] 50 sample["entity_positions"] = torch.tensor(self.data[index][1:3], dtype=torch.int64) 51 sample["relation"] = self.data[index][3] 52 return sample 53 54 55 class SampledFewShotDataset(torch.utils.data.Dataset): 56 """ 57 Read a preprocessed few shot relation extraction dataset.npy file containing samples. 58 59 A preprocessed dataset can be created from the gbure.data.prepare_* scripts. 60 """ 61 shuffleable: bool = True 62 63 def __init__(self, config: gbure.utils.dotdict, path: pathlib.Path, tokenizer: transformers.PreTrainedTokenizer, evaluation: bool, rng: Optional[random.Random] = None, data: Optional[List[Tuple[torch.Tensor, int, int, int, int, List[List[torch.Tensor]], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]]] = None) -> None: 64 """ Initialize a few shot dataset and load the samples in RAM. """ 65 super().__init__() 66 67 self.config: gbure.utils.dotdict = config 68 self.path: pathlib.Path = path 69 self.tokenizer: transformers.PreTrainedTokenizer = tokenizer 70 self.evaluation: bool = evaluation 71 if data is None: 72 self.load() 73 else: 74 self.data = data 75 76 def load(self) -> None: 77 """ Load the dataset into RAM. """ 78 dstype: str 79 self.data: List[Tuple[torch.Tensor, int, int, int, int, List[List[torch.Tensor]], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]] 80 dstype, self.data = torch.load(self.path) 81 assert(dstype == "sampled fewshot") 82 83 def __len__(self) -> int: 84 """ Get the number of samples in the dataset. """ 85 return len(self.data) 86 87 def __getitem__(self, index: int) -> Dict[str, Any]: 88 """ Get the sample at the given index. """ 89 sample: Dict[str, Any] = {} 90 sample["query_text"] = self.data[index][0] 91 sample["query_entity_positions"] = torch.tensor(self.data[index][1:3], dtype=torch.int64) 92 sample["query_entity_identifiers"] = torch.tensor(self.data[index][3:5], dtype=torch.int64) 93 sample["candidates_text"] = self.data[index][5] 94 sample["candidates_entity_positions"] = torch.stack(self.data[index][6:8], dim=2) 95 sample["candidates_entity_identifiers"] = torch.stack(self.data[index][8:10], dim=2) 96 sample["answer"] = self.data[index][10] 97 return sample 98 99 100 class FewShotDataset(torch.utils.data.IterableDataset): 101 """ 102 Read a preprocessed few shot relation extraction dataset.npy file. 103 104 A preprocessed dataset can be created from the gbure.data.prepare_* 105 modules. 106 107 Config: 108 seed: the seed for the random number generator 109 shot: the number of candidates per relation 110 way: the number of relation classes used for candidates 111 """ 112 shuffleable: bool = False # FIXME ? 113 114 def __init__(self, config: gbure.utils.dotdict, path: pathlib.Path, tokenizer: transformers.PreTrainedTokenizer, evaluation: bool, rng: Optional[random.Random] = None, data: Optional[List[List[Tuple[torch.Tensor, int, int, int, int, int]]]] = None) -> None: 115 """ Initialize a few shot dataset and load the data in RAM. """ 116 super().__init__() 117 118 self.config: gbure.utils.dotdict = config 119 self.path: pathlib.Path = path 120 self.tokenizer: transformers.PreTrainedTokenizer = tokenizer 121 self.evaluation: bool = evaluation 122 123 if data is None: 124 self.load() 125 else: 126 self.data = data 127 self.num_relations: int = len(self.data) 128 self.num_samples_per_relation: int = len(self.data[0]) 129 130 def init_seed(self, worker_id: Optional[int] = None) -> None: 131 """ Initialize the RNG. """ 132 if not self.evaluation: 133 seed: int = self.config.seed 134 worker_info = torch.utils.data.get_worker_info() 135 seed += worker_id if worker_id is not None else (worker_info.id if worker_info is not None else 0) 136 rng = random.Random(seed) 137 self.rng = rng 138 139 def load(self) -> None: 140 """ Load the dataset into RAM. """ 141 dstype: str 142 self.data: List[List[Tuple[torch.Tensor, int, int, int, int, int]]] 143 dstype, self.data = torch.load(self.path) 144 assert(dstype == "fewshot") 145 146 def __len__(self): 147 """ Get the number of samples in the dataset. """ 148 return self.num_relations * self.num_samples_per_relation * self.config.get("meta_per_sample", 1) 149 150 def get_rng(self, relation: int, sentence: int) -> random.Random: 151 """ Get the random number generator for the given query. """ 152 if self.evaluation: 153 return random.Random(self.config.seed * len(self) + relation * self.num_samples_per_relation + sentence) 154 else: 155 return self.rng 156 157 @staticmethod 158 def sample_exclude(rng: random.Random, population: int, exclude: int, size: int) -> List[int]: 159 """ Chooses size unique random elements from [0, population)\\{exclude}. """ 160 samples: List[int] = rng.sample(range(population-1), size) 161 return [sample + (1 if sample >= exclude else 0) for sample in samples] 162 163 def sample_meta(self, query_relation: int, query_sid: int) -> Dict[str, Any]: 164 """ Build a fewshot sample from the given query. """ 165 rng: random.Random = self.get_rng(query_relation, query_sid) 166 167 # positives 168 candidates: List[List[Tuple[int, int]]] = [[(query_relation, sid) for sid in self.sample_exclude(rng, self.num_samples_per_relation, query_sid, self.config.shot)]] 169 # negatives 170 for negative_relation in self.sample_exclude(rng, self.num_relations, query_relation, self.config.way-1): 171 candidates.append([(negative_relation, sid) for sid in rng.sample(range(self.num_samples_per_relation), self.config.shot)]) 172 173 order: List[int] = list(range(self.config.way)) 174 rng.shuffle(order) 175 candidates = [candidates[i] for i in order] 176 answer = order.index(0) 177 178 meta: Dict[str, Any] = {} 179 meta[f"query_text"] = self.data[query_relation][query_sid][0] 180 meta[f"query_entity_positions"] = torch.tensor(self.data[query_relation][query_sid][1:3], dtype=torch.int64) 181 meta[f"query_relation"] = self.data[query_relation][query_sid][3] 182 meta[f"query_entity_identifiers"] = torch.tensor(self.data[query_relation][query_sid][4:6], dtype=torch.int64) 183 meta[f"candidates_text"] = [[self.data[shot_relation][shot_sid][0] for shot_relation, shot_sid in way] for way in candidates] 184 meta[f"candidates_entity_positions"] = torch.tensor([[self.data[shot_relation][shot_sid][1:3] for shot_relation, shot_sid in way] for way in candidates], dtype=torch.int64) 185 meta[f"candidates_relation"] = torch.tensor([[self.data[shot_relation][shot_sid][3] for shot_relation, shot_sid in way] for way in candidates], dtype=torch.int64) 186 meta[f"candidates_entity_identifiers"] = torch.tensor([[self.data[shot_relation][shot_sid][4:6] for shot_relation, shot_sid in way] for way in candidates], dtype=torch.int64) 187 meta["answer"] = answer 188 return meta 189 190 def __iter__(self) -> Iterator[Dict[str, Any]]: 191 """ Generate samples from the dataset. """ 192 self.order: List[Tuple[int, int]] = [(relation, sid) for relation in range(self.num_relations) for sid in range(self.num_samples_per_relation)] 193 if not self.evaluation: 194 self.rng.shuffle(self.order) 195 196 worker_info = torch.utils.data.get_worker_info() 197 if worker_info is None: 198 worker_modulo: int = 1 199 worker_residue: int = 0 200 else: 201 worker_modulo: int = worker_info.num_workers 202 worker_residue: int = worker_info.id 203 204 mps = self.config.get("meta_per_sample", 1) 205 for index, (relation, sid) in enumerate(self.order): 206 for j in range(mps): 207 if (index*mps+j) % worker_modulo == worker_residue: 208 yield self.sample_meta(relation, sid) 209 210 211 class UnsupervisedDataset(torch.utils.data.IterableDataset): 212 """ 213 Read a preprocessed unsupervised relation extraction dataset. 214 215 A preprocessed dataset can be created from the gbure.data.prepare_* scripts. 216 217 Config: 218 blank_probability: the probability to replace an entity with <blank/> 219 edge_sampling: the sampling strategy to avoid (or not) popular entities 220 sample_per_epoch: the number of sample in an epoch 221 seed: the seed for the random number generator 222 """ 223 shuffleable: bool = False 224 225 def __init__(self, config: gbure.utils.dotdict, path: Optional[pathlib.Path], tokenizer: transformers.PreTrainedTokenizer, evaluation: bool, rng: Optional[random.Random] = None) -> None: 226 """ Initialize a supervised dataset and load the data in RAM. """ 227 super().__init__() 228 229 self.config: gbure.utils.dotdict = config 230 self.path: Optional[pathlib.Path] = path 231 self.tokenizer: transformers.PreTrainedTokenizer = tokenizer 232 self.evaluation: bool = evaluation 233 self.load() 234 self.init_seed() 235 236 def init_seed(self, worker_id: Optional[int] = None) -> None: 237 """ Initialize the RNG. """ 238 seed: int = self.config.seed 239 worker_info = torch.utils.data.get_worker_info() 240 seed += worker_id if worker_id is not None else (worker_info.id if worker_info is not None else 0) 241 rng = random.Random(seed) 242 self.rng = rng 243 244 def load(self) -> None: 245 """ Load the dataset into RAM. """ 246 if self.path is not None: 247 self.graph = Graph(path=self.path) 248 if self.config.get("share_memory"): 249 self.graph.share_memory() 250 251 def __len__(self) -> int: 252 """ Get the number of samples in the dataset. """ 253 return self.config.sample_per_epoch 254 255 def filter_edge(self, eid: int) -> bool: 256 """ Filter edges according to the length of the corresponding sentence and the size of its neighborhoods. """ 257 edge: torch.Tensor = self.graph.edges[eid] 258 if self.graph.sentences[edge[2]].shape[0] > self.config.max_sentence_length: 259 return False 260 if self.config.get("filter_empty_neighborhood") and (self.graph.degree(edge[0]) <= 1 or self.graph.degree(edge[1]) <= 1): 261 return False 262 return True 263 264 def sample_main(self) -> int: 265 """ Sample the main edge, from which positive and negative edges can be selected. """ 266 # From Soares et al. 267 # "To prevent a large bias towards relation statements that involve popular entities, we limit the number of relation statements that contain the same entity by randomly sampling a constant number of relation statements that contain any given entity." 268 # It's hard to guess what was exactly done, so we propose several sampling strategies. 269 while True: 270 if self.config.edge_sampling == "uniform-uniform": 271 vid: int = self.rng.randint(0, self.graph.order-1) 272 reid: int = self.rng.randint(0, self.graph.degree(vid)-1) 273 eid: int = self.graph.adj[vid][reid, 1] 274 elif self.config.edge_sampling == "uniform-inverse degree": 275 vid: int = self.rng.randint(0, self.graph.order-1) 276 277 v2_candidates: torch.Tensor = torch.zeros(self.graph.degree(vid)) 278 for i, edge in enumerate(self.graph.adj[vid]): 279 v2_candidates[i] = self.graph.degree(edge[0]) 280 v2_candidates /= torch.nn.functional.normalize(v2_candidates, p=1, dim=0) 281 282 # FIXME slow, double check worker asynchronicity 283 reid: int = torch.multinomial(v2_candidates, 1).item() 284 eid: int = self.graph.adj[vid][reid, 1] 285 else: 286 raise RuntimeError("Unsuported config value for edge_sampling") 287 if self.filter_edge(eid): 288 return eid 289 290 def eid_to_sample(self, first_eid: int, second_eid: int, polarity: int) -> Dict[str, Any]: 291 """ Build a pair with the given polarity from two edge ids. """ 292 first_edge: torch.Tensor = self.graph.edges[first_eid].clone() 293 second_edge: torch.Tensor = self.graph.edges[second_eid].clone() 294 295 self.shuffle_entities(first_edge) 296 self.align_entities_as(second_edge, first_edge) 297 298 sample: Dict[str, Any] = {"polarity": polarity} 299 sample.update(self.edge_to_features(first_eid, first_edge, "first_", mlm=True)) 300 sample.update(self.edge_to_features(second_eid, second_edge, "second_", mlm=False)) 301 return sample 302 303 @staticmethod 304 def invert_entities(edge: torch.Tensor) -> None: 305 """ Invert the <e1> and <e2> tags of the edge, the text of the entities are not inverted, only the tags. """ 306 # invert vertex ids 307 tmp = edge[0].clone() 308 edge[0] = edge[1] 309 edge[1] = tmp 310 311 # invert entity positions 312 tmp = edge[3:5].clone() 313 edge[3:5] = edge[5:7] 314 edge[5:7] = tmp 315 316 def shuffle_entities(self, edge: torch.Tensor) -> None: 317 """ Invert the <e1> and <e2> tags with probability ½. """ 318 if self.rng.randint(0, 1): 319 self.invert_entities(edge) 320 321 @staticmethod 322 def align_entities_as(edge: torch.Tensor, pattern: torch.Tensor) -> None: 323 """ Invert entities of an edge if neither of them are in the same position as in the provided pattern. """ 324 if edge[0] != pattern[0] and edge[1] != pattern[1]: 325 UnsupervisedDataset.invert_entities(edge) 326 327 def mlm_features(self, text: List[int], prefix: str) -> Dict[str, Any]: 328 """ Extract mlm_input and mlm_target for masked language model loss. """ 329 # Function inspired by HuggingFace's code 330 mlm_target = torch.tensor(text, dtype=torch.int64) 331 mlm_input = torch.tensor(text, dtype=torch.int64) 332 333 # Do not mask special tokens 334 st_mask = self.tokenizer.get_special_tokens_mask(text, already_has_special_tokens=True) 335 st_mask = torch.tensor(st_mask, dtype=torch.bool) 336 337 mlm_p = torch.full((len(text),), self.config.mlm_probability) 338 mlm_mask = torch.bernoulli(mlm_p).bool() & st_mask 339 mlm_target[~mlm_mask] = -100 340 341 masked_p = self.config.mlm_masked_probability 342 masked_mask = torch.bernoulli(torch.full((len(text),), masked_p)).bool() & mlm_mask 343 mlm_input[masked_mask] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) 344 345 random_p = self.config.mlm_random_probability / (1-masked_p) 346 random_mask = torch.bernoulli(torch.full((len(text),), random_p)).bool() & mlm_mask & ~masked_mask 347 random_value = torch.randint(len(self.tokenizer), (len(text),), dtype=torch.long) 348 mlm_input[random_mask] = random_value[random_mask] 349 350 return {f"{prefix}mlm_input": mlm_input, f"{prefix}mlm_target": mlm_target} 351 352 def edge_to_features(self, eid: int, edge: torch.Tensor, prefix: str, mlm: bool) -> Dict[str, Any]: 353 """ 354 Convert an edge to the corresponding set of features (token list of the sentence, etc). 355 356 If mlm is True, features for Masked Language Model training are also generated. 357 """ 358 sample: Dict[str, Any] = {} 359 sample[f"{prefix}edge_identifier"] = eid 360 sample[f"{prefix}entity_identifiers"] = edge[0:2] 361 sample[f"{prefix}entity_degrees"] = torch.tensor([self.graph.degree(edge[0]), self.graph.degree(edge[1])], dtype=torch.int64) 362 text: List[int] = self.graph.sentences[edge[2]].tolist() 363 364 # Abuse the fact that "</eX>" < "<eX>" 365 tags: List[Tuple[int, str]] = [(edge[3], "<e1>"), (edge[4], "</e1>"), (edge[5], "<e2>"), (edge[6], "</e2>")] 366 tags.sort(reverse=True) 367 368 # When we see a start tag <eX>, we know the last tag was </eX> 369 last_position: int = -1 370 for position, tag in tags: 371 if tag.startswith("<e"): # begin tag 372 if self.rng.random() < self.config.get("blank_probability", 0): 373 del text[position:last_position] 374 text.insert(position, self.tokenizer.convert_tokens_to_ids("<blank/>")) 375 text.insert(position, self.tokenizer.convert_tokens_to_ids(tag)) 376 last_position = position 377 378 if mlm: 379 sample.update(self.mlm_features(text, prefix)) 380 381 sample[f"{prefix}text"] = torch.tensor(text, dtype=torch.int32) 382 sample[f"{prefix}entity_positions"] = torch.tensor([ 383 text.index(self.tokenizer.convert_tokens_to_ids("<e1>")), 384 text.index(self.tokenizer.convert_tokens_to_ids("<e2>")) 385 ], dtype=torch.int64) 386 return sample 387 388 def sample_parallel(self) -> Dict[str, Any]: 389 """ Sample two parallel edges and create a positive pair from them. """ 390 while True: 391 first_eid: int = self.sample_main() 392 if self.graph.eid_simple_adjacency(first_eid): 393 # This edge has no parallel edges from which a positive can be selected 394 continue 395 396 adjacency_range: Tuple[int, int] = self.graph.eid_adjacency_range(first_eid) 397 if self.graph.edges[adjacency_range[0], 2] == self.graph.edges[adjacency_range[1]-1, 2]: 398 # All the edges are caused by repetition of an entity in the same sentence 399 continue 400 401 # Avoid the range of parallel edges sharing the same sentence 402 sentence_range: Tuple[int, int] = self.graph.eid_adjacency_range(first_eid, prefix=3) 403 sentence_card: int = sentence_range[1] - sentence_range[0] 404 405 second_eid: int = self.rng.randint(adjacency_range[0], adjacency_range[1]-sentence_card-1) 406 if second_eid >= sentence_range[0]: 407 second_eid += sentence_card 408 409 if self.filter_edge(second_eid): 410 return self.eid_to_sample(first_eid, second_eid, 1) 411 412 def sample_strong_negative(self) -> Dict[str, Any]: 413 """ Sample a strong negative edge around the two given vertices. """ 414 # TODO consider biaising the sampling away from popular entities here too. 415 while True: 416 first_eid: int = self.sample_main() 417 adjacency_range: Tuple[int, int] = self.graph.eid_adjacency_range(first_eid) 418 adjacency_size: int = adjacency_range[1] - adjacency_range[0] 419 vid1: int = self.graph.edges[first_eid, 0] 420 vid2: int = self.graph.edges[first_eid, 1] 421 vertex1_degree: int = self.graph.degree(vid1) 422 vertex2_degree: int = self.graph.degree(vid2) 423 if vertex1_degree + vertex2_degree <= 2 * adjacency_size: 424 # This edge has no other incident edges from which a negative can be selected 425 continue 426 427 second_reid: int = self.rng.randint(0, vertex1_degree + vertex2_degree - 2 * adjacency_size - 1) 428 if second_reid < vertex1_degree - adjacency_size: 429 first_reid_begin: int = self.graph.reid_adjacency_begin(vid1, vid2) 430 if second_reid >= first_reid_begin: 431 second_reid += adjacency_size 432 second_eid: int = self.graph.adj[vid1][second_reid, 1] 433 else: 434 second_reid -= vertex1_degree - adjacency_size 435 first_reid_begin: int = self.graph.reid_adjacency_begin(vid2, vid1) 436 if second_reid >= first_reid_begin: 437 second_reid += adjacency_size 438 second_eid: int = self.graph.adj[vid2][second_reid, 1] 439 440 if self.filter_edge(second_eid): 441 return self.eid_to_sample(first_eid, second_eid, -1) 442 443 def sample_weak_negative(self) -> Dict[str, Any]: 444 while True: 445 first_eid: int = self.sample_main() 446 second_eid: int = self.sample_main() 447 entities: Set[int] = set([ 448 self.graph.edges[first_eid, 0], 449 self.graph.edges[first_eid, 1], 450 self.graph.edges[second_eid, 0], 451 self.graph.edges[second_eid, 1]]) 452 if len(entities) == 4: 453 return self.eid_to_sample(first_eid, second_eid, -1) 454 455 def sample_triplet(self) -> Dict[str, Any]: 456 sample: Dict[str, Any] = {} 457 for prefix in ["first_", "second_", "third_"]: 458 eid: int = self.sample_main() 459 edge: torch.Tensor = self.graph.edges[eid].clone() 460 self.shuffle_entities(edge) 461 sample.update(self.edge_to_features(eid, edge, prefix, mlm=(prefix == "first_" and self.config.get("language_model_weight", 0) > 0))) 462 return sample 463 464 def sample(self) -> Dict[str, Any]: 465 """ Generate a single sample from the dataset. """ 466 if self.config.unsupervised == "mtb": 467 p: float = self.rng.random() 468 if p < self.config.strong_negative_probability: 469 return self.sample_strong_negative() 470 elif p < self.config.strong_negative_probability + self.config.weak_negative_probability: 471 return self.sample_weak_negative() 472 else: 473 return self.sample_parallel() 474 elif self.config.unsupervised == "triplet": 475 return self.sample_triplet() 476 else: 477 raise RuntimeError(f"Unknown unsupervised mode {self.config.unsupervised}.") 478 479 def __iter__(self) -> Iterator[Dict[str, Any]]: 480 """ Generate samples from the dataset. """ 481 worker_info = torch.utils.data.get_worker_info() 482 if worker_info is None: 483 sample_count: int = len(self) 484 else: 485 sample_count: int = len(self) // worker_info.num_workers 486 sample_count += (worker_info.id < (len(self) % worker_info.num_workers)) 487 488 for index in range(sample_count): 489 yield self.sample() 490 491 492 TYPE_MAGIC: Dict[str, torch.utils.data.Dataset] = { 493 "supervised": SupervisedDataset, 494 "fewshot": FewShotDataset, 495 "sampled fewshot": SampledFewShotDataset, 496 "unsupervised": UnsupervisedDataset # not normally used 497 } 498 499 500 class GraphAdapter(UnsupervisedDataset): 501 """ 502 Post-process a Dataset to add graph features. 503 504 The new features include neighborhood_text, neighborhood_entity_identifiers, etc and are extracted from the entity_identifiers features present in the original sample. 505 """ 506 def __init__(self, dataset: torch.utils.data.Dataset, entity_dictionary: gbure.data.dictionary.Dictionary, path: pathlib.Path, graph: Optional[gbure.data.graph.Graph]) -> None: 507 if isinstance(dataset, UnsupervisedDataset) or graph is not None: 508 super().__init__(dataset.config, None, dataset.tokenizer, dataset.evaluation, None) 509 if graph is not None: 510 self.graph = graph 511 else: 512 self.graph = dataset.graph 513 else: 514 super().__init__(dataset.config, path, dataset.tokenizer, dataset.evaluation, None) 515 self.dataset = dataset 516 self.entity_dictionary = entity_dictionary 517 518 def empty_neighborhood(self, prefix: str) -> Dict[str, Any]: 519 neighborhood_size: int = self.config.neighborhood_size 520 if not self.evaluation and self.config.get("filter_empty_neighborhood"): 521 return {} 522 # FIXME We pad to the same number of neighbors for now, since Batcher.process_int_feature does not support neighborhoods of different sizes yet. 523 # Once it is implemented, we can set neighborhood_size = 0 524 return {f"{prefix}edge_identifier": torch.full((neighborhood_size,), -1, dtype=torch.int64), 525 f"{prefix}entity_identifiers": torch.full((neighborhood_size, 2), -1, dtype=torch.int64), 526 f"{prefix}entity_degrees": torch.zeros((neighborhood_size, 2), dtype=torch.int64), 527 f"{prefix}text": [torch.zeros((0,), dtype=torch.int64) for _ in range(neighborhood_size)], 528 f"{prefix}entity_positions": torch.zeros((neighborhood_size, 2), dtype=torch.int64)} 529 530 def sample_neighborhood(self, vid: int, exclude: Optional[int], incoming: bool, prefix: str) -> Dict[str, Any]: 531 """ Sample the neighborhood around the given vertex, excluding a given edge. """ 532 number_reids: int = self.graph.degree(vid) - (0 if exclude is None else 1) 533 if number_reids <= 0: 534 reids: List[int] = [] 535 elif number_reids <= self.config.neighborhood_size: 536 reids: List[int] = list(range(number_reids)) + self.rng.choices(range(number_reids), k=self.config.neighborhood_size-number_reids) 537 else: 538 reids: List[int] = self.rng.sample(range(number_reids), self.config.neighborhood_size) 539 540 neighbors: List[Dict[str, Any]] = [] 541 for reid in reids: 542 eid: int = self.graph.adj[vid][reid, 1] 543 if exclude is not None and eid == exclude: 544 eid = self.graph.adj[vid][-1, 1] 545 edge: torch.Tensor = self.graph.edges[eid].clone() 546 if edge[int(incoming)] != vid: 547 self.invert_entities(edge) 548 neighbors.append(self.edge_to_features(eid, edge, "", mlm=False)) 549 550 if not neighbors: 551 return self.empty_neighborhood(prefix) 552 553 sample: Dict[str, Any] = {} 554 for feature in neighbors[0].keys(): 555 if feature == "text": 556 sample[f"{prefix}text"] = [neighbor["text"] for neighbor in neighbors] 557 else: 558 sample[f"{prefix}{feature}"] = torch.stack([neighbor[feature] for neighbor in neighbors]) 559 return sample 560 561 def sample_neighborhoods(self, head: int, tail: int, eid: Optional[int], prefix: str) -> Dict[str, Any]: 562 """ 563 Sample the neighborhood around the given edge. 564 565 edge should be self.graph.edges[eid], optionaly with the entities reversed. 566 """ 567 head: Optional[int] = self.graph.entity_dictionary.encoder.get(self.entity_dictionary.decode(head)) 568 tail: Optional[int] = self.graph.entity_dictionary.encoder.get(self.entity_dictionary.decode(tail)) 569 570 sample: Dict[str, Any] = {} 571 if head is None: 572 e1 = self.empty_neighborhood(f"{prefix}e1_neighborhood_") 573 e1_degree: int = 0 574 else: 575 e1 = self.sample_neighborhood(head, eid, True, f"{prefix}e1_neighborhood_") 576 e1_degree: int = self.graph.degree(head) 577 if not e1: 578 return {} 579 sample.update(e1) 580 if tail is None: 581 e2 = self.empty_neighborhood(f"{prefix}e2_neighborhood_") 582 e2_degree: int = 0 583 else: 584 e2 = self.sample_neighborhood(tail, eid, True, f"{prefix}e2_neighborhood_") 585 e2_degree: int = self.graph.degree(tail) 586 if not e2: 587 return {} 588 sample.update(e2) 589 sample[f"{prefix}entity_degrees"] = torch.tensor([e1_degree, e2_degree], dtype=torch.int64) 590 return sample 591 592 def adapt(self, sample: Dict[str, Any]) -> bool: 593 """ 594 Add neighborhood features to sample. 595 596 Returns whether the sample should be kept. 597 """ 598 extras: Dict[str, Any] = {} 599 for feature in sample: 600 if feature.endswith("entity_identifiers"): 601 prefix: str = feature[:-len("entity_identifiers")] 602 entity_identifiers: torch.Tensor = sample[f"{prefix}entity_identifiers"] 603 604 # If the underlying dataset is built upon a graph, we can exclude the main sample edge, otherwise there is no risk of sampling an edge as being its own neighbor. 605 eid: Optional[Union[int, torch.Tensor]] = sample.get(f"{prefix}edge_identifier") 606 607 if prefix == "candidates_": 608 prefix_extras: Dict[str, Any] = collections.defaultdict(list) 609 for i, way in enumerate(entity_identifiers): 610 extras_way: Dict[str, List[Any]] = collections.defaultdict(list) 611 for j, shot in enumerate(way): 612 eid: Optional[int] = None if eid is None else eid[i, j].item() 613 extras_shot: Dict[str, Any] = self.sample_neighborhoods(shot[0], shot[1], eid, prefix) 614 if not extras_shot: 615 return False 616 for feature, value in extras_shot.items(): 617 extras_way[feature].append(value) 618 for feature, values in extras_way.items(): 619 prefix_extras[feature].append(values if feature.endswith("text") else torch.stack(values)) 620 for feature, values in prefix_extras.items(): 621 prefix_extras[feature] = values if feature.endswith("text") else torch.stack(values) 622 else: 623 prefix_extras: Dict[str, Any] = self.sample_neighborhoods(entity_identifiers[0], entity_identifiers[1], eid, prefix) 624 if prefix_extras: 625 extras.update(prefix_extras) 626 else: 627 return False 628 if f"{prefix}entity_degrees" not in sample and f"{prefix}entity_degrees" not in extras: 629 extras[f"{prefix}entity_degrees"] = torch.zeros_like(extras[f"{prefix}entity_identifiers"], dtype=torch.int64) 630 sample.update(extras) 631 return True 632 633 def __len__(self) -> int: 634 return len(self.dataset) 635 636 def process_sample(self, sample: Dict[str, Any]) -> Iterator[Dict[str, Any]]: 637 # TODO define config value to repeat the sampling of neighbors 638 if self.config.get("neighborhood_size", 0) > 0: 639 if self.adapt(sample): 640 yield sample 641 else: 642 yield sample 643 644 def __iter__(self) -> Iterator[Dict[str, Any]]: 645 if isinstance(self.dataset, torch.utils.data.IterableDataset): 646 for sample in self.dataset: 647 yield from self.process_sample(sample) 648 else: # Map-style dataset 649 worker_info = torch.utils.data.get_worker_info() 650 if worker_info is None: 651 worker_modulo: int = 1 652 worker_residue: int = 0 653 else: 654 worker_modulo: int = worker_info.num_workers 655 worker_residue: int = worker_info.id 656 657 for i in range(worker_residue, len(self.dataset), worker_modulo): 658 yield from self.process_sample(self.dataset[i]) 659 660 661 def load_dataset(config: gbure.utils.dotdict, split: str, path: pathlib.Path, **kwargs) -> torch.utils.data.Dataset: 662 if split == "train" and config.get("unsupervised"): 663 return UnsupervisedDataset(config=config, path=path, **kwargs) 664 665 dstype: str 666 data: Any 667 dstype, data = torch.load(path) 668 return TYPE_MAGIC[dstype](config=config, path=path, data=data, **kwargs)