gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

preprocessing.py (17972B)


      1 from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast
      2 import argparse
      3 import collections
      4 import hashlib
      5 import math
      6 import os
      7 import pathlib
      8 import random
      9 import urllib.request
     10 import zipfile
     11 
     12 import torch
     13 import tqdm
     14 import transformers
     15 
     16 from gbure.data.dictionary import Dictionary, RelationDictionary
     17 from gbure.data.graph import Graph
     18 
     19 
     20 def hash_file(path: pathlib.Path, filename: Optional[str] = None, filesize: Optional[int] = None) -> str:
     21     """ Get a unique identifier for the file. """
     22     hasher = hashlib.sha512()
     23     with path.open("rb") as file:
     24         loop = iter(lambda: file.read(2**16), b"")
     25         if filename is not None and filesize is not None:
     26             loop = tqdm.tqdm(loop,
     27                              desc=f"checking {filename} hash",
     28                              total=math.ceil(filesize / 2**16),
     29                              unit_scale=2**16, unit="B", unit_divisor=1024)
     30         for chunk in loop:
     31             hasher.update(chunk)
     32     return hasher.hexdigest()
     33 
     34 
     35 def download(url: str, path: pathlib.Path, filename: str, sha512: str) -> None:
     36     """ Download a file at the given path and check its hash. """
     37     if not path.parent.is_dir():
     38         path.parent.mkdir(parents=True)
     39 
     40     unchecked: pathlib.Path = pathlib.Path(f"{path}.unchecked")
     41     with tqdm.tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=f"downloading {filename}") as progress:
     42         def report_hook(num_blocks: int, chunk_size: int, total_size: int):
     43             if progress.total is None:
     44                 progress.total = total_size
     45             progress.update(num_blocks * chunk_size - progress.n)
     46         urllib.request.urlretrieve(url, unchecked, report_hook)
     47 
     48     unchecked_hash = hash_file(unchecked, filename, progress.total)
     49     if unchecked_hash != sha512:
     50         raise RuntimeError(f"Downloaded file \"{filename}\" has wrong hash.")
     51     os.rename(unchecked, path)
     52 
     53 
     54 def get_zip_data(dataset_path: pathlib.Path, directory_name: str, archive_name: str, archive_sha512: str, download_url: str, unzip_directory: bool = False) -> None:
     55     """ Download and extract data zip archive if needed. """
     56     if not (dataset_path / directory_name).exists():
     57         if not (dataset_path / archive_name).exists():
     58             download(download_url, dataset_path / archive_name, archive_name, archive_sha512)
     59 
     60         with zipfile.ZipFile(str(dataset_path / archive_name), "r") as archive:
     61             archive.extractall(dataset_path / directory_name if unzip_directory else dataset_path)
     62 
     63 
     64 def base_argument_parser(description: str = "", deterministic: bool = False, parser: Optional[argparse.ArgumentParser] = None) -> argparse.ArgumentParser:
     65     assert(description != "" or parser is not None)
     66     """ Return an argument parser with standard command line arguments used by preprocessing functions. """
     67     parser: argparse.ArgumentParser = argparse.ArgumentParser(description=description) if parser is None else parser
     68     parser.add_argument("tokenizer",
     69                         type=str,
     70                         nargs='?',
     71                         default="bert-base-cased",
     72                         help="Name of the transformers tokenizer")
     73     if not deterministic:
     74         parser.add_argument("-s", "--seed",
     75                             type=int,
     76                             default=0,
     77                             help="Seed of the RNG for shuffling the dataset")
     78     return parser
     79 
     80 
     81 def dataset_name(args: argparse.Namespace, infix: str = "") -> str:
     82     """ Returns the dataset name with suffix containing non-standard preprocessing parameters. """
     83     suffix: str = ""
     84     if "seed" in args and args.seed != 0:
     85         suffix = f"-s{args.seed}"
     86     return f"{args.tokenizer}{infix}{suffix}"
     87 
     88 
     89 def args_to_serialize(args: argparse.Namespace) -> Dict[str, Any]:
     90     """ Map standard preprocessing command line arguments defined in base_argument_parser to serialize_supervised_dataset parameters. """
     91     kwargs = {"tokenizer_name": args.tokenizer}
     92     if "seed" in args:
     93         kwargs["seed"] = args.seed
     94     return kwargs
     95 
     96 
     97 def make_tokenizer(name: str, path: pathlib.Path) -> transformers.PreTrainedTokenizer:
     98     """ Build the given tokenizer and save it. """
     99     if not path.is_dir():
    100         path.mkdir()
    101 
    102     tokenizer = transformers.AutoTokenizer.from_pretrained(name)
    103     special_tokens = ["<e1>", "</e1>", "<e2>", "</e2>", "<blank/>"]
    104     tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
    105     tokenizer.save_pretrained(path)
    106 
    107     # fix huggingface transformers issue #6368
    108     config_file = transformers.AutoConfig.from_pretrained(name)
    109     config_file.save_pretrained(path)
    110 
    111     return tokenizer
    112 
    113 
    114 def process_text_2(raw_text: str, tokenizer: transformers.PreTrainedTokenizer) -> Tuple[torch.Tensor, int, int]:
    115     """
    116     Transform a string with two entities tagged to a list of token ids together with the positions of the two entities.
    117 
    118     The returned token list contains the token corresponding to the tags.
    119     The two returned positions are the positions of <e1> and <e2>.
    120     """
    121     be1_id: int = tokenizer.convert_tokens_to_ids("<e1>")
    122     be2_id: int = tokenizer.convert_tokens_to_ids("<e2>")
    123 
    124     text: List[int] = tokenizer.encode(raw_text, add_special_tokens=True)
    125     e1_pos: int = text.index(be1_id)
    126     e2_pos: int = text.index(be2_id)
    127     if len(text) > tokenizer.model_max_length:
    128         text = text[:tokenizer.model_max_length]
    129         e1_pos = min(tokenizer.model_max_length-1, e1_pos)
    130         e2_pos = min(tokenizer.model_max_length-1, e2_pos)
    131 
    132     return torch.tensor(text, dtype=torch.int32), e1_pos, e2_pos
    133 
    134 
    135 def process_text_n(raw_text: str, raw_entities: List[Tuple[str, int, int]], tokenizer: transformers.PreTrainedTokenizer) -> Tuple[torch.Tensor, List[Tuple[str, int, int]]]:
    136     """
    137     Transform a string with several entities tagged to a list of token id together with the positions of entities.
    138 
    139     The returned token list does not contain the token corresponding to the tags.
    140     The returned postions, are where the tags should be inserted.
    141     If the leftmost tag is inserted first, the position of subsequent inserts should be shifted accordingly.
    142     """
    143     be1_id: int = tokenizer.convert_tokens_to_ids("<e1>")
    144 
    145     # If one entity end at a position, and another entity start at the same position, we want to close the first entity before starting the sencond one, the second field "1 - extremity" has this function since the list is sorted in lexicographic order.
    146     tag_positions: List[Tuple[int, int, int]] = [
    147             (cast(int, entity[1 + extremity]),  # Position of the tag (start or end of entity) in the sentence.
    148                 cast(int, 1 - extremity),  # Whether this is a start or end of entity.
    149                 i)  # The index of the entity used to rebuild the list at the end.
    150             for i, entity in enumerate(raw_entities) for extremity in [0, 1]]
    151     tag_positions.sort()
    152 
    153     # We insert the tag <e1> at every tag postion in order to be able to convert postions in the raw text to positions in the token list.
    154     pieces: List[str] = []
    155     for piece_start, piece_end in zip([(0,)] + tag_positions, tag_positions + [(len(raw_text),)]):
    156         pieces.append(raw_text[piece_start[0]:piece_end[0]])
    157         pieces.append("<e1>")
    158     # Remove the last <e1> added at the end of the sentence.
    159     pieces.pop()
    160 
    161     text: List[int] = tokenizer.encode("".join(pieces), add_special_tokens=True)
    162     if len(text) > tokenizer.model_max_length:
    163         text = text[:tokenizer.model_max_length]
    164 
    165     # New entity list, with converted positions.
    166     entities: List[List[Union[str, int]]] = [[entity[0], -1, -1] for entity in raw_entities]
    167 
    168     j: int = 0  # Counter on the tags.
    169     for i, token in enumerate(text):
    170         if token == be1_id:
    171             # The order of the <e1> in the text match the one in tag_positions.
    172             tag_position: Tuple[int, int, int] = tag_positions[j]
    173 
    174             # tag_position[2] is the index of the entity in raw_entities (and thus entities).
    175             # tag_position[1] is 0 for the end of the entity and 1 for its start.
    176             # Since the returned token list will be pruned of all the <e1>, the position of the tag should be shifted by the number of <e1> already met, thus "i - j".
    177             entities[tag_position[2]][2 - tag_position[1]] = i - j
    178 
    179             j += 1  # Move to the next tag.
    180 
    181     # Remove all tags
    182     text = list(filter(lambda x: x != be1_id, text))
    183 
    184     # Remove entities which didn't fit inside tokenizer.model_max_length tokens.
    185     entities = list(filter(lambda x: x[1] >= 0 and x[2] >= 0, entities))
    186 
    187     tuple_entities: List[Tuple[str, int, int]] = list(map(tuple, entities))
    188     return torch.tensor(text, dtype=torch.int32), tuple_entities
    189 
    190 
    191 def serialize_supervised_split(
    192         path: pathlib.Path,
    193         split: Iterable[Tuple[str, str, str, str, str]],
    194         tokenizer: transformers.PreTrainedTokenizer,
    195         entity_dictionary: Dictionary,
    196         relation_dictionary: RelationDictionary) -> None:
    197     """
    198     Serialize a supervised split to a given path.
    199 
    200     split is an iterable containing (text, directed relation, undirected relation, e1, e2) tuples.
    201     Entities are ignored.
    202     The relations are raw values (e.g. P42). This function performs the encoding.
    203     """
    204     data: List[Tuple[torch.Tensor, int, int, int]] = []
    205 
    206     # TODO handle entities
    207     for raw_text, relation, relation_base, _, _ in split:
    208         text, e1_pos, e2_pos = process_text_2(raw_text, tokenizer)
    209         relation_id: int = relation_dictionary.encode(relation, relation_base)
    210         data.append((text, e1_pos, e2_pos, relation_id))
    211 
    212     torch.save(("supervised", data), path)
    213 
    214 
    215 def serialize_fewshot_split(
    216         path: pathlib.Path,
    217         split: Iterable[Tuple[str, str, str, str, str]],
    218         tokenizer: transformers.PreTrainedTokenizer,
    219         entity_dictionary: Dictionary,
    220         relation_dictionary: RelationDictionary) -> None:
    221     """
    222     Serialize a fewshot split to a given path.
    223 
    224     split is an iterable containing (text, directed relation, undirected relation, e1, e2) tuples.
    225     The relations and entities are raw values (e.g. P42, Q42). This function performs the encoding.
    226     """
    227     data: Dict[int, List[Tuple[torch.Tensor, int, int, int, int, int]]] = collections.defaultdict(list)
    228 
    229     for raw_text, relation, relation_base, e1, e2 in split:
    230         text, e1_pos, e2_pos = process_text_2(raw_text, tokenizer)
    231         relation_id: int = relation_dictionary.encode(relation, relation_base)
    232         e1_id: int = entity_dictionary.encode(e1)
    233         e2_id: int = entity_dictionary.encode(e2)
    234         data[relation_id].append((text, e1_pos, e2_pos, relation_id, e1_id, e2_id))
    235 
    236     torch.save(("fewshot", list(data.values())), path)
    237 
    238 
    239 def serialize_fewshot_sampled_split(
    240         path: pathlib.Path,
    241         name: str,
    242         split: Iterable[Tuple[Tuple[str, str, str], List[List[Tuple[str, str, str]]], int]],
    243         tokenizer_name: str) -> None:
    244     """
    245     Serialize a sampled fewshot split.
    246 
    247     split is an iterable of (query, candidates, answer) tuples.
    248     In these tuples, query is a tuple (text, e1, e2).
    249     The relations are not given.
    250     """
    251     tokenizer: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(str(path / "tokenizer"))
    252     entity_dictionary = Dictionary()
    253 
    254     data: List[Tuple[torch.Tensor, int, int, int, int, List[List[torch.Tensor]], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]] = []
    255     for train, test, answer in split:
    256         query_text, query_e1_pos, query_e2_pos = process_text_2(train[0], tokenizer)
    257         query_e1 = entity_dictionary.encode(train[1])
    258         query_e2 = entity_dictionary.encode(train[2])
    259 
    260         way = len(test)
    261         shot = len(test[0])
    262         candidates_processed_text: List[List[Tuple[torch.Tensor, int, int]]] = list(map(lambda relation: list(map(lambda candidate: process_text_2(candidate[0], tokenizer), relation)), test))
    263         candidates_text_len = max(map(lambda relation: max(map(lambda candidate: candidate[0].shape[0], relation)), candidates_processed_text))
    264 
    265         candidates_text = [[None]*shot for _ in range(way)]
    266         candidates_e1_pos = torch.empty((way, shot), dtype=torch.int64)
    267         candidates_e2_pos = torch.empty((way, shot), dtype=torch.int64)
    268         candidates_e1 = torch.empty((way, shot), dtype=torch.int64)
    269         candidates_e2 = torch.empty((way, shot), dtype=torch.int64)
    270 
    271         for n, (relation, relation_processed) in enumerate(zip(test, candidates_processed_text)):
    272             for k, (candidate, candidate_processed) in enumerate(zip(relation, relation_processed)):
    273                 candidates_text[n][k] = candidate_processed[0]
    274                 candidates_e1_pos[n, k] = candidate_processed[1]
    275                 candidates_e2_pos[n, k] = candidate_processed[2]
    276                 candidates_e1[n, k] = entity_dictionary.encode(candidate[1])
    277                 candidates_e2[n, k] = entity_dictionary.encode(candidate[2])
    278 
    279         data.append((query_text, query_e1_pos, query_e2_pos, query_e1, query_e2, candidates_text, candidates_e1_pos, candidates_e2_pos, candidates_e1, candidates_e2, answer))
    280     entity_dictionary.save(path / f"{name}.entities")
    281     torch.save(("sampled fewshot", data), path / name)
    282 
    283 
    284 def serialize_dataset(
    285         supervision: str,
    286         path: pathlib.Path,
    287         splits: Dict[str, Iterable[Tuple[str, str, str, str, str]]],
    288         tokenizer_name: str,
    289         unknown_entity: Optional[str] = None,
    290         unknown_relation: Optional[str] = None,
    291         seed: Optional[int] = None) -> None:
    292     """
    293     Serialize a dataset to a given path.
    294 
    295     The splits must be given as iterables of (text, relation, relation_base, e1, e2) tuples.
    296     supervision must be one of "supervised" or "fewshot".
    297     """
    298     if not path.is_dir():
    299         path.mkdir()
    300 
    301     tokenizer: transformers.PreTrainedTokenizer = make_tokenizer(tokenizer_name, path / "tokenizer")
    302     entity_dictionary = Dictionary(unknown=unknown_entity)
    303     relation_dictionary = RelationDictionary(unknown=unknown_relation)
    304 
    305     serialize_split = serialize_supervised_split if supervision == "supervised" else serialize_fewshot_split
    306     for split_name in ["train", "valid", "test"]:
    307         if split_name not in splits:
    308             continue
    309 
    310         split = list(splits[split_name])
    311         if split_name == "train":
    312             rng = random.Random(seed)
    313             rng.shuffle(split)
    314         split = tqdm.tqdm(split, desc=f"{split_name} tokenization")
    315         serialize_split(path / split_name, split, tokenizer, entity_dictionary, relation_dictionary)
    316     entity_dictionary.save(path / "entities")
    317     relation_dictionary.save(path / "relations")
    318 
    319 
    320 def build_edge_list(data: Iterable[Tuple[str, List[Tuple[str, int, int]]]], tokenizer: transformers.PreTrainedTokenizer) -> Tuple[List[torch.Tensor], Dictionary, List[int], List[Tuple[int, int, int, int, int, int, int]]]:
    321     """
    322     Build a list of edges and nodes corresponding to the given data.
    323 
    324     The tuples in the returned edge list are composed of the following elements:
    325         (entity 1, entity 2, sentence id, entity 1 start, entity 1 end, entity 2 start, entity 2 end)
    326     """
    327     sentences: List[str] = []
    328     entity_dictionary = Dictionary()
    329     degrees: List[int] = []
    330     edges: List[Tuple[int, int, int, int, int, int, int]] = []
    331 
    332     for raw_sentence, raw_entities in data:
    333         sentence: torch.Tensor
    334         entities: List[Tuple[str, int, int]]
    335         sentence, entities = process_text_n(raw_sentence, raw_entities, tokenizer)
    336 
    337         # Buffer the ids to avoid re-hashing the entities
    338         entity_ids: List[Optional[int]] = [None] * len(entities)
    339         edge_added: bool = False
    340 
    341         # Add all edges appearing in the clique corresponding to this sentence
    342         for i, (e1_name, e1_start, e1_end) in enumerate(entities):
    343             for j, (e2_name, e2_start, e2_end) in enumerate(entities[:i]):
    344                 # Soares et al. footnote 2 "We use a window of 40 tokens"
    345                 if max(e2_end - e1_start, e1_end - e2_start) < 40:
    346                     if entity_ids[i] is None:
    347                         entity_ids[i] = entity_dictionary.encode(e1_name)
    348                         if entity_ids[i] >= len(degrees):
    349                             degrees.append(0)
    350                     e1_id: int = cast(int, entity_ids[i])
    351 
    352                     if entity_ids[j] is None:
    353                         entity_ids[j] = entity_dictionary.encode(e2_name)
    354                         if entity_ids[j] >= len(degrees):
    355                             degrees.append(0)
    356                     e2_id: int = cast(int, entity_ids[j])
    357 
    358                     if e1_id <= e2_id:
    359                         edges.append((e1_id, e2_id, len(sentences), e1_start, e1_end, e2_start, e2_end))
    360                     else:
    361                         edges.append((e2_id, e1_id, len(sentences), e2_start, e2_end, e1_start, e1_end))
    362                     degrees[e1_id] += 1
    363                     degrees[e2_id] += 1
    364                     edge_added = True
    365 
    366         if edge_added:
    367             sentences.append(sentence)
    368 
    369     return sentences, entity_dictionary, degrees, edges
    370 
    371 
    372 def serialize_unsupervised_dataset(
    373         path: pathlib.Path,
    374         data: Iterable[Tuple[str, List[Tuple[str, int, int]]]],
    375         tokenizer_name: str,
    376         seed: int) -> None:
    377     """
    378     Serialize an unsupervised dataset to a given path.
    379 
    380     The data must be given as an iterable of (sentence, list of entities) tuples.
    381     Where entities are tuples of (identifier, start indice in sentence, end indice in sentence).
    382     """
    383     if not path.is_dir():
    384         path.mkdir()
    385 
    386     tokenizer: transformers.PreTrainedTokenizer = make_tokenizer(tokenizer_name, path / "tokenizer")
    387 
    388     sentences: List[str]
    389     entities: Dictionary
    390     edges: List[Tuple[int, int, int, int, int, int, int]]
    391     graph = Graph(*build_edge_list(data, tokenizer))
    392     graph.save(path / "train")