gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

batcher.py (8378B)


      1 from typing import Any, Dict, List, Tuple
      2 import collections
      3 
      4 import torch
      5 
      6 
      7 # Must be kept prefix-sorted!
      8 # (prefix, list depth)
      9 FEATURE_PREFIXES: List[Tuple[str, int]] = [
     10         ("query_e1_neighborhood_", 1),
     11         ("query_e2_neighborhood_", 1),
     12         ("candidates_e1_neighborhood_", 3),
     13         ("candidates_e2_neighborhood_", 3),
     14         ("first_e1_neighborhood_", 1),
     15         ("first_e2_neighborhood_", 1),
     16         ("second_e1_neighborhood_", 1),
     17         ("second_e2_neighborhood_", 1),
     18         ("third_e1_neighborhood_", 1),
     19         ("third_e2_neighborhood_", 1),
     20         ("query_", 0),
     21         ("candidates_", 2),
     22         ("first_", 0),
     23         ("second_", 0),
     24         ("third_", 0),
     25         ("", 0)]
     26 
     27 
     28 class Batcher:
     29     """
     30     Batch a group of sample together.
     31 
     32     Two new features are derived from the "text": its length and a mask.
     33     """
     34     def __init__(self, pad_value: int) -> None:
     35         """ Initialize a Batcher, using the provided value to pad text. """
     36         self.pad_value: int = pad_value
     37 
     38     def add_length_field(self, batch: Dict[str, Any], prefix: str, depth: int) -> None:
     39         """ Add the length field for the given prefix. """
     40         text: List[Any] = batch[f"{prefix}text"]
     41         batch_size: int = len(text)
     42 
     43         if depth == 0:
     44             # text is a list of sentences
     45             lengths: torch.Tensor = torch.empty((batch_size,), dtype=torch.int64)
     46             for b, sentence in enumerate(text):
     47                 lengths[b] = sentence.shape[0]
     48         elif depth == 1:
     49             # text is a list of list of sentences (each sample contains several candidates)
     50             size: int = len(text[0])
     51             lengths: torch.Tensor = torch.empty((batch_size, size), dtype=torch.int64)
     52             for b, sample in enumerate(text):
     53                 for i, sentence in enumerate(sample):
     54                     lengths[b, i] = sentence.shape[0]
     55         elif depth == 2:
     56             # text is a list of list of list of sentences (each sample contains several candidates)
     57             way: int = len(text[0])
     58             shot: int = len(text[0][0])
     59             lengths: torch.Tensor = torch.empty((batch_size, way, shot), dtype=torch.int64)
     60             for b, sample in enumerate(text):
     61                 for w, candidates in enumerate(sample):
     62                     for s, candidate in enumerate(candidates):
     63                         lengths[b, w, s] = candidate.shape[0]
     64         elif depth == 3:
     65             # text is a list of list of list of list of sentences (each sample contains several candidates' neighborhoods)
     66             way: int = len(text[0])
     67             shot: int = len(text[0][0])
     68             size: int = len(text[0][0][0])
     69             lengths: torch.Tensor = torch.empty((batch_size, way, shot, size), dtype=torch.int64)
     70             for b, sample in enumerate(text):
     71                 for w, candidates in enumerate(sample):
     72                     for s, candidate in enumerate(candidates):
     73                         for n, neighbor in enumerate(candidate):
     74                             lengths[b, w, s, n] = neighbor.shape[0]
     75 
     76         batch[f"{prefix}length"] = lengths
     77 
     78     def process_text(self, batch: Dict[str, Any], prefix: str, depth: int, key: str) -> None:
     79         """ Build mask and text batch by padding sentences. """
     80         in_text: List[Any] = batch[f"{prefix}{key}"]
     81         if isinstance(batch[f"{prefix}length"], list):
     82             self.add_length_field(batch, prefix, depth)
     83         max_seq_len: int = max(batch[f"{prefix}length"].max(), 1)
     84         batch_size: int = len(in_text)
     85 
     86         if depth == 0:
     87             # text is a list of sentences
     88             text: torch.Tensor = torch.empty((batch_size, max_seq_len), dtype=torch.int64)
     89             mask: torch.Tensor = torch.empty((batch_size, max_seq_len), dtype=torch.bool)
     90             for b, sentence in enumerate(in_text):
     91                 text[b, :sentence.shape[0]] = sentence
     92                 text[b, sentence.shape[0]:] = self.pad_value
     93                 mask[b, :sentence.shape[0]] = 1
     94                 mask[b, sentence.shape[0]:] = 0
     95         elif depth == 1:
     96             # text is a list of list of sentences (each sample contains several candidates)
     97             # In this case, we are not sure the tensor is full (some neighborhoods might be of different sizes or even empty)
     98             size: int = len(in_text[0])
     99             text: torch.Tensor = torch.empty((batch_size, size, max_seq_len), dtype=torch.int64)
    100             mask: torch.Tensor = torch.zeros((batch_size, size, max_seq_len), dtype=torch.bool)
    101             for b, samples in enumerate(in_text):
    102                 for i, sentence in enumerate(samples):
    103                     text[b, i, :sentence.shape[0]] = sentence
    104                     text[b, i, sentence.shape[0]:] = self.pad_value
    105                     mask[b, i, :sentence.shape[0]] = 1
    106         elif depth == 2:
    107             # text is a list of list of list of sentences (each sample contains several candidates)
    108             # In this case, we are sure the tensor is full (all n way have the save k shots)
    109             way: int = len(in_text[0])
    110             shot: int = len(in_text[0][0])
    111             text: torch.Tensor = torch.empty((batch_size, way, shot, max_seq_len), dtype=torch.int64)
    112             mask: torch.Tensor = torch.empty((batch_size, way, shot, max_seq_len), dtype=torch.bool)
    113             for b, samples in enumerate(in_text):
    114                 for w, candidates in enumerate(samples):
    115                     for s, candidate in enumerate(candidates):
    116                         text[b, w, s, :candidate.shape[0]] = candidate
    117                         text[b, w, s, candidate.shape[0]:] = self.pad_value
    118                         mask[b, w, s, :candidate.shape[0]] = 1
    119                         mask[b, w, s, candidate.shape[0]:] = 0
    120         elif depth == 3:
    121             # text is a list of list of list of list of sentences (each sample contains several candidates' neighborhoods)
    122             # In this case, we are not sure the tensor is full (some neighborhoods might be of different sizes or even empty)
    123             way: int = len(in_text[0])
    124             shot: int = len(in_text[0][0])
    125             size: int = len(in_text[0][0][0])
    126             text: torch.Tensor = torch.empty((batch_size, way, shot, size, max_seq_len), dtype=torch.int64)
    127             mask: torch.Tensor = torch.empty((batch_size, way, shot, size, max_seq_len), dtype=torch.bool)
    128             for b, samples in enumerate(in_text):
    129                 for w, candidates in enumerate(samples):
    130                     for s, candidate in enumerate(candidates):
    131                         for n, neighbor in enumerate(candidate):
    132                             text[b, w, s, n, :neighbor.shape[0]] = neighbor
    133                             text[b, w, s, n, neighbor.shape[0]:] = self.pad_value
    134                             mask[b, w, s, n, :neighbor.shape[0]] = 1
    135                             mask[b, w, s, n, neighbor.shape[0]:] = 0
    136 
    137         batch[f"{prefix}{key}"] = text
    138         if f"{prefix}mask" not in batch:
    139             batch[f"{prefix}mask"] = mask
    140 
    141     def process_int_feature(self, batch: Dict[str, Any], prefix: str, feature: str) -> None:
    142         """ Transform a list of integer into a torch LongTensor. """
    143         # TODO handle neighborhoods of different sizes
    144         if isinstance(batch[f"{prefix}{feature}"][0], torch.Tensor):
    145             batch[f"{prefix}{feature}"] = torch.stack(batch[f"{prefix}{feature}"])
    146         else:
    147             batch[f"{prefix}{feature}"] = torch.tensor(batch[f"{prefix}{feature}"], dtype=torch.int64)
    148 
    149     def __call__(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
    150         """ Batch the provided samples """
    151         batch = collections.defaultdict(list)
    152         for sample in samples:
    153             for key, value in sample.items():
    154                 batch[key].append(value)
    155 
    156         for key in list(batch.keys()):
    157             for prefix, depth in FEATURE_PREFIXES:
    158                 if key.startswith(prefix):
    159                     break
    160             feature: str = key[len(prefix):]
    161             if feature in ["text", "mlm_input", "mlm_target"]:
    162                 self.process_text(batch, prefix, depth, feature)
    163             if feature in ["relation", "entity_positions", "entity_identifiers", "entity_degrees", "edge_identifier", "polarity", "answer", "eid"]:
    164                 self.process_int_feature(batch, prefix, feature)
    165 
    166         return batch