prepare_sampled_fewrel.py (2555B)
1 from typing import Any, Dict, Iterable, List, Optional, Tuple 2 import argparse 3 import hashlib 4 import itertools 5 import json 6 import pathlib 7 8 import tqdm 9 10 from gbure.utils import DATA_PATH 11 from gbure.data.prepare_fewrel import DATASET_PATH, process_sentence 12 import gbure.data.preprocessing as preprocessing 13 14 15 def read_entry(entry: Dict[str, Any]) -> Tuple[str, str, str]: 16 """ Convert the json object of an entry to a preprocessing tuple (sentence, e1, e2). """ 17 return process_sentence(entry), entry["h"][1], entry["t"][1] 18 19 20 def read_file(inpath: pathlib.Path, outpath: Optional[pathlib.Path]) -> Iterable[Tuple[Tuple[str, str, str], List[List[Tuple[str, str, str]]], int]]: 21 """ 22 Yield (test, [[train]]) pair from a file of samples. 23 24 Each input is a triplet (sentence, head entity, tail entity). 25 """ 26 with open(inpath) as file: 27 data = json.load(file) 28 29 if outpath: 30 with open(outpath) as file: 31 answers = json.load(file) 32 else: 33 answers = itertools.repeat(-1) 34 35 for problem, answer in zip(tqdm.tqdm(data, desc=f"processing {inpath.name}"+(" and "+outpath.name if outpath else "")), answers): 36 test = read_entry(problem["meta_test"]) 37 train = list(map(lambda candidates: list(map(read_entry, candidates)), problem["meta_train"])) 38 yield (test, train, answer) 39 40 41 if __name__ == "__main__": 42 parser: argparse.ArgumentParser = preprocessing.base_argument_parser("Prepare a sampled few shot FewRel dataset (generated by sample_io.py).", deterministic=True) 43 parser.add_argument("inpath", 44 type=pathlib.Path, 45 help="Path to the file containing the input ") 46 parser.add_argument("outpath", 47 type=pathlib.Path, 48 nargs="?", 49 help="Path to the file containing the output (optional)") 50 parser.add_argument("-S", "--suffix", 51 type=str, 52 default="", 53 help="Suffix to add to the tokenizer to find the dataset.") 54 55 args: argparse.Namespace = parser.parse_args() 56 hashid: str = preprocessing.hash_file(args.inpath)[:8] 57 if args.outpath: 58 hashid += preprocessing.hash_file(args.outpath)[:8] 59 name: str = preprocessing.dataset_name(args, args.suffix) 60 61 preprocessing.serialize_fewshot_sampled_split( 62 path=DATASET_PATH / name, 63 name=hashid, 64 split=read_file(args.inpath, args.outpath), 65 **preprocessing.args_to_serialize(args))