gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

dataset.py (32424B)


      1 from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union
      2 import collections
      3 import pathlib
      4 import random
      5 
      6 import torch
      7 import transformers
      8 
      9 import gbure.data.dictionary
     10 from gbure.data.graph import Graph
     11 import gbure.utils
     12 
     13 
     14 class SupervisedDataset(torch.utils.data.Dataset):
     15     """
     16     Read a preprocessed supervised relation extraction dataset.
     17 
     18     A preprocessed dataset can be created from the gbure.data.prepare_* scripts.
     19     """
     20     shuffleable: bool = True
     21 
     22     def __init__(self, config: gbure.utils.dotdict, path: pathlib.Path, tokenizer: transformers.PreTrainedTokenizer, evaluation: bool, rng: Optional[random.Random] = None, data: Optional[List[Tuple[torch.Tensor, int, int, int]]] = None) -> None:
     23         """ Initialize a supervised dataset and load the data in RAM. """
     24         super().__init__()
     25 
     26         self.config: gbure.utils.dotdict = config
     27         self.path: pathlib.Path = path
     28         self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
     29         self.evaluation: bool = evaluation
     30         if data is None:
     31             self.load()
     32         else:
     33             self.data = data
     34 
     35     def load(self) -> None:
     36         """ Load the dataset into RAM. """
     37         dstype: str
     38         self.data: List[Tuple[torch.Tensor, int, int, int]]
     39         dstype, self.data = torch.load(self.path)
     40         assert(dstype == "supervised")
     41 
     42     def __len__(self) -> int:
     43         """ Get the number of samples in the dataset. """
     44         return len(self.data)
     45 
     46     def __getitem__(self, index: int) -> Dict[str, Any]:
     47         """ Get the sample at the given index. """
     48         sample: Dict[str, Any] = {}
     49         sample["text"] = self.data[index][0]
     50         sample["entity_positions"] = torch.tensor(self.data[index][1:3], dtype=torch.int64)
     51         sample["relation"] = self.data[index][3]
     52         return sample
     53 
     54 
     55 class SampledFewShotDataset(torch.utils.data.Dataset):
     56     """
     57     Read a preprocessed few shot relation extraction dataset.npy file containing samples.
     58 
     59     A preprocessed dataset can be created from the gbure.data.prepare_* scripts.
     60     """
     61     shuffleable: bool = True
     62 
     63     def __init__(self, config: gbure.utils.dotdict, path: pathlib.Path, tokenizer: transformers.PreTrainedTokenizer, evaluation: bool, rng: Optional[random.Random] = None, data: Optional[List[Tuple[torch.Tensor, int, int, int, int, List[List[torch.Tensor]], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]]] = None) -> None:
     64         """ Initialize a few shot dataset and load the samples in RAM. """
     65         super().__init__()
     66 
     67         self.config: gbure.utils.dotdict = config
     68         self.path: pathlib.Path = path
     69         self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
     70         self.evaluation: bool = evaluation
     71         if data is None:
     72             self.load()
     73         else:
     74             self.data = data
     75 
     76     def load(self) -> None:
     77         """ Load the dataset into RAM. """
     78         dstype: str
     79         self.data: List[Tuple[torch.Tensor, int, int, int, int, List[List[torch.Tensor]], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]]
     80         dstype, self.data = torch.load(self.path)
     81         assert(dstype == "sampled fewshot")
     82 
     83     def __len__(self) -> int:
     84         """ Get the number of samples in the dataset. """
     85         return len(self.data)
     86 
     87     def __getitem__(self, index: int) -> Dict[str, Any]:
     88         """ Get the sample at the given index. """
     89         sample: Dict[str, Any] = {}
     90         sample["query_text"] = self.data[index][0]
     91         sample["query_entity_positions"] = torch.tensor(self.data[index][1:3], dtype=torch.int64)
     92         sample["query_entity_identifiers"] = torch.tensor(self.data[index][3:5], dtype=torch.int64)
     93         sample["candidates_text"] = self.data[index][5]
     94         sample["candidates_entity_positions"] = torch.stack(self.data[index][6:8], dim=2)
     95         sample["candidates_entity_identifiers"] = torch.stack(self.data[index][8:10], dim=2)
     96         sample["answer"] = self.data[index][10]
     97         return sample
     98 
     99 
    100 class FewShotDataset(torch.utils.data.IterableDataset):
    101     """
    102     Read a preprocessed few shot relation extraction dataset.npy file.
    103 
    104     A preprocessed dataset can be created from the gbure.data.prepare_*
    105     modules.
    106 
    107     Config:
    108         seed: the seed for the random number generator
    109         shot: the number of candidates per relation
    110         way: the number of relation classes used for candidates
    111     """
    112     shuffleable: bool = False  # FIXME ?
    113 
    114     def __init__(self, config: gbure.utils.dotdict, path: pathlib.Path, tokenizer: transformers.PreTrainedTokenizer, evaluation: bool, rng: Optional[random.Random] = None, data: Optional[List[List[Tuple[torch.Tensor, int, int, int, int, int]]]] = None) -> None:
    115         """ Initialize a few shot dataset and load the data in RAM. """
    116         super().__init__()
    117 
    118         self.config: gbure.utils.dotdict = config
    119         self.path: pathlib.Path = path
    120         self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
    121         self.evaluation: bool = evaluation
    122 
    123         if data is None:
    124             self.load()
    125         else:
    126             self.data = data
    127         self.num_relations: int = len(self.data)
    128         self.num_samples_per_relation: int = len(self.data[0])
    129 
    130     def init_seed(self, worker_id: Optional[int] = None) -> None:
    131         """ Initialize the RNG. """
    132         if not self.evaluation:
    133             seed: int = self.config.seed
    134             worker_info = torch.utils.data.get_worker_info()
    135             seed += worker_id if worker_id is not None else (worker_info.id if worker_info is not None else 0)
    136             rng = random.Random(seed)
    137             self.rng = rng
    138 
    139     def load(self) -> None:
    140         """ Load the dataset into RAM. """
    141         dstype: str
    142         self.data: List[List[Tuple[torch.Tensor, int, int, int, int, int]]]
    143         dstype, self.data = torch.load(self.path)
    144         assert(dstype == "fewshot")
    145 
    146     def __len__(self):
    147         """ Get the number of samples in the dataset. """
    148         return self.num_relations * self.num_samples_per_relation * self.config.get("meta_per_sample", 1)
    149 
    150     def get_rng(self, relation: int, sentence: int) -> random.Random:
    151         """ Get the random number generator for the given query. """
    152         if self.evaluation:
    153             return random.Random(self.config.seed * len(self) + relation * self.num_samples_per_relation + sentence)
    154         else:
    155             return self.rng
    156 
    157     @staticmethod
    158     def sample_exclude(rng: random.Random, population: int, exclude: int, size: int) -> List[int]:
    159         """ Chooses size unique random elements from [0, population)\\{exclude}. """
    160         samples: List[int] = rng.sample(range(population-1), size)
    161         return [sample + (1 if sample >= exclude else 0) for sample in samples]
    162 
    163     def sample_meta(self, query_relation: int, query_sid: int) -> Dict[str, Any]:
    164         """ Build a fewshot sample from the given query. """
    165         rng: random.Random = self.get_rng(query_relation, query_sid)
    166 
    167         # positives
    168         candidates: List[List[Tuple[int, int]]] = [[(query_relation, sid) for sid in self.sample_exclude(rng, self.num_samples_per_relation, query_sid, self.config.shot)]]
    169         # negatives
    170         for negative_relation in self.sample_exclude(rng, self.num_relations, query_relation, self.config.way-1):
    171             candidates.append([(negative_relation, sid) for sid in rng.sample(range(self.num_samples_per_relation), self.config.shot)])
    172 
    173         order: List[int] = list(range(self.config.way))
    174         rng.shuffle(order)
    175         candidates = [candidates[i] for i in order]
    176         answer = order.index(0)
    177 
    178         meta: Dict[str, Any] = {}
    179         meta[f"query_text"] = self.data[query_relation][query_sid][0]
    180         meta[f"query_entity_positions"] = torch.tensor(self.data[query_relation][query_sid][1:3], dtype=torch.int64)
    181         meta[f"query_relation"] = self.data[query_relation][query_sid][3]
    182         meta[f"query_entity_identifiers"] = torch.tensor(self.data[query_relation][query_sid][4:6], dtype=torch.int64)
    183         meta[f"candidates_text"] = [[self.data[shot_relation][shot_sid][0] for shot_relation, shot_sid in way] for way in candidates]
    184         meta[f"candidates_entity_positions"] = torch.tensor([[self.data[shot_relation][shot_sid][1:3] for shot_relation, shot_sid in way] for way in candidates], dtype=torch.int64)
    185         meta[f"candidates_relation"] = torch.tensor([[self.data[shot_relation][shot_sid][3] for shot_relation, shot_sid in way] for way in candidates], dtype=torch.int64)
    186         meta[f"candidates_entity_identifiers"] = torch.tensor([[self.data[shot_relation][shot_sid][4:6] for shot_relation, shot_sid in way] for way in candidates], dtype=torch.int64)
    187         meta["answer"] = answer
    188         return meta
    189 
    190     def __iter__(self) -> Iterator[Dict[str, Any]]:
    191         """ Generate samples from the dataset. """
    192         self.order: List[Tuple[int, int]] = [(relation, sid) for relation in range(self.num_relations) for sid in range(self.num_samples_per_relation)]
    193         if not self.evaluation:
    194             self.rng.shuffle(self.order)
    195 
    196         worker_info = torch.utils.data.get_worker_info()
    197         if worker_info is None:
    198             worker_modulo: int = 1
    199             worker_residue: int = 0
    200         else:
    201             worker_modulo: int = worker_info.num_workers
    202             worker_residue: int = worker_info.id
    203 
    204         mps = self.config.get("meta_per_sample", 1)
    205         for index, (relation, sid) in enumerate(self.order):
    206             for j in range(mps):
    207                 if (index*mps+j) % worker_modulo == worker_residue:
    208                     yield self.sample_meta(relation, sid)
    209 
    210 
    211 class UnsupervisedDataset(torch.utils.data.IterableDataset):
    212     """
    213     Read a preprocessed unsupervised relation extraction dataset.
    214 
    215     A preprocessed dataset can be created from the gbure.data.prepare_* scripts.
    216 
    217     Config:
    218         blank_probability: the probability to replace an entity with <blank/>
    219         edge_sampling: the sampling strategy to avoid (or not) popular entities
    220         sample_per_epoch: the number of sample in an epoch
    221         seed: the seed for the random number generator
    222     """
    223     shuffleable: bool = False
    224 
    225     def __init__(self, config: gbure.utils.dotdict, path: Optional[pathlib.Path], tokenizer: transformers.PreTrainedTokenizer, evaluation: bool, rng: Optional[random.Random] = None) -> None:
    226         """ Initialize a supervised dataset and load the data in RAM. """
    227         super().__init__()
    228 
    229         self.config: gbure.utils.dotdict = config
    230         self.path: Optional[pathlib.Path] = path
    231         self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
    232         self.evaluation: bool = evaluation
    233         self.load()
    234         self.init_seed()
    235 
    236     def init_seed(self, worker_id: Optional[int] = None) -> None:
    237         """ Initialize the RNG. """
    238         seed: int = self.config.seed
    239         worker_info = torch.utils.data.get_worker_info()
    240         seed += worker_id if worker_id is not None else (worker_info.id if worker_info is not None else 0)
    241         rng = random.Random(seed)
    242         self.rng = rng
    243 
    244     def load(self) -> None:
    245         """ Load the dataset into RAM. """
    246         if self.path is not None:
    247             self.graph = Graph(path=self.path)
    248             if self.config.get("share_memory"):
    249                 self.graph.share_memory()
    250 
    251     def __len__(self) -> int:
    252         """ Get the number of samples in the dataset. """
    253         return self.config.sample_per_epoch
    254 
    255     def filter_edge(self, eid: int) -> bool:
    256         """ Filter edges according to the length of the corresponding sentence and the size of its neighborhoods. """
    257         edge: torch.Tensor = self.graph.edges[eid]
    258         if self.graph.sentences[edge[2]].shape[0] > self.config.max_sentence_length:
    259             return False
    260         if self.config.get("filter_empty_neighborhood") and (self.graph.degree(edge[0]) <= 1 or self.graph.degree(edge[1]) <= 1):
    261             return False
    262         return True
    263 
    264     def sample_main(self) -> int:
    265         """ Sample the main edge, from which positive and negative edges can be selected. """
    266         # From Soares et al.
    267         # "To prevent a large bias towards relation statements that involve popular entities, we limit the number of relation statements that contain the same entity by randomly sampling a constant number of relation statements that contain any given entity."
    268         # It's hard to guess what was exactly done, so we propose several sampling strategies.
    269         while True:
    270             if self.config.edge_sampling == "uniform-uniform":
    271                 vid: int = self.rng.randint(0, self.graph.order-1)
    272                 reid: int = self.rng.randint(0, self.graph.degree(vid)-1)
    273                 eid: int = self.graph.adj[vid][reid, 1]
    274             elif self.config.edge_sampling == "uniform-inverse degree":
    275                 vid: int = self.rng.randint(0, self.graph.order-1)
    276 
    277                 v2_candidates: torch.Tensor = torch.zeros(self.graph.degree(vid))
    278                 for i, edge in enumerate(self.graph.adj[vid]):
    279                     v2_candidates[i] = self.graph.degree(edge[0])
    280                 v2_candidates /= torch.nn.functional.normalize(v2_candidates, p=1, dim=0)
    281 
    282                 # FIXME slow, double check worker asynchronicity
    283                 reid: int = torch.multinomial(v2_candidates, 1).item()
    284                 eid: int = self.graph.adj[vid][reid, 1]
    285             else:
    286                 raise RuntimeError("Unsuported config value for edge_sampling")
    287             if self.filter_edge(eid):
    288                 return eid
    289 
    290     def eid_to_sample(self, first_eid: int, second_eid: int, polarity: int) -> Dict[str, Any]:
    291         """ Build a pair with the given polarity from two edge ids. """
    292         first_edge: torch.Tensor = self.graph.edges[first_eid].clone()
    293         second_edge: torch.Tensor = self.graph.edges[second_eid].clone()
    294 
    295         self.shuffle_entities(first_edge)
    296         self.align_entities_as(second_edge, first_edge)
    297 
    298         sample: Dict[str, Any] = {"polarity": polarity}
    299         sample.update(self.edge_to_features(first_eid, first_edge, "first_", mlm=True))
    300         sample.update(self.edge_to_features(second_eid, second_edge, "second_", mlm=False))
    301         return sample
    302 
    303     @staticmethod
    304     def invert_entities(edge: torch.Tensor) -> None:
    305         """ Invert the <e1> and <e2> tags of the edge, the text of the entities are not inverted, only the tags. """
    306         # invert vertex ids
    307         tmp = edge[0].clone()
    308         edge[0] = edge[1]
    309         edge[1] = tmp
    310 
    311         # invert entity positions
    312         tmp = edge[3:5].clone()
    313         edge[3:5] = edge[5:7]
    314         edge[5:7] = tmp
    315 
    316     def shuffle_entities(self, edge: torch.Tensor) -> None:
    317         """ Invert the <e1> and <e2> tags with probability ½. """
    318         if self.rng.randint(0, 1):
    319             self.invert_entities(edge)
    320 
    321     @staticmethod
    322     def align_entities_as(edge: torch.Tensor, pattern: torch.Tensor) -> None:
    323         """ Invert entities of an edge if neither of them are in the same position as in the provided pattern. """
    324         if edge[0] != pattern[0] and edge[1] != pattern[1]:
    325             UnsupervisedDataset.invert_entities(edge)
    326 
    327     def mlm_features(self, text: List[int], prefix: str) -> Dict[str, Any]:
    328         """ Extract mlm_input and mlm_target for masked language model loss. """
    329         # Function inspired by HuggingFace's code
    330         mlm_target = torch.tensor(text, dtype=torch.int64)
    331         mlm_input = torch.tensor(text, dtype=torch.int64)
    332 
    333         # Do not mask special tokens
    334         st_mask = self.tokenizer.get_special_tokens_mask(text, already_has_special_tokens=True)
    335         st_mask = torch.tensor(st_mask, dtype=torch.bool)
    336 
    337         mlm_p = torch.full((len(text),), self.config.mlm_probability)
    338         mlm_mask = torch.bernoulli(mlm_p).bool() & st_mask
    339         mlm_target[~mlm_mask] = -100
    340 
    341         masked_p = self.config.mlm_masked_probability
    342         masked_mask = torch.bernoulli(torch.full((len(text),), masked_p)).bool() & mlm_mask
    343         mlm_input[masked_mask] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
    344 
    345         random_p = self.config.mlm_random_probability / (1-masked_p)
    346         random_mask = torch.bernoulli(torch.full((len(text),), random_p)).bool() & mlm_mask & ~masked_mask
    347         random_value = torch.randint(len(self.tokenizer), (len(text),), dtype=torch.long)
    348         mlm_input[random_mask] = random_value[random_mask]
    349 
    350         return {f"{prefix}mlm_input": mlm_input, f"{prefix}mlm_target": mlm_target}
    351 
    352     def edge_to_features(self, eid: int, edge: torch.Tensor, prefix: str, mlm: bool) -> Dict[str, Any]:
    353         """
    354         Convert an edge to the corresponding set of features (token list of the sentence, etc).
    355 
    356         If mlm is True, features for Masked Language Model training are also generated.
    357         """
    358         sample: Dict[str, Any] = {}
    359         sample[f"{prefix}edge_identifier"] = eid
    360         sample[f"{prefix}entity_identifiers"] = edge[0:2]
    361         sample[f"{prefix}entity_degrees"] = torch.tensor([self.graph.degree(edge[0]), self.graph.degree(edge[1])], dtype=torch.int64)
    362         text: List[int] = self.graph.sentences[edge[2]].tolist()
    363 
    364         # Abuse the fact that "</eX>" < "<eX>"
    365         tags: List[Tuple[int, str]] = [(edge[3], "<e1>"), (edge[4], "</e1>"), (edge[5], "<e2>"), (edge[6], "</e2>")]
    366         tags.sort(reverse=True)
    367 
    368         # When we see a start tag <eX>, we know the last tag was </eX>
    369         last_position: int = -1
    370         for position, tag in tags:
    371             if tag.startswith("<e"):  # begin tag
    372                 if self.rng.random() < self.config.get("blank_probability", 0):
    373                     del text[position:last_position]
    374                     text.insert(position, self.tokenizer.convert_tokens_to_ids("<blank/>"))
    375             text.insert(position, self.tokenizer.convert_tokens_to_ids(tag))
    376             last_position = position
    377 
    378         if mlm:
    379             sample.update(self.mlm_features(text, prefix))
    380 
    381         sample[f"{prefix}text"] = torch.tensor(text, dtype=torch.int32)
    382         sample[f"{prefix}entity_positions"] = torch.tensor([
    383                 text.index(self.tokenizer.convert_tokens_to_ids("<e1>")),
    384                 text.index(self.tokenizer.convert_tokens_to_ids("<e2>"))
    385             ], dtype=torch.int64)
    386         return sample
    387 
    388     def sample_parallel(self) -> Dict[str, Any]:
    389         """ Sample two parallel edges and create a positive pair from them. """
    390         while True:
    391             first_eid: int = self.sample_main()
    392             if self.graph.eid_simple_adjacency(first_eid):
    393                 # This edge has no parallel edges from which a positive can be selected
    394                 continue
    395 
    396             adjacency_range: Tuple[int, int] = self.graph.eid_adjacency_range(first_eid)
    397             if self.graph.edges[adjacency_range[0], 2] == self.graph.edges[adjacency_range[1]-1, 2]:
    398                 # All the edges are caused by repetition of an entity in the same sentence
    399                 continue
    400 
    401             # Avoid the range of parallel edges sharing the same sentence
    402             sentence_range: Tuple[int, int] = self.graph.eid_adjacency_range(first_eid, prefix=3)
    403             sentence_card: int = sentence_range[1] - sentence_range[0]
    404 
    405             second_eid: int = self.rng.randint(adjacency_range[0], adjacency_range[1]-sentence_card-1)
    406             if second_eid >= sentence_range[0]:
    407                 second_eid += sentence_card
    408 
    409             if self.filter_edge(second_eid):
    410                 return self.eid_to_sample(first_eid, second_eid, 1)
    411 
    412     def sample_strong_negative(self) -> Dict[str, Any]:
    413         """ Sample a strong negative edge around the two given vertices. """
    414         # TODO consider biaising the sampling away from popular entities here too.
    415         while True:
    416             first_eid: int = self.sample_main()
    417             adjacency_range: Tuple[int, int] = self.graph.eid_adjacency_range(first_eid)
    418             adjacency_size: int = adjacency_range[1] - adjacency_range[0]
    419             vid1: int = self.graph.edges[first_eid, 0]
    420             vid2: int = self.graph.edges[first_eid, 1]
    421             vertex1_degree: int = self.graph.degree(vid1)
    422             vertex2_degree: int = self.graph.degree(vid2)
    423             if vertex1_degree + vertex2_degree <= 2 * adjacency_size:
    424                 # This edge has no other incident edges from which a negative can be selected
    425                 continue
    426 
    427             second_reid: int = self.rng.randint(0, vertex1_degree + vertex2_degree - 2 * adjacency_size - 1)
    428             if second_reid < vertex1_degree - adjacency_size:
    429                 first_reid_begin: int = self.graph.reid_adjacency_begin(vid1, vid2)
    430                 if second_reid >= first_reid_begin:
    431                     second_reid += adjacency_size
    432                 second_eid: int = self.graph.adj[vid1][second_reid, 1]
    433             else:
    434                 second_reid -= vertex1_degree - adjacency_size
    435                 first_reid_begin: int = self.graph.reid_adjacency_begin(vid2, vid1)
    436                 if second_reid >= first_reid_begin:
    437                     second_reid += adjacency_size
    438                 second_eid: int = self.graph.adj[vid2][second_reid, 1]
    439 
    440             if self.filter_edge(second_eid):
    441                 return self.eid_to_sample(first_eid, second_eid, -1)
    442 
    443     def sample_weak_negative(self) -> Dict[str, Any]:
    444         while True:
    445             first_eid: int = self.sample_main()
    446             second_eid: int = self.sample_main()
    447             entities: Set[int] = set([
    448                     self.graph.edges[first_eid, 0],
    449                     self.graph.edges[first_eid, 1],
    450                     self.graph.edges[second_eid, 0],
    451                     self.graph.edges[second_eid, 1]])
    452             if len(entities) == 4:
    453                 return self.eid_to_sample(first_eid, second_eid, -1)
    454 
    455     def sample_triplet(self) -> Dict[str, Any]:
    456         sample: Dict[str, Any] = {}
    457         for prefix in ["first_", "second_", "third_"]:
    458             eid: int = self.sample_main()
    459             edge: torch.Tensor = self.graph.edges[eid].clone()
    460             self.shuffle_entities(edge)
    461             sample.update(self.edge_to_features(eid, edge, prefix, mlm=(prefix == "first_" and self.config.get("language_model_weight", 0) > 0)))
    462         return sample
    463 
    464     def sample(self) -> Dict[str, Any]:
    465         """ Generate a single sample from the dataset. """
    466         if self.config.unsupervised == "mtb":
    467             p: float = self.rng.random()
    468             if p < self.config.strong_negative_probability:
    469                 return self.sample_strong_negative()
    470             elif p < self.config.strong_negative_probability + self.config.weak_negative_probability:
    471                 return self.sample_weak_negative()
    472             else:
    473                 return self.sample_parallel()
    474         elif self.config.unsupervised == "triplet":
    475             return self.sample_triplet()
    476         else:
    477             raise RuntimeError(f"Unknown unsupervised mode {self.config.unsupervised}.")
    478 
    479     def __iter__(self) -> Iterator[Dict[str, Any]]:
    480         """ Generate samples from the dataset. """
    481         worker_info = torch.utils.data.get_worker_info()
    482         if worker_info is None:
    483             sample_count: int = len(self)
    484         else:
    485             sample_count: int = len(self) // worker_info.num_workers
    486             sample_count += (worker_info.id < (len(self) % worker_info.num_workers))
    487 
    488         for index in range(sample_count):
    489             yield self.sample()
    490 
    491 
    492 TYPE_MAGIC: Dict[str, torch.utils.data.Dataset] = {
    493         "supervised": SupervisedDataset,
    494         "fewshot": FewShotDataset,
    495         "sampled fewshot": SampledFewShotDataset,
    496         "unsupervised": UnsupervisedDataset  # not normally used
    497     }
    498 
    499 
    500 class GraphAdapter(UnsupervisedDataset):
    501     """
    502     Post-process a Dataset to add graph features.
    503 
    504     The new features include neighborhood_text, neighborhood_entity_identifiers, etc and are extracted from the entity_identifiers features present in the original sample.
    505     """
    506     def __init__(self, dataset: torch.utils.data.Dataset, entity_dictionary: gbure.data.dictionary.Dictionary, path: pathlib.Path, graph: Optional[gbure.data.graph.Graph]) -> None:
    507         if isinstance(dataset, UnsupervisedDataset) or graph is not None:
    508             super().__init__(dataset.config, None, dataset.tokenizer, dataset.evaluation, None)
    509             if graph is not None:
    510                 self.graph = graph
    511             else:
    512                 self.graph = dataset.graph
    513         else:
    514             super().__init__(dataset.config, path, dataset.tokenizer, dataset.evaluation, None)
    515         self.dataset = dataset
    516         self.entity_dictionary = entity_dictionary
    517 
    518     def empty_neighborhood(self, prefix: str) -> Dict[str, Any]:
    519         neighborhood_size: int = self.config.neighborhood_size
    520         if not self.evaluation and self.config.get("filter_empty_neighborhood"):
    521             return {}
    522         # FIXME We pad to the same number of neighbors for now, since Batcher.process_int_feature does not support neighborhoods of different sizes yet.
    523         # Once it is implemented, we can set neighborhood_size = 0
    524         return {f"{prefix}edge_identifier": torch.full((neighborhood_size,), -1, dtype=torch.int64),
    525                 f"{prefix}entity_identifiers": torch.full((neighborhood_size, 2), -1, dtype=torch.int64),
    526                 f"{prefix}entity_degrees": torch.zeros((neighborhood_size, 2), dtype=torch.int64),
    527                 f"{prefix}text": [torch.zeros((0,), dtype=torch.int64) for _ in range(neighborhood_size)],
    528                 f"{prefix}entity_positions": torch.zeros((neighborhood_size, 2), dtype=torch.int64)}
    529 
    530     def sample_neighborhood(self, vid: int, exclude: Optional[int], incoming: bool, prefix: str) -> Dict[str, Any]:
    531         """ Sample the neighborhood around the given vertex, excluding a given edge. """
    532         number_reids: int = self.graph.degree(vid) - (0 if exclude is None else 1)
    533         if number_reids <= 0:
    534             reids: List[int] = []
    535         elif number_reids <= self.config.neighborhood_size:
    536             reids: List[int] = list(range(number_reids)) + self.rng.choices(range(number_reids), k=self.config.neighborhood_size-number_reids)
    537         else:
    538             reids: List[int] = self.rng.sample(range(number_reids), self.config.neighborhood_size)
    539 
    540         neighbors: List[Dict[str, Any]] = []
    541         for reid in reids:
    542             eid: int = self.graph.adj[vid][reid, 1]
    543             if exclude is not None and eid == exclude:
    544                 eid = self.graph.adj[vid][-1, 1]
    545             edge: torch.Tensor = self.graph.edges[eid].clone()
    546             if edge[int(incoming)] != vid:
    547                 self.invert_entities(edge)
    548             neighbors.append(self.edge_to_features(eid, edge, "", mlm=False))
    549 
    550         if not neighbors:
    551             return self.empty_neighborhood(prefix)
    552 
    553         sample: Dict[str, Any] = {}
    554         for feature in neighbors[0].keys():
    555             if feature == "text":
    556                 sample[f"{prefix}text"] = [neighbor["text"] for neighbor in neighbors]
    557             else:
    558                 sample[f"{prefix}{feature}"] = torch.stack([neighbor[feature] for neighbor in neighbors])
    559         return sample
    560 
    561     def sample_neighborhoods(self, head: int, tail: int, eid: Optional[int], prefix: str) -> Dict[str, Any]:
    562         """
    563         Sample the neighborhood around the given edge.
    564 
    565         edge should be self.graph.edges[eid], optionaly with the entities reversed.
    566         """
    567         head: Optional[int] = self.graph.entity_dictionary.encoder.get(self.entity_dictionary.decode(head))
    568         tail: Optional[int] = self.graph.entity_dictionary.encoder.get(self.entity_dictionary.decode(tail))
    569 
    570         sample: Dict[str, Any] = {}
    571         if head is None:
    572             e1 = self.empty_neighborhood(f"{prefix}e1_neighborhood_")
    573             e1_degree: int = 0
    574         else:
    575             e1 = self.sample_neighborhood(head, eid, True, f"{prefix}e1_neighborhood_")
    576             e1_degree: int = self.graph.degree(head)
    577         if not e1:
    578             return {}
    579         sample.update(e1)
    580         if tail is None:
    581             e2 = self.empty_neighborhood(f"{prefix}e2_neighborhood_")
    582             e2_degree: int = 0
    583         else:
    584             e2 = self.sample_neighborhood(tail, eid, True, f"{prefix}e2_neighborhood_")
    585             e2_degree: int = self.graph.degree(tail)
    586         if not e2:
    587             return {}
    588         sample.update(e2)
    589         sample[f"{prefix}entity_degrees"] = torch.tensor([e1_degree, e2_degree], dtype=torch.int64)
    590         return sample
    591 
    592     def adapt(self, sample: Dict[str, Any]) -> bool:
    593         """
    594         Add neighborhood features to sample.
    595 
    596         Returns whether the sample should be kept.
    597         """
    598         extras: Dict[str, Any] = {}
    599         for feature in sample:
    600             if feature.endswith("entity_identifiers"):
    601                 prefix: str = feature[:-len("entity_identifiers")]
    602                 entity_identifiers: torch.Tensor = sample[f"{prefix}entity_identifiers"]
    603 
    604                 # If the underlying dataset is built upon a graph, we can exclude the main sample edge, otherwise there is no risk of sampling an edge as being its own neighbor.
    605                 eid: Optional[Union[int, torch.Tensor]] = sample.get(f"{prefix}edge_identifier")
    606 
    607                 if prefix == "candidates_":
    608                     prefix_extras: Dict[str, Any] = collections.defaultdict(list)
    609                     for i, way in enumerate(entity_identifiers):
    610                         extras_way: Dict[str, List[Any]] = collections.defaultdict(list)
    611                         for j, shot in enumerate(way):
    612                             eid: Optional[int] = None if eid is None else eid[i, j].item()
    613                             extras_shot: Dict[str, Any] = self.sample_neighborhoods(shot[0], shot[1], eid, prefix)
    614                             if not extras_shot:
    615                                 return False
    616                             for feature, value in extras_shot.items():
    617                                 extras_way[feature].append(value)
    618                         for feature, values in extras_way.items():
    619                             prefix_extras[feature].append(values if feature.endswith("text") else torch.stack(values))
    620                     for feature, values in prefix_extras.items():
    621                         prefix_extras[feature] = values if feature.endswith("text") else torch.stack(values)
    622                 else:
    623                     prefix_extras: Dict[str, Any] = self.sample_neighborhoods(entity_identifiers[0], entity_identifiers[1], eid, prefix)
    624                 if prefix_extras:
    625                     extras.update(prefix_extras)
    626                 else:
    627                     return False
    628                 if f"{prefix}entity_degrees" not in sample and f"{prefix}entity_degrees" not in extras:
    629                     extras[f"{prefix}entity_degrees"] = torch.zeros_like(extras[f"{prefix}entity_identifiers"], dtype=torch.int64)
    630         sample.update(extras)
    631         return True
    632 
    633     def __len__(self) -> int:
    634         return len(self.dataset)
    635 
    636     def process_sample(self, sample: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
    637         # TODO define config value to repeat the sampling of neighbors
    638         if self.config.get("neighborhood_size", 0) > 0:
    639             if self.adapt(sample):
    640                 yield sample
    641         else:
    642             yield sample
    643 
    644     def __iter__(self) -> Iterator[Dict[str, Any]]:
    645         if isinstance(self.dataset, torch.utils.data.IterableDataset):
    646             for sample in self.dataset:
    647                 yield from self.process_sample(sample)
    648         else:  # Map-style dataset
    649             worker_info = torch.utils.data.get_worker_info()
    650             if worker_info is None:
    651                 worker_modulo: int = 1
    652                 worker_residue: int = 0
    653             else:
    654                 worker_modulo: int = worker_info.num_workers
    655                 worker_residue: int = worker_info.id
    656 
    657             for i in range(worker_residue, len(self.dataset), worker_modulo):
    658                 yield from self.process_sample(self.dataset[i])
    659 
    660 
    661 def load_dataset(config: gbure.utils.dotdict, split: str, path: pathlib.Path, **kwargs) -> torch.utils.data.Dataset:
    662     if split == "train" and config.get("unsupervised"):
    663         return UnsupervisedDataset(config=config, path=path, **kwargs)
    664 
    665     dstype: str
    666     data: Any
    667     dstype, data = torch.load(path)
    668     return TYPE_MAGIC[dstype](config=config, path=path, data=data, **kwargs)