gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

commit a52d7245480b4e0be22a2c21691afd7a35b26d16
parent 35e7296c5fbe3ed83103ff94676a3cff90030d28
Author: Étienne Simon <esimon@esimon.eu>
Date:   Wed, 18 May 2022 00:23:20 +0000

Implement new graph-based approaches

Diffstat:
MREADME | 48+++++++++++++++---------------------------------
Dfsre/__init__.py | 4----
Dfsre/config/soares_supervised_kbp37.py | 21---------------------
Dfsre/config/soares_supervised_semeval.py | 21---------------------
Dfsre/data/__init__.py | 2--
Dfsre/data/dataset.py | 116-------------------------------------------------------------------------------
Dfsre/data/prepare_kbp37.py | 72------------------------------------------------------------------------
Dfsre/data/prepare_semeval.py | 106-------------------------------------------------------------------------------
Dfsre/data/relation_dictionary.py | 94-------------------------------------------------------------------------------
Dfsre/eval.py | 30------------------------------
Dfsre/metrics.py | 250-------------------------------------------------------------------------------
Dfsre/model/mtb_classifier.py | 58----------------------------------------------------------
Dfsre/model/mtb_supervised.py | 39---------------------------------------
Dfsre/train.py | 353-------------------------------------------------------------------------------
Dfsre/utils.py | 143-------------------------------------------------------------------------------
Agbure/__init__.py | 0
Agbure/config/contrastive_alignment.py | 53+++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/config/gcn_mtb.py | 63+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/config/mtb.py | 57+++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/config/nonparametric.py | 32++++++++++++++++++++++++++++++++
Agbure/config/soares_fewrel.py | 30++++++++++++++++++++++++++++++
Agbure/config/soares_kbp37.py | 23+++++++++++++++++++++++
Agbure/config/soares_semeval.py | 23+++++++++++++++++++++++
Agbure/data/__init__.py | 0
Agbure/data/batcher.py | 166+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/data/dataset.py | 668+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/data/dictionary.py | 112+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/data/graph.py | 146+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/data/prepare_fewrel.py | 65+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/data/prepare_kbp37.py | 41+++++++++++++++++++++++++++++++++++++++++
Agbure/data/prepare_sampled_fewrel.py | 65+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/data/prepare_semeval.py | 83+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/data/prepare_trex.py | 76++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/data/preprocessing.py | 392+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/eval.py | 34++++++++++++++++++++++++++++++++++
Agbure/metrics.py | 295+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/model/__init__.py | 0
Agbure/model/contrastive_alignment.py | 85+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/model/fewshot.py | 117+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/model/linguistic_encoder.py | 75+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/model/masked_lm.py | 59+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/model/matching_the_blanks.py | 109+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/model/similarity.py | 134+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/model/supervised.py | 45+++++++++++++++++++++++++++++++++++++++++++++
Agbure/model/topological_encoder.py | 65+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/outputs.py | 55+++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/train.py | 421+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Agbure/utils.py | 387+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Drequirements.txt | 4----
Ascripts/contrastive_alignment.sh | 58++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Ascripts/mtb.sh | 47+++++++++++++++++++++++++++++++++++++++++++++++
Ascripts/mtb_gcn.sh | 54++++++++++++++++++++++++++++++++++++++++++++++++++++++
Ascripts/nonparametric.sh | 42++++++++++++++++++++++++++++++++++++++++++
53 files changed, 4192 insertions(+), 1346 deletions(-)

diff --git a/README b/README @@ -1,35 +1,17 @@ -Reproduction of Matching the Blanks: Distributional Similarity for Relation Learning by Livio Baldini Soares, Nicholas FitzGerald, Jeffrey Ling, and Tom Kwiatkowski. +This repository presents some results of graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem. -This repository currently contains the supervised "entity markers-entity start" model for the Semeval and KBP37 datasets, the unsupervised MTB model and the FewRel dataset will be added latter on. -In order to reproduce the results, you first need to manually download and extract the datasets (e.g. in /tmp), then execute the following: +$ export DATA_PATH="/tmp" # the directory for storing datasets (<100MiB for each dataset, except T-REx which need a bit less than 40GiB) +$ export LOG_PATH="/tmp" # the directory where the models will be saved (around 12GiB per model) +$ https://esimon.eu/repos/gbure.git +$ cd gbure +$ python3 -m gbure.data.prepare_semeval +$ python3 -m gbure.data.prepare_kbp37 +$ python3 -m gbure.data.prepare_fewrel +$ python3 -m gbure.train gbure/config/soares_semeval.py +$ python3 -m gbure.train gbure/config/soares_kbp37.py +$ python3 -m gbure.train gbure/config/soares_fewrel.py -$ export DATA_PATH="/tmp" # the directory containing the extracted datasets -$ export LOG_PATH="/tmp" # the directory where the models will be saved -$ git clone https://gitlab.lip6.fr/esimon/few-shot-relation-extraction -$ cd few-shot-relation-extraction -$ python -m fsre.data.prepare_semeval -$ python -m fsre.train fsre/config/soares_supervised_semeval.py -$ python -m fsre.data.prepare_kbp37 -$ python -m fsre.train fsre/config/soares_supervised_kbp37.py - -I must be missing something since I'm a bit away from the results reported in the paper. -Here are the score reached by this repository (official SemEval macro F1 "taking directionality into account"): - -On Semeval: - paper valid: 82.1 - BERT cased valid: 88.55 (std: 0.70) - BERT uncased valid: 87.49 (std: 0.87) - paper test: 89.2 - BERT cased test: 88.24 (std: 0.37) - BERT uncased test: 88.02 (std: 0.65) - -On KBP37: - paper valid: 70.0 - BERT cased valid: 66.92 (std: 0.61) - BERT uncased valid: 66.48 (std: 0.46) - paper test: 68.3 - BERT cased test: 67.18 (std: 0.51) - BERT uncased test: 66.21 (std: 0.62) - -The mean and std are computed over 5 runs, all reported results are for "large" BERT models. -Detailed results can be found at http://www-ia.lip6.fr/~esimon/results.xhtml (temporary link). +To evaluate on the static (already sampled) FewRel subset used for evaluation you should use the sample_io.py script from FewRel: +$ python3 sample_io.py $DATA_PATH/FewRel/val.json 50000 5 1 0 input > $DATA_PATH/FewRel/val_50000_5_1_0_input +$ python3 sample_io.py $DATA_PATH/FewRel/val.json 50000 5 1 0 output > $DATA_PATH/FewRel/val_50000_5_1_0_output +$ python3 -m gbure.data.prepare_sampled_fewrel bert-base-cased $DATA_PATH/FewRel/val_50000_5_1_0_input $DATA_PATH/FewRel/val_50000_5_1_0_output diff --git a/fsre/__init__.py b/fsre/__init__.py @@ -1,4 +0,0 @@ -import fsre.data -import fsre.utils -import fsre.metrics -import fsre.train diff --git a/fsre/config/soares_supervised_kbp37.py b/fsre/config/soares_supervised_kbp37.py @@ -1,21 +0,0 @@ -from fsre.model.mtb_supervised import Model - - -dataset_name = "KBP37" - -# From Table 1 -bert_model = "bert-large-cased" -post_transformer_layer = "linear" -max_epoch = 10 -learning_rate = 3e-5 -true_batch_size = 64 - -# Guessed -validation_metric = "half_directed_macro_f1" -early_stopping_patience = 2 - -# Implementation details -seed = 0 -batch_size = 2 -batch_per_sort_bucket = 8 -sort_per_shuffle_bucket = 8 diff --git a/fsre/config/soares_supervised_semeval.py b/fsre/config/soares_supervised_semeval.py @@ -1,21 +0,0 @@ -from fsre.model.mtb_supervised import Model - - -dataset_name = "SemEval2010_task8_all_data" - -# From Table 1 -bert_model = "bert-large-cased" -post_transformer_layer = "layer_norm" -max_epoch = 10 -learning_rate = 3e-5 -true_batch_size = 64 - -# Guessed -validation_metric = "half_directed_macro_f1" -early_stopping_patience = 2 - -# Implementation details -seed = 0 -batch_size = 2 -batch_per_sort_bucket = 8 -sort_per_shuffle_bucket = 8 diff --git a/fsre/data/__init__.py b/fsre/data/__init__.py @@ -1,2 +0,0 @@ -from fsre.data.dataset import RelationExtractionDataset -from fsre.data.relation_dictionary import RelationDictionary diff --git a/fsre/data/dataset.py b/fsre/data/dataset.py @@ -1,116 +0,0 @@ -import math -import numpy -import torch - - -class RelationExtractionDataset(torch.utils.data.IterableDataset): - """ - Read a preprocessed Relation Extraction dataset from a .npy file. - - When generating data, we first read a large shuffle bucket which is - shuffled (unless the dataset is for evaluation). This shuffle_bucket - is then cut down into several sort buckets, each of them is sorted - so that sentences of similar length end up next to each other. The - sort buckets are then cut down into batches which pad the sentences. - However this class generate samples, the proper batching should be - done by a DataLoader. - A preprocessed dataset can be created from the fsre.data.prepare_* - modules. - - Config: - batch_per_sort_bucket: the number of batches in a sort bucket - batch_size: the number of samples in a batch - seed: the seed for the random number generator - sort_per_shuffle_bucket: the number of sort buckets in a shuffle bucket - """ - - def __init__(self, config, path, pad, evaluation, rng=None): - """ - Initialize a Relation Extraction dataset and load the data in RAM. - - Args: - config: global config object - path: path to the dataset to load, this should be a .npy file - pad: the value used to pad text in a batch - evaluation: whether this dataset is an evaluation one (no need to sort then) - rng: the random number generator to use for shuffling - """ - - super().__init__() - - self.config = config - self.pad = pad - self.evaluation = evaluation - - self.data = numpy.load(path, allow_pickle=True) - if not evaluation: - self.rng = rng if rng is not None else numpy.random.RandomState(config.seed) - - self.batch_size = config.batch_size - self.sort_bucket_size = config.batch_size * config.batch_per_sort_bucket - self.shuffle_bucket_size = self.sort_bucket_size * config.sort_per_shuffle_bucket - - def __len__(self): - return len(self.data) - - def pad_text(self, text, diff, pad): - """ Append diff tokens pad to the end of text """ - text = torch.tensor(text, dtype=torch.int64) - padding = text.new_full((diff,), pad) - return torch.cat((text, padding)) - - def iter_sample(self, samples): - """ Generate samples from a batch """ - lengths = [sample[1].shape[0] for sample in samples] - max_len = max(lengths) - - for sample, length in zip(samples, lengths): - ds = dict(zip(["id", "text", "e1_pos", "e2_pos", "relation"], sample)) - ds["length"] = length - ds["mask"] = self.pad_text(numpy.ones_like(ds["text"]), max_len - length, 0) - ds["text"] = self.pad_text(ds["text"], max_len - length, self.pad) - yield ds - - def iter_batch(self, sort_bucket): - """ Generate batches from a sort_bucket """ - for batch_start in range(0, len(sort_bucket), self.batch_size): - batch_end = batch_start + self.batch_size - batch_end = min(batch_end, len(sort_bucket)) - - yield from self.iter_sample([self.data[i] for i in sort_bucket[batch_start:batch_end]]) - - def iter_shuffle_bucket(self, shuffle_bucket): - """ Generate sort_buckets from a shuffle_bucket """ - worker_info = torch.utils.data.get_worker_info() - if worker_info is None: - start = 0 - end = len(shuffle_bucket) - else: - work_size = len(shuffle_bucket) / self.sort_bucket_size - per_worker = int(math.ceil(work_size / float(worker_info.num_workers))) * self.sort_bucket_size - start = worker_info.id * per_worker - end = min(start + per_worker, len(shuffle_bucket)) - - for sort_start in range(start, end, self.sort_bucket_size): - sort_end = sort_start + self.sort_bucket_size - sort_end = min(sort_end, end) - - sort_bucket = shuffle_bucket[sort_start:sort_end] - - sort_bucket.sort(key=lambda x: self.data[x, 1].shape[0]) - yield from self.iter_batch(sort_bucket) - - def __iter__(self): - """ Generate shuffle_buckets from the dataset """ - for shuffle_start in range(0, len(self), self.shuffle_bucket_size): - shuffle_end = shuffle_start + self.shuffle_bucket_size - shuffle_end = min(shuffle_end, len(self)) - - shuffle_bucket = list(range(shuffle_start, shuffle_end)) - - if not self.evaluation: - self.rng.shuffle(shuffle_bucket) - yield from self.iter_shuffle_bucket(shuffle_bucket) - - def state_dict(self): - return {"rng": self.rng} diff --git a/fsre/data/prepare_kbp37.py b/fsre/data/prepare_kbp37.py @@ -1,72 +0,0 @@ -import argparse -import numpy -import transformers -import tqdm - -from fsre.utils import DATA_PATH -from fsre.data.relation_dictionary import RelationDictionary -from fsre.data.prepare_semeval import load_semeval_dataset - -TRAIN_SIZE = 15917 -VALID_SIZE = 1724 -TEST_SIZE = 3405 -UNKNOWN_RELATION = "no_relation" - - -def prepare_kbp37(args): - rng = numpy.random.RandomState(args.seed) - kbp37_path = DATA_PATH / "KBP37" - output_path = kbp37_path / args.tokenizer - - if not output_path.is_dir(): - output_path.mkdir() - - relation_dictionary = RelationDictionary(unknown=UNKNOWN_RELATION) - - tokenizer = transformers.BertTokenizer.from_pretrained(args.tokenizer) - tokenizer.add_special_tokens({"additional_special_tokens": ["<e1>", "</e1>", "<e2>", "</e2>"]}) - tokenizer_path = output_path / "tokenizer" - if not tokenizer_path.is_dir(): - tokenizer_path.mkdir() - tokenizer.save_pretrained(tokenizer_path) - - train = load_semeval_dataset( - kbp37_path / "train.txt", - tokenizer, - relation_dictionary, - TRAIN_SIZE) - rng.shuffle(train) - - valid = load_semeval_dataset( - kbp37_path / "dev.txt", - tokenizer, - relation_dictionary, - VALID_SIZE) - - test = load_semeval_dataset( - kbp37_path / "test.txt", - tokenizer, - relation_dictionary, - TEST_SIZE) - - numpy.save(output_path / "train.npy", numpy.array(train)) - numpy.save(output_path / "valid.npy", numpy.array(valid)) - numpy.save(output_path / "test.npy", numpy.array(test)) - - relation_dictionary.save(output_path / "relations") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Prepare the KBP37 dataset.") - parser.add_argument("tokenizer", - type=str, - nargs='?', - default="bert-large-cased", - help="Name of the transformers tokenizer") - parser.add_argument("-s", "--seed", - type=int, - default=0, - help="Seed of the RNG for shuffling the dataset") - - prepare_kbp37(parser.parse_args()) diff --git a/fsre/data/prepare_semeval.py b/fsre/data/prepare_semeval.py @@ -1,106 +0,0 @@ -import argparse -import numpy -import transformers -import tqdm - -from fsre.utils import DATA_PATH -from fsre.data.relation_dictionary import RelationDictionary - -TRAIN_SIZE = 8000 -TEST_SIZE = 2717 -UNKNOWN_RELATION = "Other" - - -def load_semeval_dataset(path, tokenizer, relation_dictionary, size): - be1_id = tokenizer.added_tokens_encoder["<e1>"] - be2_id = tokenizer.added_tokens_encoder["<e2>"] - - dataset = [] - with open(path) as infile: - for _ in tqdm.trange(size): - idtext_line = infile.readline() - relation_line = infile.readline() - if not (idtext_line and relation_line): - break - - id, raw_text = idtext_line.rstrip().split('\t') - id = int(id) - - raw_text = raw_text[1:-1] # remove quotes around text - text = tokenizer.encode(raw_text, add_special_tokens=True) - e1_pos = text.index(be1_id) - e2_pos = text.index(be2_id) - if len(text) > tokenizer.max_len: - text = text[:tokenizer.max_len] - e1_pos = min(tokenizer.max_len-1, e1_pos) - e2_pos = min(tokenizer.max_len-1, e2_pos) - text = numpy.array(text, dtype=numpy.int32) - - relation_line = relation_line.rstrip() - dir_start = relation_line.find('(') - relation_base = relation_line[:dir_start] if dir_start >= 0 else relation_line - relation = relation_dictionary.encode(relation_line, relation_base) - - dataset.append([id, text, e1_pos, e2_pos, relation]) - infile.readline() # Ignore Comment line - infile.readline() # Ignore empty line - return dataset - - -def prepare_semeval(args): - rng = numpy.random.RandomState(args.seed) - semeval_path = DATA_PATH / "SemEval2010_task8_all_data" - output_path = semeval_path / args.tokenizer - - if not output_path.is_dir(): - output_path.mkdir() - - relation_dictionary = RelationDictionary(unknown=UNKNOWN_RELATION) - - tokenizer = transformers.BertTokenizer.from_pretrained(args.tokenizer) - tokenizer.add_special_tokens({"additional_special_tokens": ["<e1>", "</e1>", "<e2>", "</e2>"]}) - tokenizer_path = output_path / "tokenizer" - if not tokenizer_path.is_dir(): - tokenizer_path.mkdir() - tokenizer.save_pretrained(tokenizer_path) - - dataset = load_semeval_dataset( - semeval_path / "SemEval2010_task8_training" / "TRAIN_FILE.TXT", - tokenizer, - relation_dictionary, - TRAIN_SIZE) - rng.shuffle(dataset) - train = dataset[args.valid_size:] - valid = dataset[:args.valid_size] - - test = load_semeval_dataset( - semeval_path / "SemEval2010_task8_testing_keys" / "TEST_FILE_FULL.TXT", - tokenizer, - relation_dictionary, - TEST_SIZE) - - numpy.save(output_path / "train.npy", numpy.array(train)) - numpy.save(output_path / "valid.npy", numpy.array(valid)) - numpy.save(output_path / "test.npy", numpy.array(test)) - - relation_dictionary.save(output_path / "relations") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Prepare the SemEval 2010 Task 8 dataset.") - parser.add_argument("tokenizer", - type=str, - nargs='?', - default="bert-large-cased", - help="Name of the transformers tokenizer") - parser.add_argument("-s", "--seed", - type=int, - default=0, - help="Seed of the RNG for shuffling the dataset") - parser.add_argument("-v", "--valid-size", - type=int, - default=1500, - help="Size of the validation set") - - prepare_semeval(parser.parse_args()) diff --git a/fsre/data/relation_dictionary.py b/fsre/data/relation_dictionary.py @@ -1,94 +0,0 @@ -import pickle - - -class RelationDictionary: - """ - A dictionary to be used for relations. - - The tokens held by this class are divided between: - - *relation* such as "Entity-Destination(e1,e2)" - - *base* such as "Entity-Destination" - """ - - def __init__(self, *, unknown=None, path=None): - self.encoder = {} - self.decoder = [] - - self.base_encoder = {} - self.base_decoder = [] - self.id_to_bid = [] - - self.unknown = unknown - if unknown is not None: - self.encoder[unknown] = 0 - self.decoder.append(unknown) - self.base_encoder[unknown] = 0 - self.base_decoder.append(unknown) - self.id_to_bid.append(0) - - if path is not None: - self.load(path) - - def __len__(self): - """ Number of relations in the dictionary. """ - return len(self.decoder) - - def base_size(self): - """ Number of bases in the dictionary. """ - return len(self.base_decoder) - - def encode(self, relation, base=None): - """ - Returns the id corresponding to a relation string. - - Args: - relation: the string of the relation (e.g. "Entity-Destination(e1,e2)") - base: the string of the base relation (e.g. "Entity-Destination") - """ - - if relation is None: - return None - - if base is None: - return self.encoder[relation] - - id = self.encoder.get(relation) - if id is not None: - return id - - bid = self.base_encoder.get(base) - if bid is None: - bid = len(self.base_decoder) - self.base_encoder[base] = bid - self.base_decoder.append(base) - - id = len(self.decoder) - self.encoder[relation] = id - self.decoder.append(relation) - self.id_to_bid.append(bid) - return id - - def decode(self, id): - """ Returns the string corresponding to a relation id. """ - return self.decoder[id] - - def base_id(self, id): - """ Returns the base id corresponding to a relation id. """ - return self.id_to_bid[id] - - def save(self, path): - with open(path, "wb") as file: - pickle.dump({ - "unknown": self.unknown, - "decoder": self.decoder, - "encoder": self.encoder, - "base_encoder": self.base_encoder, - "base_decoder": self.base_decoder, - "id_to_bid": self.id_to_bid, - }, file) - - def load(self, path): - with open(path, "rb") as file: - data = pickle.load(file) - for key, value in data.items(): - setattr(self, key, value) diff --git a/fsre/eval.py b/fsre/eval.py @@ -1,30 +0,0 @@ -import torch - -import fsre - - -class Evaluator(fsre.train.Trainer): - """ - Evaluate a model. - """ - - def __init__(self, config, state_dicts): - self.eval_config = config - super().__init__(state_dicts["config"], None, state_dicts) - - def run(self): - self.epoch = self.state_dicts["epoch"] - self.info() - self.initialize_rng() - self.prepare_dataset() - self.build_model() - self.count_parameters() - self.evaluate("test") - - -if __name__ == "__main__": - fsre.utils.fix_transformers_logging_handler() - config = fsre.utils.parse_args() - - state_dicts = torch.load(config.load) - Evaluator(config, state_dicts).run() diff --git a/fsre/metrics.py b/fsre/metrics.py @@ -1,250 +0,0 @@ -import math -import numpy -import torch - - -class Metrics: - """ - Class for computing metrics. - - Twenty metrics are computed: - - Accuracy - - Negative Log Likelihood (nll) - - {directed, undirected, half_directed} {micro, macro} {f1, precision, recall} - Note that the Accuracy is the true accuracy, taking directionality into account and scoring the unknown relation as any other relation. - The last 18 metrics follow the SemEval scorer: - - The unknown ("Other") relation is only scored indirectly - - Directed is equivalent to the metrics "USING DIRECTIONALITY" - - Undirected is equivalent to the metrics "IGNORING DIRECTIONALITY" - - Half-directed is equivalent to the metrics "TAKING DIRECTIONALITY INTO ACCOUNT -- OFFICIAL" - Note that the directed and half_directed micro metrics are equivalents. - """ - - def __init__(self, relation_dictionary): - """ - Initialize all metrics. - - Args: - relation_dictionary: see class RelationDictionary - """ - - self.relation_dictionary = relation_dictionary - self.n = len(relation_dictionary) - self.m = relation_dictionary.base_size() - self.build_mask() - self.build_base_transition() - self.crossentropy = torch.nn.CrossEntropyLoss(reduction="sum") - - self.size = 0 - self.correct = 0 - self.ce_sum = 0 - self.confusion = numpy.zeros((self.n, self.n), numpy.int64) - - def build_mask(self): - self.mask = numpy.ones(self.n) - if self.relation_dictionary.unknown is not None: - assert(self.relation_dictionary.decode(0) == self.relation_dictionary.unknown) - self.mask[0] = 0 - - def build_base_transition(self): - self.base_transition = numpy.zeros((self.n, self.m)) - for id, bid in enumerate(self.relation_dictionary.id_to_bid): - self.base_transition[id, bid] = 1 - - def update(self, predictions, target): - """ - Update metrics with a batch of predictions and corresponding targets. - - Args: - predictions: the predicted logits (before softmax) - target: the gold relations - """ - - self.size += predictions.shape[0] - self.ce_sum += self.crossentropy(predictions, target).item() - - prediction = predictions.argmax(1) - for p, t in zip(prediction.tolist(), target.tolist()): - self.confusion[p, t] += 1 - self.correct += (p == t) - - @property - def summary(self): - return {"accuracy": f"{self.accuracy*100:.2f}", - "nll": f"{self.nll:.2f}"} - - @property - def all(self): - keys = ["accuracy", "nll"] + [ - f"{direction}_{level}_{metric}" - for direction in ["directed", "undirected", "half_directed"] - for level in ["macro", "micro"] - for metric in ["f1", "precision", "recall"]] - return {key: getattr(self, key) for key in keys} - - @property - def base_mask(self): - return self.mask.dot(self.base_transition).clip(0, 1) - - @property - def base_confusion(self): - return self.base_transition.T.dot(self.confusion).dot(self.base_transition) - - @property - def accuracy(self): - return math.nan if self.size == 0 else self.correct / self.size - - @property - def nll(self): - return math.nan if self.size == 0 else self.ce_sum / self.size - - ########################## - # Directed macro metrics # - ########################## - - @property - def directed_class_precision(self): - norm = self.confusion.sum(1) - norm[norm == 0] = 1 - return self.confusion.diagonal() / norm - - @property - def directed_class_recall(self): - norm = self.confusion.sum(0) - norm[norm == 0] = 1 - return self.confusion.diagonal() / norm - - @property - def directed_class_f1(self): - norm = self.directed_class_precision + self.directed_class_recall - norm[norm == 0] = 1 - return 2 * self.directed_class_precision * self.directed_class_recall / norm - - @property - def directed_macro_precision(self): - return numpy.sum(self.directed_class_precision * self.mask) / self.mask.sum() - - @property - def directed_macro_recall(self): - return numpy.sum(self.directed_class_recall * self.mask) / self.mask.sum() - - @property - def directed_macro_f1(self): - return numpy.sum(self.directed_class_f1 * self.mask) / self.mask.sum() - - ############################ - # Undirected macro metrics # - ############################ - - @property - def undirected_class_precision(self): - norm = self.base_confusion.sum(1) - norm[norm == 0] = 1 - return self.base_confusion.diagonal() / norm - - @property - def undirected_class_recall(self): - norm = self.base_confusion.sum(0) - norm[norm == 0] = 1 - return self.base_confusion.diagonal() / norm - - @property - def undirected_class_f1(self): - norm = self.undirected_class_precision + self.undirected_class_recall - norm[norm == 0] = 1 - return 2 * self.undirected_class_precision * self.undirected_class_recall / norm - - @property - def undirected_macro_precision(self): - return numpy.sum(self.undirected_class_precision * self.base_mask) / self.base_mask.sum() - - @property - def undirected_macro_recall(self): - return numpy.sum(self.undirected_class_recall * self.base_mask) / self.base_mask.sum() - - @property - def undirected_macro_f1(self): - return numpy.sum(self.undirected_class_f1 * self.base_mask) / self.base_mask.sum() - - ############################### - # Half-directed macro metrics # - ############################### - - @property - def half_directed_class_precision(self): - norm = self.base_confusion.sum(1) - norm[norm == 0] = 1 - return self.confusion.diagonal().dot(self.base_transition) / norm - - @property - def half_directed_class_recall(self): - norm = self.base_confusion.sum(0) - norm[norm == 0] = 1 - return self.confusion.diagonal().dot(self.base_transition) / norm - - @property - def half_directed_class_f1(self): - norm = self.half_directed_class_precision + self.half_directed_class_recall - norm[norm == 0] = 1 - return 2 * self.half_directed_class_precision * self.half_directed_class_recall / norm - - @property - def half_directed_macro_precision(self): - return numpy.sum(self.half_directed_class_precision * self.base_mask) / self.base_mask.sum() - - @property - def half_directed_macro_recall(self): - return numpy.sum(self.half_directed_class_recall * self.base_mask) / self.base_mask.sum() - - @property - def half_directed_macro_f1(self): - return numpy.sum(self.half_directed_class_f1 * self.base_mask) / self.base_mask.sum() - - ################# - # Micro metrics # - ################# - - @property - def directed_micro_precision(self): - norm = numpy.sum(self.confusion.sum(1) * self.mask) - return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm - - @property - def directed_micro_recall(self): - norm = numpy.sum(self.confusion.sum(0) * self.mask) - return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm - - @property - def directed_micro_f1(self): - norm = self.directed_micro_precision + self.directed_micro_recall - return 0 if norm == 0 else 2 * (self.directed_micro_precision * self.directed_micro_recall) / norm - - @property - def half_directed_micro_precision(self): - norm = numpy.sum(self.confusion.sum(1) * self.mask) - return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm - - @property - def half_directed_micro_recall(self): - norm = numpy.sum(self.confusion.sum(0) * self.mask) - return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm - - @property - def half_directed_micro_f1(self): - norm = self.half_directed_micro_precision + self.half_directed_micro_recall - return 0 if norm == 0 else 2 * (self.half_directed_micro_precision * self.half_directed_micro_recall) / norm - - @property - def undirected_micro_precision(self): - norm = numpy.sum(self.base_confusion.sum(1) * self.base_mask) - return 0 if norm == 0 else numpy.sum(self.base_confusion.diagonal() * self.base_mask) / norm - - @property - def undirected_micro_recall(self): - norm = numpy.sum(self.base_confusion.sum(0) * self.base_mask) - return 0 if norm == 0 else numpy.sum(self.base_confusion.diagonal() * self.base_mask) / norm - - @property - def undirected_micro_f1(self): - norm = self.undirected_micro_precision + self.undirected_micro_recall - return 0 if norm == 0 else 2 * (self.undirected_micro_precision * self.undirected_micro_recall) / norm diff --git a/fsre/model/mtb_classifier.py b/fsre/model/mtb_classifier.py @@ -1,58 +0,0 @@ -import torch -import torch.nn as nn -import transformers - - -class Classifier(nn.Module): - """ - Transformer classifier from Soares et al. - - Correspond to the left part of each subfigure of Figure 2 (Deep Transformer Encoder and the green layer above). - - Config: - bert_model: The version of BERT to use (e.g. bert-large-uncased). - post_transformer_layer: The transformation applied after BERT (must be "linear" or "layer_norm") - """ - - def __init__(self, config, tokenizer): - """ - Instantiate a Soares et al. classifier. - - Args: - config: global config object - tokenizer: tokenizer used to create the vocabulary - """ - - super().__init__() - - self.config = config - self.tokenizer = tokenizer - - if self.config.get("load"): - bert_config = transformers.BertConfig.from_pretrained(self.config.bert_model) - bert_config.vocab_size = len(tokenizer) - self.bert = transformers.BertModel(bert_config) - else: - self.bert = transformers.BertModel.from_pretrained(self.config.bert_model) - self.bert.resize_token_embeddings(len(tokenizer)) - - if self.config.post_transformer_layer == "linear": - self.post_transformer = nn.Linear( - in_features=self.output_size, - out_features=self.output_size) - elif self.config.post_transformer_layer == "layer_norm": - self.post_transformer = torch.nn.LayerNorm(self.output_size) - else: - assert(False) - - @property - def output_size(self): - return self.bert.config.hidden_size * 2 - - def forward(self, inputs): - bert_out = self.bert(inputs["text"], attention_mask=inputs["mask"])[0] - batch_ids = torch.arange(bert_out.shape[0], device=bert_out.device, dtype=torch.int64) - e1_out = bert_out[batch_ids, inputs["e1_pos"]] - e2_out = bert_out[batch_ids, inputs["e2_pos"]] - sentence = torch.cat((e1_out, e2_out), dim=1) - return self.post_transformer(sentence) diff --git a/fsre/model/mtb_supervised.py b/fsre/model/mtb_supervised.py @@ -1,39 +0,0 @@ -import torch -import torch.nn as nn -import transformers - -from fsre.model.mtb_classifier import Classifier - - -class Model(nn.Module): - """ - Supervised model from Soares et al. - - Correspond to the left subfigure of Figure 2. - """ - - def __init__(self, config, tokenizer, relation_dictionary): - """ - Instantiate a Soares et al. supervised model. - - Args: - config: global config object - tokenizer: tokenizer used to create the vocabulary - relation_dictionary: dictionary of all relations - """ - - super().__init__() - - self.config = config - self.tokenizer = tokenizer - self.relation_dictionary = relation_dictionary - - self.classifier = Classifier(config, tokenizer) - self.relation_classifier = nn.Linear( - in_features=self.classifier.output_size, - out_features=len(relation_dictionary), - bias=False) - - def forward(self, inputs): - latent = self.classifier(inputs) - return self.relation_classifier(latent) diff --git a/fsre/train.py b/fsre/train.py @@ -1,353 +0,0 @@ -import sys -import os -import math -import time -import contextlib -import multiprocessing -import signal -import logging - -import tqdm -import torch -import transformers - -import fsre - -logger = logging.getLogger(__name__) - - -class Trainer: - """ - Train a model. - - Config: - Model: the model class to use for training - batch_size: the number of samples in the batch of data loaded - bert_model: the model of transformer to use - dataset_name: name of the dataset to load - deterministic: run in deterministic mode - early_stopping_patience: how many epoch to train after best validation score has been reached - learning_rate: learning rate - max_epoch: maximum number of epoch - no_initial_validation: do not run evaluation on the valid dataset before first epoch - seed: the seed for the random number generator - sort_per_shuffle_bucket: the number of sort buckets in a shuffle bucket - test_output: path to a file where the test predictions will be written - true_batch_size: the actual number of sample in a batch, the number of sample seen before a backward (must be a multiple of batch_size) - validation_metric: metric used for early stopping - """ - - def __init__(self, config, logdir, state_dicts=None): - self.config = config - self.logdir = logdir - self.state_dicts = state_dicts - - def run(self): - self.info() - self.log_patch() - self.initialize_rng() - self.prepare_dataset() - self.build_model() - self.count_parameters() - self.setup_optimizer() - self.hook_signals() - self.train() - - def environment_check(self): - python_version = '.'.join(map(str, sys.version_info[:3])) - torch_version = torch.__version__ - cuda_available = torch.cuda.is_available() - - logger.info(f"python version {python_version}") - logger.info(f"torch version {torch_version}") - logger.info(f"cuda available {cuda_available}") - - def problem(str): - return f"\033[1m\033[31m{str}\033[0m" - - if sys.version_info < (3, 7): - python_version = problem(python_version) - if list(map(int, torch_version.split('.'))) < [1, 3]: - torch_version = problem(torch_version) - if not cuda_available: - cuda_available = problem(cuda_available) - - print(f"python version: {python_version}, torch version: {torch_version}, cuda available: {cuda_available}") - - def detect_gpus(self): - count = torch.cuda.device_count() - - if count == 0: - print(f"\033[1m\033[31mNo GPU available\033[0m") - logger.warning("no GPU available") - self.device = torch.device("cpu") - else: - self.device = torch.device("cuda:0") - - for i in range(count): - gp = torch.cuda.get_device_properties(i) - print(f"GPU{i}: \033[33m{gp.name}\033[0m (Mem: {gp.total_memory/2**30:.2f}GiB CC: {gp.major}.{gp.minor})") - logger.info(f"GPU{i} {gp.name} {gp.total_memory} {gp.major}.{gp.minor}") - - def info(self): - if self.logdir is None: - print(f"logdir is \033[1m\033[31mnot set\033[0m, log messages will be discarded") - else: - print(f"logdir is \033[1m\033[33m{self.logdir}\033[0m") - - self.environment_check() - self.detect_gpus() - print("") - - print("\033[1m\033[33mConfiguration\033[0m") - fsre.utils.print_dict(self.config) - fsre.utils.log_dict(logging.getLogger("config"), self.config) - print("") - - def log_patch(self): - version = fsre.utils.get_repo_version() - logger.info(f"repository_version {version}") - if version == "release": - print(f"\033[41mRelease version\033[0m\n") - elif version.endswith('+'): - print(f"\033[31mUncommited changes detected, saving patch to logdir.\033[0m\n") - suffix = "" - if self.state_dicts: - suffix = time.strftime("%FT%H:%M:%S") - fsre.utils.save_patch(self.logdir / f"patch{suffix}") - - def initialize_rng(self): - if self.state_dicts: - torch.random.set_rng_state(self.state_dicts["torch_rng"]) - assert(("cuda_rng" in self.state_dicts) == torch.cuda.is_available()) - if "cuda_rng" in self.state_dicts: - torch.cuda.random.set_rng_state_all(self.state_dicts["cuda_rng"]) - else: - torch.manual_seed(self.config.seed) - - if self.config.get("deterministic"): - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - def prepare_dataset(self): - data_dir = fsre.utils.DATA_PATH / self.config.dataset_name / self.config.bert_model - self.relation_dictionary = fsre.data.RelationDictionary(path=data_dir / "relations") - self.tokenizer = transformers.BertTokenizer.from_pretrained(data_dir / "tokenizer") - - self.dataset = {} - self.iterator = {} - - for dataset in ["train", "valid", "test"]: - if dataset == "train": - kwargs = {"rng": self.state_dicts["train_rng"]} if self.state_dicts else {} - - self.dataset[dataset] = fsre.data.RelationExtractionDataset( - self.config, - data_dir / f"{dataset}.npy", - pad=self.tokenizer.pad_token_id, - evaluation=(dataset != "train"), - **kwargs) - - self.iterator[dataset] = lambda dataset=dataset: torch.utils.data.DataLoader( - dataset=self.dataset[dataset], - batch_size=self.config.batch_size, - num_workers=self.config.sort_per_shuffle_bucket, - pin_memory=(self.device.type == "cuda")) - - def build_model(self): - self.model = self.config.Model(self.config, self.tokenizer, self.relation_dictionary) - - if self.state_dicts: - self.model.load_state_dict(self.state_dicts["model"]) - - self.loss = torch.nn.CrossEntropyLoss() - if self.device.type == "cuda": - self.model.to(self.device) - self.loss.to(self.device) - - def count_parameters(self): - total = 0 - for parameter in self.model.parameters(): - total += parameter.shape.numel() - print(f"\033[33mNumber of parameters: {total:,}\033[0m") - - def setup_optimizer(self): - self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.learning_rate) - if self.state_dicts: - self.optimizer.load_state_dict(self.state_dicts["optimizer"]) - - def hook_signals(self): - self.interrupted = False - - def handler(sig, frame): - if multiprocessing.current_process().name != "MainProcess": - return - - print("\n\033[31mInterrupted, training will stop at the end of this epoch.\n\033[1mNEXT ^C WILL KILL THE PROCESS!\033[0m\n", file=sys.stderr) - self.interrupted = True - signal.signal(signal.SIGINT, signal.SIG_DFL) - - signal.signal(signal.SIGINT, handler) - - def eval_context(self, dataset): - self.test_output_file = None - if dataset == "test" and self.config.get("test_output"): - self.test_output_file = open(self.config.test_output, 'w') - return self.test_output_file - return contextlib.nullcontext() - - def eval_handle_predictions(self, ids, predictions): - if self.test_output_file is None: - return - ids = ids.tolist() - predictions = predictions.argmax(1).tolist() - for id, prediction in zip(ids, predictions): - prediction = self.relation_dictionary.decode(prediction) - self.test_output_file.write(f"{id}\t{prediction}\n") - - def evaluate(self, dataset): - loop = tqdm.tqdm( - iterable=self.iterator[dataset](), - desc=f"Epoch {self.epoch:2} {dataset:5}", - unit="samples", - unit_scale=self.config.batch_size, - total=math.ceil(len(self.dataset[dataset]) / self.config.batch_size), - leave=False) - - self.model.eval() - output_prediction_to_file = True - with torch.no_grad(), self.eval_context(dataset): - has_target = False - scorer = fsre.metrics.Metrics(self.relation_dictionary) - correct_prediction = 0 - for batch in loop: - batch = {key: value.to(self.device) for key, value in batch.items()} - - # Pop the target to ensure it's not used by the model - if "relation" in batch: - target = batch.pop("relation") - has_target = True - predictions = self.model(batch) - self.eval_handle_predictions(batch["id"], predictions) - if has_target: - scorer.update(predictions, target) - loop.set_postfix(**scorer.summary, refresh=False) - - if has_target: - print(f"Epoch {self.epoch} {dataset:5} accuracy: {scorer.accuracy*100:8.4f}% Half-directed Macro F1: {scorer.half_directed_macro_f1*100:8.4f}% (P: {scorer.half_directed_macro_precision*100:8.4f}% R: {scorer.half_directed_macro_recall*100:8.4f}%) NLL: {scorer.nll:8.4f}") - for metric, value in scorer.all.items(): - logger.info(f"epoch {self.epoch} {dataset} {metric} {value}") - return getattr(scorer, self.config.validation_metric) - return None - - def save(self, path): - state_dicts = { - "logdir": self.logdir, - "config": self.config, - "model": self.model.state_dict(), - "optimizer": self.optimizer.state_dict(), - "train_rng": self.dataset["train"].state_dict(), - "torch_rng": torch.random.get_rng_state(), - "epoch": self.epoch, - "best_epoch": self.best_epoch, - "best_eval": self.best_eval, - } - if torch.cuda.is_available(): - state_dicts["cuda_rng"] = torch.cuda.random.get_rng_state_all() - torch.save(state_dicts, path) - - def train(self): - self.epoch = 0 - self.best_epoch = 0 - self.best_eval = -float("inf") - - if self.state_dicts: - self.epoch = self.state_dicts["epoch"] - self.best_epoch = self.state_dicts["best_epoch"] - self.best_eval = self.state_dicts["best_eval"] - - best_path = self.logdir / "best" - batch_per_epoch = int(math.ceil(len(self.dataset["train"]) / self.config.batch_size)) - assert(self.config.true_batch_size % self.config.batch_size == 0) - data_batch_per_true_batch = self.config.true_batch_size // self.config.batch_size - - if not self.config.get("no_initial_validation"): - self.best_eval = self.evaluate("valid") - - for self.epoch in range(self.epoch+1, self.config.max_epoch+1): - if self.interrupted: - break - - loop = tqdm.tqdm( - iterable=self.iterator["train"](), - desc=f"Epoch {self.epoch} train", - unit="samples", - unit_scale=self.config.batch_size, - total=math.ceil(len(self.dataset["train"]) / self.config.batch_size), - leave=False) - - self.model.train() - self.optimizer.zero_grad() - total_loss = 0 - total_sample = 0 - for batch_id, batch in enumerate(loop): - batch = {key: value.to(self.device) for key, value in batch.items()} - - # Pop the target to ensure it's not used by the model - target = batch.pop("relation") - - output = self.model(batch) - loss = self.loss(output, target) - loss.backward() - - total_loss += loss.item() - total_sample += target.shape[0] - loop.set_postfix(loss=f"{total_loss / total_sample:.2f}", refresh=False) - - if batch_id % data_batch_per_true_batch == data_batch_per_true_batch - 1: - self.optimizer.step() - self.optimizer.zero_grad() - - print(f"Epoch {self.epoch} train mean loss: {total_loss / total_sample:8.4f}") - logger.info(f"epoch {self.epoch} train mean_loss {total_loss / total_sample}") - - self.save(self.logdir / "checkpoint.new") - os.rename(self.logdir / "checkpoint.new", self.logdir / "checkpoint") - - candidate = self.evaluate("valid") - if candidate > self.best_eval: - self.best_epoch = self.epoch - self.best_eval = candidate - self.save(best_path) - logger.info(f"Model saved to {best_path}") - elif self.epoch - self.best_epoch > self.config.get("early_stopping_patience", self.config.max_epoch): - break - - if self.best_eval != -float("inf") and best_path.exists(): - print(f"Loading best model from {best_path}…", end="", flush=True) - self.model.load_state_dict(torch.load(best_path)["model"]) - print(" done") - logger.info(f"{best_path} loaded for evaluation on test set") - - self.evaluate("test") - - -if __name__ == "__main__": - fsre.utils.fix_transformers_logging_handler() - config = fsre.utils.parse_args() - - state_dicts = None - if config.get("load"): - state_dicts = torch.load(config.load) - logdir = state_dicts["logdir"] - if config.get("reuse_config"): - config = state_dicts["config"] - else: - logdir = fsre.utils.logdir_name("FSRE") - assert(not logdir.exists()) - logdir.mkdir() - - logfile = logdir / "log" - logging.basicConfig(format="%(asctime)s\t%(levelname)s:%(name)s:%(message)s", filename=logfile, filemode='a', level=logging.INFO) - - Trainer(config, logdir, state_dicts).run() diff --git a/fsre/utils.py b/fsre/utils.py @@ -1,143 +0,0 @@ -import os -import sys -import types -import importlib -import subprocess -import logging -import time -import hashlib -import pathlib - - -def import_environment(name, cast=str): - try: - globals()[name] = cast(os.environ[name]) - except KeyError: - print(f"ERROR: {name} environment variable is not set.", - file=sys.stderr) - sys.exit(1) - - -import_environment("DATA_PATH", pathlib.Path) -import_environment("LOG_PATH", pathlib.Path) - - -class dotdict(dict): - def __getattr__(self, name): - if name not in self: - raise AttributeError(f"Config key {name} not found") - return dotdict(self[name]) if type(self[name]) is dict else self[name] - __setattr__ = dict.__setitem__ - __delattr__ = dict.__delitem__ - - -def eval_arg(config, arg): - if '=' in arg: - key, value = arg.split('=', maxsplit=1) - value = eval(value, config) - else: - key, value = arg, True - path = key.split('.') - for d in path[:-1]: - config = config[d] - config[path[-1]] = value - config.pop("__builtins__", None) - - -def import_arg(config, arg): - if arg.endswith(".py"): - arg = arg[:-3].replace('/', '.') - module = importlib.import_module(arg) - for key, value in vars(module).items(): - if key not in module.__builtins__ \ - and not key.startswith("__") \ - and not isinstance(value, types.ModuleType): - config[key] = value - - -def parse_args(): - config = {} - for arg in sys.argv[1:]: - if arg.startswith("--"): - eval_arg(config, arg[2:]) - else: - import_arg(config, arg) - return dotdict(config) - - -def map_dict(output, input, depth=0): - for key, value in input.items(): - indent = '\t'*depth - output(f"{indent}{key}:") - if isinstance(value, dict): - output('\n') - map_dict(output, value, depth+1) - else: - output(f" {value}\n") - - -def print_dict(input): - map_dict(lambda x: print(x, end=""), input) - - -def log_dict(logger, input): - class log: - buf = "" - - def __call__(self, x): - self.buf += x - if self.buf.endswith('\n'): - logger.info(self.buf[:-1]) - self.buf = "" - map_dict(log(), input) - - -def get_repo_version(): - repo_dir = pathlib.Path(__file__).parents[0] - result = subprocess.run(["hg", "id", "-i"], - stdout=subprocess.PIPE, - stderr=subprocess.DEVNULL, - encoding="utf-8", - cwd=repo_dir) - - if result.returncode != 0: - return "release" - return result.stdout.rstrip() - - -def experiment_name(name): - args = ' '.join(sys.argv[1:]) - version = get_repo_version() - stime = time.strftime("%FT%H:%M:%S") - return f"{name} {version} {args} {stime}" - - -def logdir_name(name): - subdir = experiment_name(name).replace('/', '_') - if len(subdir) > 255: - sha1 = hashlib.sha1(subdir.encode("utf-8")).digest().hex()[:16] - subdir = subdir[:255-17] + ' ' + sha1 - return LOG_PATH / subdir - - -def fix_transformers_logging_handler(): - """ - The transformers package from huggingface install its own logger on import, - I don't want it. - """ - logger = logging.getLogger() - for handler in logger.handlers: - logger.removeHandler(handler) - - -def save_patch(outpath): - repo_dir = pathlib.Path(__file__).parents[0] - - with open(outpath, "w") as outfile: - result = subprocess.run(["hg", "diff"], - stdout=outfile, - stderr=subprocess.DEVNULL, - encoding="utf-8", - cwd=repo_dir) - - assert(result.returncode == 0) diff --git a/gbure/__init__.py b/gbure/__init__.py diff --git a/gbure/config/contrastive_alignment.py b/gbure/config/contrastive_alignment.py @@ -0,0 +1,53 @@ +from gbure.model.contrastive_alignment import Model +from gbure.model.fewshot import Model as EvalModel +from torch.optim import Adam as Optimizer +from torch.optim.lr_scheduler import LinearLR as Scheduler + + +dataset_name = "T-REx" +graph_name = "T-REx" +unsupervised = "triplet" + +eval_dataset_name = "FewRel" +valid_name = "7def1330ba9527d6" +shot = 1 +way = 5 + +margin = 1 +neighborhood_size = 3 +filter_empty_neighborhood = True +sinkhorn_blur = 0.05 + +# Necessary to make a distance +linguistic_similarity = "euclidean" +undefined_poison_whole_meta = True + +# From section 4.3 +blank_probability = 0.7 + +# From section 5 +transformer_model = "bert-base-cased" +max_epoch = 10 +sample_per_epoch = 100000 +learning_rate = 3e-5 +accumulated_batch_size = 256 +clip_gradient = 1 + +# Guessed +post_transformer_layer = "linear" # Maybe we should change this depending on the subsequent task? +max_sentence_length = 100 # Maybe should be 40 (from footnote 2, guessed from ACL slides) +language_model_weight = 0 +edge_sampling = "uniform-inverse degree" + +# From BERT +mlm_probability = 0.15 +mlm_masked_probability = 0.8 +mlm_random_probability = 0.1 + +# Implementation details +seed = 0 +amp = True +initial_grad_scale = 1 +batch_size = 2 +eval_batch_size = 1 +workers = 2 diff --git a/gbure/config/gcn_mtb.py b/gbure/config/gcn_mtb.py @@ -0,0 +1,63 @@ +from gbure.model.matching_the_blanks import Model +from gbure.model.fewshot import Model as EvalModel +from torch.optim import Adam as Optimizer +from torch.optim.lr_scheduler import LinearLR as Scheduler + + +dataset_name = "T-REx" +graph_name = "T-REx" +unsupervised = "mtb" + +eval_dataset_name = "FewRel" +valid_name = "7def1330ba9527d6" +shot = 1 +way = 5 + +# From Section 4.1 +linguistic_similarity = "dot" +undefined_poison_whole_meta = True + +# Observed to be better +latent_metric_scale = "standard" +latent_dot_mean = 1067.65 +latent_dot_std = 111.17 + +# GCN +neighborhood_size = 3 +gcn_aggregator = "mean" + +# From Section 4.3 +blank_probability = 0.7 + +# From Section 5 +transformer_model = "bert-base-cased" +sample_per_epoch = 100000 +learning_rate = 3e-5 +accumulated_batch_size = 2048 + +# Stated to be 10 in Section 5, but found 5 was better on T-REx dataset. +max_epoch = 5 + +# From BERT +mlm_probability = 0.15 +mlm_masked_probability = 0.8 +mlm_random_probability = 0.1 + +# Guessed +# post_transformer_layer might need to be changed depending on the subsequent task +# the "layer_norm" gives results within expectations for non-finetuned few-shot. +max_sentence_length = 100 # Maybe should be 40 (from footnote 2, guessed from ACL slides) +language_model_weight = 1 +edge_sampling = "uniform-inverse degree" +clip_gradient = 1 + +strong_negative_probability = 0.5 +weak_negative_probability = 0.0 + +# Implementation details +seed = 0 +amp = True +initial_grad_scale = 1 +batch_size = 2 +eval_batch_size = 1 +workers = 2 diff --git a/gbure/config/mtb.py b/gbure/config/mtb.py @@ -0,0 +1,57 @@ +from gbure.model.matching_the_blanks import Model +from gbure.model.fewshot import Model as EvalModel +from torch.optim import Adam as Optimizer +from torch.optim.lr_scheduler import LinearLR as Scheduler + + +dataset_name = "T-REx" +unsupervised = "mtb" + +eval_dataset_name = "FewRel" +valid_name = "7def1330ba9527d6" +shot = 1 +way = 5 + +# From Section 4.1 +linguistic_similarity = "dot" + +# Observed to be better +latent_metric_scale = "standard" +latent_dot_mean = 1067.65 +latent_dot_std = 111.17 + +# From Section 4.3 +blank_probability = 0.7 + +# From Section 5 +transformer_model = "bert-base-cased" +sample_per_epoch = 100000 +learning_rate = 3e-5 +accumulated_batch_size = 2048 + +# Stated to be 10 in Section 5, but found 5 was better on T-REx dataset. +max_epoch = 5 + +# From BERT +mlm_probability = 0.15 +mlm_masked_probability = 0.8 +mlm_random_probability = 0.1 + +# Guessed +# post_transformer_layer might need to be changed depending on the subsequent task +# the "layer_norm" gives results within expectations for non-finetuned few-shot. +max_sentence_length = 100 # Maybe should be 40 (from footnote 2, guessed from ACL slides) +language_model_weight = 1 +edge_sampling = "uniform-inverse degree" +clip_gradient = 1 + +strong_negative_probability = 0.5 +weak_negative_probability = 0.0 + +# Implementation details +seed = 0 +amp = True +initial_grad_scale = 1 +batch_size = 8 +eval_batch_size = 2 +workers = 8 diff --git a/gbure/config/nonparametric.py b/gbure/config/nonparametric.py @@ -0,0 +1,32 @@ +from gbure.model.fewshot import Model +from torch.optim import SGD as Optimizer + + +max_epoch = 0 +dataset_name = "FewRel" +graph_name = "T-REx" +valid_name = "7def1330ba9527d6" +shot = 1 +way = 5 + +transformer_model = "bert-base-cased" +post_transformer_layer = "none" +learning_rate = 0 +accumulated_batch_size = 256 + +neighborhood_size = 3 +linguistic_similarity = "dot" +topological_weight = 0.2 +linguistic_weight = 1 +undefined_poison_whole_meta = True + +validation_metric = "accuracy" +latent_metric_scale = "standard" +latent_dot_mean = 1067.65 +latent_dot_std = 111.17 + +# Implementation details +seed = 0 +amp = True +batch_size = 2 +workers = 2 diff --git a/gbure/config/soares_fewrel.py b/gbure/config/soares_fewrel.py @@ -0,0 +1,30 @@ +from gbure.model.fewshot import Model +from torch.optim import SGD as Optimizer + + +dataset_name = "FewRel" +shot = 1 +way = 5 +linguistic_similarity = "dot" + +# From Table 1 +transformer_model = "bert-base-cased" +post_transformer_layer = "layer_norm" +max_epoch = 10 +learning_rate = 1e-4 +accumulated_batch_size = 256 + +# Guessed +validation_metric = "accuracy" +early_stopping_patience = 2 +latent_metric_scale = "standard" +latent_dot_mean = 1067.65 +latent_dot_std = 111.17 +clip_gradient = 1 + +# Implementation details +seed = 0 +amp = True +initial_grad_scale = 1 +batch_size = 2 +workers = 8 diff --git a/gbure/config/soares_kbp37.py b/gbure/config/soares_kbp37.py @@ -0,0 +1,23 @@ +from gbure.model.supervised import Model +from torch.optim import Adam as Optimizer + + +dataset_name = "KBP37" +linguistic_similarity = "dot" + +# From Table 1 +transformer_model = "bert-base-cased" +post_transformer_layer = "linear" +max_epoch = 10 +learning_rate = 3e-5 +accumulated_batch_size = 64 + +# Guessed +validation_metric = "half_directed_macro_f1" +early_stopping_patience = 2 +clip_gradient = 1 + +# Implementation details +seed = 0 +batch_size = 2 +workers = 8 diff --git a/gbure/config/soares_semeval.py b/gbure/config/soares_semeval.py @@ -0,0 +1,23 @@ +from gbure.model.supervised import Model +from torch.optim import Adam as Optimizer + + +dataset_name = "SemEval 2010 Task 8" +linguistic_similarity = "dot" + +# From Table 1 +transformer_model = "bert-base-cased" +post_transformer_layer = "layer_norm" +max_epoch = 10 +learning_rate = 3e-5 +accumulated_batch_size = 64 + +# Guessed +validation_metric = "half_directed_macro_f1" +early_stopping_patience = 2 +clip_gradient = 1 + +# Implementation details +seed = 0 +batch_size = 8 +workers = 8 diff --git a/gbure/data/__init__.py b/gbure/data/__init__.py diff --git a/gbure/data/batcher.py b/gbure/data/batcher.py @@ -0,0 +1,166 @@ +from typing import Any, Dict, List, Tuple +import collections + +import torch + + +# Must be kept prefix-sorted! +# (prefix, list depth) +FEATURE_PREFIXES: List[Tuple[str, int]] = [ + ("query_e1_neighborhood_", 1), + ("query_e2_neighborhood_", 1), + ("candidates_e1_neighborhood_", 3), + ("candidates_e2_neighborhood_", 3), + ("first_e1_neighborhood_", 1), + ("first_e2_neighborhood_", 1), + ("second_e1_neighborhood_", 1), + ("second_e2_neighborhood_", 1), + ("third_e1_neighborhood_", 1), + ("third_e2_neighborhood_", 1), + ("query_", 0), + ("candidates_", 2), + ("first_", 0), + ("second_", 0), + ("third_", 0), + ("", 0)] + + +class Batcher: + """ + Batch a group of sample together. + + Two new features are derived from the "text": its length and a mask. + """ + def __init__(self, pad_value: int) -> None: + """ Initialize a Batcher, using the provided value to pad text. """ + self.pad_value: int = pad_value + + def add_length_field(self, batch: Dict[str, Any], prefix: str, depth: int) -> None: + """ Add the length field for the given prefix. """ + text: List[Any] = batch[f"{prefix}text"] + batch_size: int = len(text) + + if depth == 0: + # text is a list of sentences + lengths: torch.Tensor = torch.empty((batch_size,), dtype=torch.int64) + for b, sentence in enumerate(text): + lengths[b] = sentence.shape[0] + elif depth == 1: + # text is a list of list of sentences (each sample contains several candidates) + size: int = len(text[0]) + lengths: torch.Tensor = torch.empty((batch_size, size), dtype=torch.int64) + for b, sample in enumerate(text): + for i, sentence in enumerate(sample): + lengths[b, i] = sentence.shape[0] + elif depth == 2: + # text is a list of list of list of sentences (each sample contains several candidates) + way: int = len(text[0]) + shot: int = len(text[0][0]) + lengths: torch.Tensor = torch.empty((batch_size, way, shot), dtype=torch.int64) + for b, sample in enumerate(text): + for w, candidates in enumerate(sample): + for s, candidate in enumerate(candidates): + lengths[b, w, s] = candidate.shape[0] + elif depth == 3: + # text is a list of list of list of list of sentences (each sample contains several candidates' neighborhoods) + way: int = len(text[0]) + shot: int = len(text[0][0]) + size: int = len(text[0][0][0]) + lengths: torch.Tensor = torch.empty((batch_size, way, shot, size), dtype=torch.int64) + for b, sample in enumerate(text): + for w, candidates in enumerate(sample): + for s, candidate in enumerate(candidates): + for n, neighbor in enumerate(candidate): + lengths[b, w, s, n] = neighbor.shape[0] + + batch[f"{prefix}length"] = lengths + + def process_text(self, batch: Dict[str, Any], prefix: str, depth: int, key: str) -> None: + """ Build mask and text batch by padding sentences. """ + in_text: List[Any] = batch[f"{prefix}{key}"] + if isinstance(batch[f"{prefix}length"], list): + self.add_length_field(batch, prefix, depth) + max_seq_len: int = max(batch[f"{prefix}length"].max(), 1) + batch_size: int = len(in_text) + + if depth == 0: + # text is a list of sentences + text: torch.Tensor = torch.empty((batch_size, max_seq_len), dtype=torch.int64) + mask: torch.Tensor = torch.empty((batch_size, max_seq_len), dtype=torch.bool) + for b, sentence in enumerate(in_text): + text[b, :sentence.shape[0]] = sentence + text[b, sentence.shape[0]:] = self.pad_value + mask[b, :sentence.shape[0]] = 1 + mask[b, sentence.shape[0]:] = 0 + elif depth == 1: + # text is a list of list of sentences (each sample contains several candidates) + # In this case, we are not sure the tensor is full (some neighborhoods might be of different sizes or even empty) + size: int = len(in_text[0]) + text: torch.Tensor = torch.empty((batch_size, size, max_seq_len), dtype=torch.int64) + mask: torch.Tensor = torch.zeros((batch_size, size, max_seq_len), dtype=torch.bool) + for b, samples in enumerate(in_text): + for i, sentence in enumerate(samples): + text[b, i, :sentence.shape[0]] = sentence + text[b, i, sentence.shape[0]:] = self.pad_value + mask[b, i, :sentence.shape[0]] = 1 + elif depth == 2: + # text is a list of list of list of sentences (each sample contains several candidates) + # In this case, we are sure the tensor is full (all n way have the save k shots) + way: int = len(in_text[0]) + shot: int = len(in_text[0][0]) + text: torch.Tensor = torch.empty((batch_size, way, shot, max_seq_len), dtype=torch.int64) + mask: torch.Tensor = torch.empty((batch_size, way, shot, max_seq_len), dtype=torch.bool) + for b, samples in enumerate(in_text): + for w, candidates in enumerate(samples): + for s, candidate in enumerate(candidates): + text[b, w, s, :candidate.shape[0]] = candidate + text[b, w, s, candidate.shape[0]:] = self.pad_value + mask[b, w, s, :candidate.shape[0]] = 1 + mask[b, w, s, candidate.shape[0]:] = 0 + elif depth == 3: + # text is a list of list of list of list of sentences (each sample contains several candidates' neighborhoods) + # In this case, we are not sure the tensor is full (some neighborhoods might be of different sizes or even empty) + way: int = len(in_text[0]) + shot: int = len(in_text[0][0]) + size: int = len(in_text[0][0][0]) + text: torch.Tensor = torch.empty((batch_size, way, shot, size, max_seq_len), dtype=torch.int64) + mask: torch.Tensor = torch.empty((batch_size, way, shot, size, max_seq_len), dtype=torch.bool) + for b, samples in enumerate(in_text): + for w, candidates in enumerate(samples): + for s, candidate in enumerate(candidates): + for n, neighbor in enumerate(candidate): + text[b, w, s, n, :neighbor.shape[0]] = neighbor + text[b, w, s, n, neighbor.shape[0]:] = self.pad_value + mask[b, w, s, n, :neighbor.shape[0]] = 1 + mask[b, w, s, n, neighbor.shape[0]:] = 0 + + batch[f"{prefix}{key}"] = text + if f"{prefix}mask" not in batch: + batch[f"{prefix}mask"] = mask + + def process_int_feature(self, batch: Dict[str, Any], prefix: str, feature: str) -> None: + """ Transform a list of integer into a torch LongTensor. """ + # TODO handle neighborhoods of different sizes + if isinstance(batch[f"{prefix}{feature}"][0], torch.Tensor): + batch[f"{prefix}{feature}"] = torch.stack(batch[f"{prefix}{feature}"]) + else: + batch[f"{prefix}{feature}"] = torch.tensor(batch[f"{prefix}{feature}"], dtype=torch.int64) + + def __call__(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]: + """ Batch the provided samples """ + batch = collections.defaultdict(list) + for sample in samples: + for key, value in sample.items(): + batch[key].append(value) + + for key in list(batch.keys()): + for prefix, depth in FEATURE_PREFIXES: + if key.startswith(prefix): + break + feature: str = key[len(prefix):] + if feature in ["text", "mlm_input", "mlm_target"]: + self.process_text(batch, prefix, depth, feature) + if feature in ["relation", "entity_positions", "entity_identifiers", "entity_degrees", "edge_identifier", "polarity", "answer", "eid"]: + self.process_int_feature(batch, prefix, feature) + + return batch diff --git a/gbure/data/dataset.py b/gbure/data/dataset.py @@ -0,0 +1,668 @@ +from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union +import collections +import pathlib +import random + +import torch +import transformers + +import gbure.data.dictionary +from gbure.data.graph import Graph +import gbure.utils + + +class SupervisedDataset(torch.utils.data.Dataset): + """ + Read a preprocessed supervised relation extraction dataset. + + A preprocessed dataset can be created from the gbure.data.prepare_* scripts. + """ + shuffleable: bool = True + + def __init__(self, config: gbure.utils.dotdict, path: pathlib.Path, tokenizer: transformers.PreTrainedTokenizer, evaluation: bool, rng: Optional[random.Random] = None, data: Optional[List[Tuple[torch.Tensor, int, int, int]]] = None) -> None: + """ Initialize a supervised dataset and load the data in RAM. """ + super().__init__() + + self.config: gbure.utils.dotdict = config + self.path: pathlib.Path = path + self.tokenizer: transformers.PreTrainedTokenizer = tokenizer + self.evaluation: bool = evaluation + if data is None: + self.load() + else: + self.data = data + + def load(self) -> None: + """ Load the dataset into RAM. """ + dstype: str + self.data: List[Tuple[torch.Tensor, int, int, int]] + dstype, self.data = torch.load(self.path) + assert(dstype == "supervised") + + def __len__(self) -> int: + """ Get the number of samples in the dataset. """ + return len(self.data) + + def __getitem__(self, index: int) -> Dict[str, Any]: + """ Get the sample at the given index. """ + sample: Dict[str, Any] = {} + sample["text"] = self.data[index][0] + sample["entity_positions"] = torch.tensor(self.data[index][1:3], dtype=torch.int64) + sample["relation"] = self.data[index][3] + return sample + + +class SampledFewShotDataset(torch.utils.data.Dataset): + """ + Read a preprocessed few shot relation extraction dataset.npy file containing samples. + + A preprocessed dataset can be created from the gbure.data.prepare_* scripts. + """ + shuffleable: bool = True + + def __init__(self, config: gbure.utils.dotdict, path: pathlib.Path, tokenizer: transformers.PreTrainedTokenizer, evaluation: bool, rng: Optional[random.Random] = None, data: Optional[List[Tuple[torch.Tensor, int, int, int, int, List[List[torch.Tensor]], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]]] = None) -> None: + """ Initialize a few shot dataset and load the samples in RAM. """ + super().__init__() + + self.config: gbure.utils.dotdict = config + self.path: pathlib.Path = path + self.tokenizer: transformers.PreTrainedTokenizer = tokenizer + self.evaluation: bool = evaluation + if data is None: + self.load() + else: + self.data = data + + def load(self) -> None: + """ Load the dataset into RAM. """ + dstype: str + self.data: List[Tuple[torch.Tensor, int, int, int, int, List[List[torch.Tensor]], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]] + dstype, self.data = torch.load(self.path) + assert(dstype == "sampled fewshot") + + def __len__(self) -> int: + """ Get the number of samples in the dataset. """ + return len(self.data) + + def __getitem__(self, index: int) -> Dict[str, Any]: + """ Get the sample at the given index. """ + sample: Dict[str, Any] = {} + sample["query_text"] = self.data[index][0] + sample["query_entity_positions"] = torch.tensor(self.data[index][1:3], dtype=torch.int64) + sample["query_entity_identifiers"] = torch.tensor(self.data[index][3:5], dtype=torch.int64) + sample["candidates_text"] = self.data[index][5] + sample["candidates_entity_positions"] = torch.stack(self.data[index][6:8], dim=2) + sample["candidates_entity_identifiers"] = torch.stack(self.data[index][8:10], dim=2) + sample["answer"] = self.data[index][10] + return sample + + +class FewShotDataset(torch.utils.data.IterableDataset): + """ + Read a preprocessed few shot relation extraction dataset.npy file. + + A preprocessed dataset can be created from the gbure.data.prepare_* + modules. + + Config: + seed: the seed for the random number generator + shot: the number of candidates per relation + way: the number of relation classes used for candidates + """ + shuffleable: bool = False # FIXME ? + + def __init__(self, config: gbure.utils.dotdict, path: pathlib.Path, tokenizer: transformers.PreTrainedTokenizer, evaluation: bool, rng: Optional[random.Random] = None, data: Optional[List[List[Tuple[torch.Tensor, int, int, int, int, int]]]] = None) -> None: + """ Initialize a few shot dataset and load the data in RAM. """ + super().__init__() + + self.config: gbure.utils.dotdict = config + self.path: pathlib.Path = path + self.tokenizer: transformers.PreTrainedTokenizer = tokenizer + self.evaluation: bool = evaluation + + if data is None: + self.load() + else: + self.data = data + self.num_relations: int = len(self.data) + self.num_samples_per_relation: int = len(self.data[0]) + + def init_seed(self, worker_id: Optional[int] = None) -> None: + """ Initialize the RNG. """ + if not self.evaluation: + seed: int = self.config.seed + worker_info = torch.utils.data.get_worker_info() + seed += worker_id if worker_id is not None else (worker_info.id if worker_info is not None else 0) + rng = random.Random(seed) + self.rng = rng + + def load(self) -> None: + """ Load the dataset into RAM. """ + dstype: str + self.data: List[List[Tuple[torch.Tensor, int, int, int, int, int]]] + dstype, self.data = torch.load(self.path) + assert(dstype == "fewshot") + + def __len__(self): + """ Get the number of samples in the dataset. """ + return self.num_relations * self.num_samples_per_relation * self.config.get("meta_per_sample", 1) + + def get_rng(self, relation: int, sentence: int) -> random.Random: + """ Get the random number generator for the given query. """ + if self.evaluation: + return random.Random(self.config.seed * len(self) + relation * self.num_samples_per_relation + sentence) + else: + return self.rng + + @staticmethod + def sample_exclude(rng: random.Random, population: int, exclude: int, size: int) -> List[int]: + """ Chooses size unique random elements from [0, population)\\{exclude}. """ + samples: List[int] = rng.sample(range(population-1), size) + return [sample + (1 if sample >= exclude else 0) for sample in samples] + + def sample_meta(self, query_relation: int, query_sid: int) -> Dict[str, Any]: + """ Build a fewshot sample from the given query. """ + rng: random.Random = self.get_rng(query_relation, query_sid) + + # positives + candidates: List[List[Tuple[int, int]]] = [[(query_relation, sid) for sid in self.sample_exclude(rng, self.num_samples_per_relation, query_sid, self.config.shot)]] + # negatives + for negative_relation in self.sample_exclude(rng, self.num_relations, query_relation, self.config.way-1): + candidates.append([(negative_relation, sid) for sid in rng.sample(range(self.num_samples_per_relation), self.config.shot)]) + + order: List[int] = list(range(self.config.way)) + rng.shuffle(order) + candidates = [candidates[i] for i in order] + answer = order.index(0) + + meta: Dict[str, Any] = {} + meta[f"query_text"] = self.data[query_relation][query_sid][0] + meta[f"query_entity_positions"] = torch.tensor(self.data[query_relation][query_sid][1:3], dtype=torch.int64) + meta[f"query_relation"] = self.data[query_relation][query_sid][3] + meta[f"query_entity_identifiers"] = torch.tensor(self.data[query_relation][query_sid][4:6], dtype=torch.int64) + meta[f"candidates_text"] = [[self.data[shot_relation][shot_sid][0] for shot_relation, shot_sid in way] for way in candidates] + meta[f"candidates_entity_positions"] = torch.tensor([[self.data[shot_relation][shot_sid][1:3] for shot_relation, shot_sid in way] for way in candidates], dtype=torch.int64) + meta[f"candidates_relation"] = torch.tensor([[self.data[shot_relation][shot_sid][3] for shot_relation, shot_sid in way] for way in candidates], dtype=torch.int64) + meta[f"candidates_entity_identifiers"] = torch.tensor([[self.data[shot_relation][shot_sid][4:6] for shot_relation, shot_sid in way] for way in candidates], dtype=torch.int64) + meta["answer"] = answer + return meta + + def __iter__(self) -> Iterator[Dict[str, Any]]: + """ Generate samples from the dataset. """ + self.order: List[Tuple[int, int]] = [(relation, sid) for relation in range(self.num_relations) for sid in range(self.num_samples_per_relation)] + if not self.evaluation: + self.rng.shuffle(self.order) + + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + worker_modulo: int = 1 + worker_residue: int = 0 + else: + worker_modulo: int = worker_info.num_workers + worker_residue: int = worker_info.id + + mps = self.config.get("meta_per_sample", 1) + for index, (relation, sid) in enumerate(self.order): + for j in range(mps): + if (index*mps+j) % worker_modulo == worker_residue: + yield self.sample_meta(relation, sid) + + +class UnsupervisedDataset(torch.utils.data.IterableDataset): + """ + Read a preprocessed unsupervised relation extraction dataset. + + A preprocessed dataset can be created from the gbure.data.prepare_* scripts. + + Config: + blank_probability: the probability to replace an entity with <blank/> + edge_sampling: the sampling strategy to avoid (or not) popular entities + sample_per_epoch: the number of sample in an epoch + seed: the seed for the random number generator + """ + shuffleable: bool = False + + def __init__(self, config: gbure.utils.dotdict, path: Optional[pathlib.Path], tokenizer: transformers.PreTrainedTokenizer, evaluation: bool, rng: Optional[random.Random] = None) -> None: + """ Initialize a supervised dataset and load the data in RAM. """ + super().__init__() + + self.config: gbure.utils.dotdict = config + self.path: Optional[pathlib.Path] = path + self.tokenizer: transformers.PreTrainedTokenizer = tokenizer + self.evaluation: bool = evaluation + self.load() + self.init_seed() + + def init_seed(self, worker_id: Optional[int] = None) -> None: + """ Initialize the RNG. """ + seed: int = self.config.seed + worker_info = torch.utils.data.get_worker_info() + seed += worker_id if worker_id is not None else (worker_info.id if worker_info is not None else 0) + rng = random.Random(seed) + self.rng = rng + + def load(self) -> None: + """ Load the dataset into RAM. """ + if self.path is not None: + self.graph = Graph(path=self.path) + if self.config.get("share_memory"): + self.graph.share_memory() + + def __len__(self) -> int: + """ Get the number of samples in the dataset. """ + return self.config.sample_per_epoch + + def filter_edge(self, eid: int) -> bool: + """ Filter edges according to the length of the corresponding sentence and the size of its neighborhoods. """ + edge: torch.Tensor = self.graph.edges[eid] + if self.graph.sentences[edge[2]].shape[0] > self.config.max_sentence_length: + return False + if self.config.get("filter_empty_neighborhood") and (self.graph.degree(edge[0]) <= 1 or self.graph.degree(edge[1]) <= 1): + return False + return True + + def sample_main(self) -> int: + """ Sample the main edge, from which positive and negative edges can be selected. """ + # From Soares et al. + # "To prevent a large bias towards relation statements that involve popular entities, we limit the number of relation statements that contain the same entity by randomly sampling a constant number of relation statements that contain any given entity." + # It's hard to guess what was exactly done, so we propose several sampling strategies. + while True: + if self.config.edge_sampling == "uniform-uniform": + vid: int = self.rng.randint(0, self.graph.order-1) + reid: int = self.rng.randint(0, self.graph.degree(vid)-1) + eid: int = self.graph.adj[vid][reid, 1] + elif self.config.edge_sampling == "uniform-inverse degree": + vid: int = self.rng.randint(0, self.graph.order-1) + + v2_candidates: torch.Tensor = torch.zeros(self.graph.degree(vid)) + for i, edge in enumerate(self.graph.adj[vid]): + v2_candidates[i] = self.graph.degree(edge[0]) + v2_candidates /= torch.nn.functional.normalize(v2_candidates, p=1, dim=0) + + # FIXME slow, double check worker asynchronicity + reid: int = torch.multinomial(v2_candidates, 1).item() + eid: int = self.graph.adj[vid][reid, 1] + else: + raise RuntimeError("Unsuported config value for edge_sampling") + if self.filter_edge(eid): + return eid + + def eid_to_sample(self, first_eid: int, second_eid: int, polarity: int) -> Dict[str, Any]: + """ Build a pair with the given polarity from two edge ids. """ + first_edge: torch.Tensor = self.graph.edges[first_eid].clone() + second_edge: torch.Tensor = self.graph.edges[second_eid].clone() + + self.shuffle_entities(first_edge) + self.align_entities_as(second_edge, first_edge) + + sample: Dict[str, Any] = {"polarity": polarity} + sample.update(self.edge_to_features(first_eid, first_edge, "first_", mlm=True)) + sample.update(self.edge_to_features(second_eid, second_edge, "second_", mlm=False)) + return sample + + @staticmethod + def invert_entities(edge: torch.Tensor) -> None: + """ Invert the <e1> and <e2> tags of the edge, the text of the entities are not inverted, only the tags. """ + # invert vertex ids + tmp = edge[0].clone() + edge[0] = edge[1] + edge[1] = tmp + + # invert entity positions + tmp = edge[3:5].clone() + edge[3:5] = edge[5:7] + edge[5:7] = tmp + + def shuffle_entities(self, edge: torch.Tensor) -> None: + """ Invert the <e1> and <e2> tags with probability ½. """ + if self.rng.randint(0, 1): + self.invert_entities(edge) + + @staticmethod + def align_entities_as(edge: torch.Tensor, pattern: torch.Tensor) -> None: + """ Invert entities of an edge if neither of them are in the same position as in the provided pattern. """ + if edge[0] != pattern[0] and edge[1] != pattern[1]: + UnsupervisedDataset.invert_entities(edge) + + def mlm_features(self, text: List[int], prefix: str) -> Dict[str, Any]: + """ Extract mlm_input and mlm_target for masked language model loss. """ + # Function inspired by HuggingFace's code + mlm_target = torch.tensor(text, dtype=torch.int64) + mlm_input = torch.tensor(text, dtype=torch.int64) + + # Do not mask special tokens + st_mask = self.tokenizer.get_special_tokens_mask(text, already_has_special_tokens=True) + st_mask = torch.tensor(st_mask, dtype=torch.bool) + + mlm_p = torch.full((len(text),), self.config.mlm_probability) + mlm_mask = torch.bernoulli(mlm_p).bool() & st_mask + mlm_target[~mlm_mask] = -100 + + masked_p = self.config.mlm_masked_probability + masked_mask = torch.bernoulli(torch.full((len(text),), masked_p)).bool() & mlm_mask + mlm_input[masked_mask] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) + + random_p = self.config.mlm_random_probability / (1-masked_p) + random_mask = torch.bernoulli(torch.full((len(text),), random_p)).bool() & mlm_mask & ~masked_mask + random_value = torch.randint(len(self.tokenizer), (len(text),), dtype=torch.long) + mlm_input[random_mask] = random_value[random_mask] + + return {f"{prefix}mlm_input": mlm_input, f"{prefix}mlm_target": mlm_target} + + def edge_to_features(self, eid: int, edge: torch.Tensor, prefix: str, mlm: bool) -> Dict[str, Any]: + """ + Convert an edge to the corresponding set of features (token list of the sentence, etc). + + If mlm is True, features for Masked Language Model training are also generated. + """ + sample: Dict[str, Any] = {} + sample[f"{prefix}edge_identifier"] = eid + sample[f"{prefix}entity_identifiers"] = edge[0:2] + sample[f"{prefix}entity_degrees"] = torch.tensor([self.graph.degree(edge[0]), self.graph.degree(edge[1])], dtype=torch.int64) + text: List[int] = self.graph.sentences[edge[2]].tolist() + + # Abuse the fact that "</eX>" < "<eX>" + tags: List[Tuple[int, str]] = [(edge[3], "<e1>"), (edge[4], "</e1>"), (edge[5], "<e2>"), (edge[6], "</e2>")] + tags.sort(reverse=True) + + # When we see a start tag <eX>, we know the last tag was </eX> + last_position: int = -1 + for position, tag in tags: + if tag.startswith("<e"): # begin tag + if self.rng.random() < self.config.get("blank_probability", 0): + del text[position:last_position] + text.insert(position, self.tokenizer.convert_tokens_to_ids("<blank/>")) + text.insert(position, self.tokenizer.convert_tokens_to_ids(tag)) + last_position = position + + if mlm: + sample.update(self.mlm_features(text, prefix)) + + sample[f"{prefix}text"] = torch.tensor(text, dtype=torch.int32) + sample[f"{prefix}entity_positions"] = torch.tensor([ + text.index(self.tokenizer.convert_tokens_to_ids("<e1>")), + text.index(self.tokenizer.convert_tokens_to_ids("<e2>")) + ], dtype=torch.int64) + return sample + + def sample_parallel(self) -> Dict[str, Any]: + """ Sample two parallel edges and create a positive pair from them. """ + while True: + first_eid: int = self.sample_main() + if self.graph.eid_simple_adjacency(first_eid): + # This edge has no parallel edges from which a positive can be selected + continue + + adjacency_range: Tuple[int, int] = self.graph.eid_adjacency_range(first_eid) + if self.graph.edges[adjacency_range[0], 2] == self.graph.edges[adjacency_range[1]-1, 2]: + # All the edges are caused by repetition of an entity in the same sentence + continue + + # Avoid the range of parallel edges sharing the same sentence + sentence_range: Tuple[int, int] = self.graph.eid_adjacency_range(first_eid, prefix=3) + sentence_card: int = sentence_range[1] - sentence_range[0] + + second_eid: int = self.rng.randint(adjacency_range[0], adjacency_range[1]-sentence_card-1) + if second_eid >= sentence_range[0]: + second_eid += sentence_card + + if self.filter_edge(second_eid): + return self.eid_to_sample(first_eid, second_eid, 1) + + def sample_strong_negative(self) -> Dict[str, Any]: + """ Sample a strong negative edge around the two given vertices. """ + # TODO consider biaising the sampling away from popular entities here too. + while True: + first_eid: int = self.sample_main() + adjacency_range: Tuple[int, int] = self.graph.eid_adjacency_range(first_eid) + adjacency_size: int = adjacency_range[1] - adjacency_range[0] + vid1: int = self.graph.edges[first_eid, 0] + vid2: int = self.graph.edges[first_eid, 1] + vertex1_degree: int = self.graph.degree(vid1) + vertex2_degree: int = self.graph.degree(vid2) + if vertex1_degree + vertex2_degree <= 2 * adjacency_size: + # This edge has no other incident edges from which a negative can be selected + continue + + second_reid: int = self.rng.randint(0, vertex1_degree + vertex2_degree - 2 * adjacency_size - 1) + if second_reid < vertex1_degree - adjacency_size: + first_reid_begin: int = self.graph.reid_adjacency_begin(vid1, vid2) + if second_reid >= first_reid_begin: + second_reid += adjacency_size + second_eid: int = self.graph.adj[vid1][second_reid, 1] + else: + second_reid -= vertex1_degree - adjacency_size + first_reid_begin: int = self.graph.reid_adjacency_begin(vid2, vid1) + if second_reid >= first_reid_begin: + second_reid += adjacency_size + second_eid: int = self.graph.adj[vid2][second_reid, 1] + + if self.filter_edge(second_eid): + return self.eid_to_sample(first_eid, second_eid, -1) + + def sample_weak_negative(self) -> Dict[str, Any]: + while True: + first_eid: int = self.sample_main() + second_eid: int = self.sample_main() + entities: Set[int] = set([ + self.graph.edges[first_eid, 0], + self.graph.edges[first_eid, 1], + self.graph.edges[second_eid, 0], + self.graph.edges[second_eid, 1]]) + if len(entities) == 4: + return self.eid_to_sample(first_eid, second_eid, -1) + + def sample_triplet(self) -> Dict[str, Any]: + sample: Dict[str, Any] = {} + for prefix in ["first_", "second_", "third_"]: + eid: int = self.sample_main() + edge: torch.Tensor = self.graph.edges[eid].clone() + self.shuffle_entities(edge) + sample.update(self.edge_to_features(eid, edge, prefix, mlm=(prefix == "first_" and self.config.get("language_model_weight", 0) > 0))) + return sample + + def sample(self) -> Dict[str, Any]: + """ Generate a single sample from the dataset. """ + if self.config.unsupervised == "mtb": + p: float = self.rng.random() + if p < self.config.strong_negative_probability: + return self.sample_strong_negative() + elif p < self.config.strong_negative_probability + self.config.weak_negative_probability: + return self.sample_weak_negative() + else: + return self.sample_parallel() + elif self.config.unsupervised == "triplet": + return self.sample_triplet() + else: + raise RuntimeError(f"Unknown unsupervised mode {self.config.unsupervised}.") + + def __iter__(self) -> Iterator[Dict[str, Any]]: + """ Generate samples from the dataset. """ + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + sample_count: int = len(self) + else: + sample_count: int = len(self) // worker_info.num_workers + sample_count += (worker_info.id < (len(self) % worker_info.num_workers)) + + for index in range(sample_count): + yield self.sample() + + +TYPE_MAGIC: Dict[str, torch.utils.data.Dataset] = { + "supervised": SupervisedDataset, + "fewshot": FewShotDataset, + "sampled fewshot": SampledFewShotDataset, + "unsupervised": UnsupervisedDataset # not normally used + } + + +class GraphAdapter(UnsupervisedDataset): + """ + Post-process a Dataset to add graph features. + + The new features include neighborhood_text, neighborhood_entity_identifiers, etc and are extracted from the entity_identifiers features present in the original sample. + """ + def __init__(self, dataset: torch.utils.data.Dataset, entity_dictionary: gbure.data.dictionary.Dictionary, path: pathlib.Path, graph: Optional[gbure.data.graph.Graph]) -> None: + if isinstance(dataset, UnsupervisedDataset) or graph is not None: + super().__init__(dataset.config, None, dataset.tokenizer, dataset.evaluation, None) + if graph is not None: + self.graph = graph + else: + self.graph = dataset.graph + else: + super().__init__(dataset.config, path, dataset.tokenizer, dataset.evaluation, None) + self.dataset = dataset + self.entity_dictionary = entity_dictionary + + def empty_neighborhood(self, prefix: str) -> Dict[str, Any]: + neighborhood_size: int = self.config.neighborhood_size + if not self.evaluation and self.config.get("filter_empty_neighborhood"): + return {} + # FIXME We pad to the same number of neighbors for now, since Batcher.process_int_feature does not support neighborhoods of different sizes yet. + # Once it is implemented, we can set neighborhood_size = 0 + return {f"{prefix}edge_identifier": torch.full((neighborhood_size,), -1, dtype=torch.int64), + f"{prefix}entity_identifiers": torch.full((neighborhood_size, 2), -1, dtype=torch.int64), + f"{prefix}entity_degrees": torch.zeros((neighborhood_size, 2), dtype=torch.int64), + f"{prefix}text": [torch.zeros((0,), dtype=torch.int64) for _ in range(neighborhood_size)], + f"{prefix}entity_positions": torch.zeros((neighborhood_size, 2), dtype=torch.int64)} + + def sample_neighborhood(self, vid: int, exclude: Optional[int], incoming: bool, prefix: str) -> Dict[str, Any]: + """ Sample the neighborhood around the given vertex, excluding a given edge. """ + number_reids: int = self.graph.degree(vid) - (0 if exclude is None else 1) + if number_reids <= 0: + reids: List[int] = [] + elif number_reids <= self.config.neighborhood_size: + reids: List[int] = list(range(number_reids)) + self.rng.choices(range(number_reids), k=self.config.neighborhood_size-number_reids) + else: + reids: List[int] = self.rng.sample(range(number_reids), self.config.neighborhood_size) + + neighbors: List[Dict[str, Any]] = [] + for reid in reids: + eid: int = self.graph.adj[vid][reid, 1] + if exclude is not None and eid == exclude: + eid = self.graph.adj[vid][-1, 1] + edge: torch.Tensor = self.graph.edges[eid].clone() + if edge[int(incoming)] != vid: + self.invert_entities(edge) + neighbors.append(self.edge_to_features(eid, edge, "", mlm=False)) + + if not neighbors: + return self.empty_neighborhood(prefix) + + sample: Dict[str, Any] = {} + for feature in neighbors[0].keys(): + if feature == "text": + sample[f"{prefix}text"] = [neighbor["text"] for neighbor in neighbors] + else: + sample[f"{prefix}{feature}"] = torch.stack([neighbor[feature] for neighbor in neighbors]) + return sample + + def sample_neighborhoods(self, head: int, tail: int, eid: Optional[int], prefix: str) -> Dict[str, Any]: + """ + Sample the neighborhood around the given edge. + + edge should be self.graph.edges[eid], optionaly with the entities reversed. + """ + head: Optional[int] = self.graph.entity_dictionary.encoder.get(self.entity_dictionary.decode(head)) + tail: Optional[int] = self.graph.entity_dictionary.encoder.get(self.entity_dictionary.decode(tail)) + + sample: Dict[str, Any] = {} + if head is None: + e1 = self.empty_neighborhood(f"{prefix}e1_neighborhood_") + e1_degree: int = 0 + else: + e1 = self.sample_neighborhood(head, eid, True, f"{prefix}e1_neighborhood_") + e1_degree: int = self.graph.degree(head) + if not e1: + return {} + sample.update(e1) + if tail is None: + e2 = self.empty_neighborhood(f"{prefix}e2_neighborhood_") + e2_degree: int = 0 + else: + e2 = self.sample_neighborhood(tail, eid, True, f"{prefix}e2_neighborhood_") + e2_degree: int = self.graph.degree(tail) + if not e2: + return {} + sample.update(e2) + sample[f"{prefix}entity_degrees"] = torch.tensor([e1_degree, e2_degree], dtype=torch.int64) + return sample + + def adapt(self, sample: Dict[str, Any]) -> bool: + """ + Add neighborhood features to sample. + + Returns whether the sample should be kept. + """ + extras: Dict[str, Any] = {} + for feature in sample: + if feature.endswith("entity_identifiers"): + prefix: str = feature[:-len("entity_identifiers")] + entity_identifiers: torch.Tensor = sample[f"{prefix}entity_identifiers"] + + # If the underlying dataset is built upon a graph, we can exclude the main sample edge, otherwise there is no risk of sampling an edge as being its own neighbor. + eid: Optional[Union[int, torch.Tensor]] = sample.get(f"{prefix}edge_identifier") + + if prefix == "candidates_": + prefix_extras: Dict[str, Any] = collections.defaultdict(list) + for i, way in enumerate(entity_identifiers): + extras_way: Dict[str, List[Any]] = collections.defaultdict(list) + for j, shot in enumerate(way): + eid: Optional[int] = None if eid is None else eid[i, j].item() + extras_shot: Dict[str, Any] = self.sample_neighborhoods(shot[0], shot[1], eid, prefix) + if not extras_shot: + return False + for feature, value in extras_shot.items(): + extras_way[feature].append(value) + for feature, values in extras_way.items(): + prefix_extras[feature].append(values if feature.endswith("text") else torch.stack(values)) + for feature, values in prefix_extras.items(): + prefix_extras[feature] = values if feature.endswith("text") else torch.stack(values) + else: + prefix_extras: Dict[str, Any] = self.sample_neighborhoods(entity_identifiers[0], entity_identifiers[1], eid, prefix) + if prefix_extras: + extras.update(prefix_extras) + else: + return False + if f"{prefix}entity_degrees" not in sample and f"{prefix}entity_degrees" not in extras: + extras[f"{prefix}entity_degrees"] = torch.zeros_like(extras[f"{prefix}entity_identifiers"], dtype=torch.int64) + sample.update(extras) + return True + + def __len__(self) -> int: + return len(self.dataset) + + def process_sample(self, sample: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + # TODO define config value to repeat the sampling of neighbors + if self.config.get("neighborhood_size", 0) > 0: + if self.adapt(sample): + yield sample + else: + yield sample + + def __iter__(self) -> Iterator[Dict[str, Any]]: + if isinstance(self.dataset, torch.utils.data.IterableDataset): + for sample in self.dataset: + yield from self.process_sample(sample) + else: # Map-style dataset + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + worker_modulo: int = 1 + worker_residue: int = 0 + else: + worker_modulo: int = worker_info.num_workers + worker_residue: int = worker_info.id + + for i in range(worker_residue, len(self.dataset), worker_modulo): + yield from self.process_sample(self.dataset[i]) + + +def load_dataset(config: gbure.utils.dotdict, split: str, path: pathlib.Path, **kwargs) -> torch.utils.data.Dataset: + if split == "train" and config.get("unsupervised"): + return UnsupervisedDataset(config=config, path=path, **kwargs) + + dstype: str + data: Any + dstype, data = torch.load(path) + return TYPE_MAGIC[dstype](config=config, path=path, data=data, **kwargs) diff --git a/gbure/data/dictionary.py b/gbure/data/dictionary.py @@ -0,0 +1,112 @@ +from typing import Any, Dict, List, Optional +import pathlib +import pickle + + +class Dictionary: + keys: List[str] = ["keys", "unknown", "decoder", "encoder"] + + def __init__(self, *, unknown: Optional[str] = None, path: Optional[pathlib.Path] = None) -> None: + self.encoder: Dict[str, int] = {} + self.decoder: List[str] = [] + + self.unknown: Optional[str] = unknown + if unknown is not None: + self.encoder[unknown] = 0 + self.decoder.append(unknown) + + if path is not None: + self.load(path) + + def __len__(self) -> int: + """ Number of tokens in the dictionary. """ + return len(self.decoder) + + def encode(self, token: str) -> int: + """ Returns the id corresponding to a token. """ + id: Optional[int] = self.encoder.get(token) + if id is not None: + return id + + id = len(self.decoder) + self.encoder[token] = id + self.decoder.append(token) + return id + + def decode(self, id: int) -> str: + """ Returns the token corresponding to an id. """ + return self.decoder[id] + + def save(self, path: pathlib.Path) -> None: + with path.open("wb") as file: + pickle.dump({key: getattr(self, key) for key in self.keys}, file) + + def load(self, path: pathlib.Path) -> None: + with path.open("rb") as file: + data: Dict[str, Any] = pickle.load(file) + for key, value in data.items(): + setattr(self, key, value) + + +class RelationDictionary(Dictionary): + """ + A dictionary to be used for relations. + + The tokens held by this class are divided between: + - *relation* such as "Entity-Destination(e1,e2)" + - *base* such as "Entity-Destination" + """ + + keys = ["keys", "unknown", "decoder", "encoder", "base_encoder", "base_decoder", "id_to_bid"] + + def __init__(self, *, unknown: Optional[str] = None, path: Optional[pathlib.Path] = None) -> None: + self.base_encoder: Dict[str, int] = {} + self.base_decoder: List[str] = [] + self.id_to_bid: List[int] = [] + + if unknown is not None: + self.base_encoder[unknown] = 0 + self.base_decoder.append(unknown) + self.id_to_bid.append(0) + + super().__init__(unknown=unknown, path=path) + + def base_size(self) -> int: + """ Number of bases in the dictionary. """ + return len(self.base_decoder) + + def encode(self, relation: str, base: Optional[str] = None) -> int: + """ + Returns the id corresponding to a relation string. + + If base is none, do not attempt to insert a new id and returns the id of relation immediately + + Args: + relation: the string of the relation (e.g. "Entity-Destination(e1,e2)") + base: the string of the base relation (e.g. "Entity-Destination") + """ + if relation is None: + return None + + if base is None: + return self.encoder[relation] + + id: Optional[int] = self.encoder.get(relation) + if id is not None: + return id + + bid: Optional[int] = self.base_encoder.get(base) + if bid is None: + bid = len(self.base_decoder) + self.base_encoder[base] = bid + self.base_decoder.append(base) + + id = len(self.decoder) + self.encoder[relation] = id + self.decoder.append(relation) + self.id_to_bid.append(bid) + return id + + def base_id(self, id: int) -> int: + """ Returns the base id corresponding to a relation id. """ + return self.id_to_bid[id] diff --git a/gbure/data/graph.py b/gbure/data/graph.py @@ -0,0 +1,146 @@ +from typing import List, Optional, Tuple, Union +import pathlib + +import torch +import tqdm +import transformers + +from gbure.data.dictionary import Dictionary +from gbure.utils import SharedLongTensorList + + +class Graph: + """ + Graph represented by an adjacency list. + + The node of the Graph correspond to entities, an edge between node n1 and node n2 indicates that the two corresponding entities appear together in a sentence, this sentence being the label of the edge. + To avoid storing the same sentence several time, a list of sentences is stored and the edges are labeled with the sentence id and entities postions. + + The data is stored in four objects: + sentences: sid -> list of tokens without tags (<e1>, etc) + entity_dictionary: KB entity identifier -> vid (vertex id) + adj: source vid -> [(destination vid, eid)] sorted in lexicographic order + edges: eid (edge id) -> (e1 vid, e2 vid, sid, e1 start position, e1 end position, e2 start position, e2 end position) + """ + + def __init__(self, + sentences: Optional[List[torch.Tensor]] = None, + entity_dictionary: Optional[Dictionary] = None, + degrees: Optional[List[int]] = None, + edges: Optional[List[Tuple[int, int, int, int, int, int, int]]] = None, + *, path: Optional[pathlib.Path] = None) -> None: + """ Initialize a Graph, either with the provided data or by loading the given file. """ + if path is not None: + assert(sentences is None and entity_dictionary is None and degrees is None and edges is None) + self.load(path) + else: + assert(sentences is not None and entity_dictionary is not None and degrees is not None and edges is not None) + self.sentences: Union[List[torch.Tensor], SharedLongTensorList] = sentences + self.entity_dictionary: Dictionary = entity_dictionary + self.compile_edges_adj(edges, degrees) + + def compile_edges_adj(self, edges: List[Tuple[int, int, int, int, int, int, int]], degrees: List[int]) -> None: + """ Compile the edge and adjacency list of the graph. """ + edge_order: List[int] = sorted(range(len(edges)), key=edges.__getitem__) + + self.edges: torch.Tensor = torch.empty((len(edges), 7), dtype=torch.int32) + self.adj: Union[List[torch.Tensor], SharedLongTensorList] = [torch.empty((degree, 2), dtype=torch.int32) for degree in degrees] + + for new_id, old_id in enumerate(tqdm.tqdm(edge_order, desc="compiling graph")): + edge: torch.Tensor = torch.tensor(edges[old_id], dtype=torch.int32) + self.edges[new_id] = edge + for source, destination in [(edge[0], edge[1]), (edge[1], edge[0])]: + degrees[source] -= 1 + self.adj[source][degrees[source], 0] = destination + self.adj[source][degrees[source], 1] = new_id + # Free the memory along the way to avoid storing the graph twice + edges[old_id] = None + + for i, vertex in enumerate(tqdm.tqdm(self.adj, desc="sorting adjacency list")): + self.adj[i] = torch.stack(sorted(vertex, key=torch.Tensor.tolist)) # pytype: disable=unsupported-operands + + @property + def order(self) -> int: + """ Number of vertices. """ + return len(self.adj) + + @property + def size(self) -> int: + """ Number of edges. """ + return self.edges.shape[0] + + def degree(self, vertex: int) -> int: + """ Number of edges connected to a given vertex. """ + return self.adj[vertex].shape[0] + + def eid_simple_adjacency(self, eid: int) -> bool: + """ Decide whether a given edge is the sole edge between two nodes. """ + if eid > 0 and (self.edges[eid, :2] == self.edges[eid-1, :2]).all(): + return False + if eid < self.size-1 and (self.edges[eid, :2] == self.edges[eid+1, :2]).all(): + return False + return True + + def eid_adjacency_range(self, eid: int, prefix: int = 2) -> Tuple[int, int]: + """ Return the range of edges (in the global edge list) sharing the same end points as eid. """ + range_start: int = eid + while range_start > 0 and (self.edges[eid, :prefix] == self.edges[range_start-1, :prefix]).all(): + range_start -= 1 + + range_end: int = eid+1 + while range_end < self.size and (self.edges[eid, :prefix] == self.edges[range_end, :prefix]).all(): + range_end += 1 + + return range_start, range_end + + def reid_adjacency_begin(self, source: int, destination: int) -> int: + """ Return the first edge from source to destination as relative index in source's adjacency list. """ + left: int = 0 + right: int = self.degree(source) + while left < right: + middle: int = (left + right) // 2 + if self.adj[source][middle, 0] < destination: + left = middle + 1 + else: + right = middle + + return left + + def tagged_sentence(self, eid: int, tokenizer: transformers.PreTrainedTokenizer, invert: bool = False) -> Tuple[List[int], int, int]: + """ Get the tagged sentence corresponding to an edge. """ + edge: torch.Tensor = self.edges[eid] + text: List[int] = self.sentences[edge[2]].tolist() + # Abuse the fact that "</e1>" < "<e1>" + if invert: + tags: List[Tuple[int, str]] = [(edge[5], "<e1>"), (edge[6], "</e1>"), (edge[3], "<e2>"), (edge[4], "</e2>")] + else: + tags: List[Tuple[int, str]] = [(edge[3], "<e1>"), (edge[4], "</e1>"), (edge[5], "<e2>"), (edge[6], "</e2>")] + tags.sort(reverse=True) + for position, tag in tags: + text.insert(position, tokenizer.convert_tokens_to_ids(tag)) + e1_pos: int = self.edges[eid, 5 if invert else 3].item() + e2_pos: int = self.edges[eid, 3 if invert else 5].item() + if e1_pos < e2_pos: + e2_pos += 2 + else: + e1_pos += 2 + return text, e1_pos, e2_pos + + def save(self, path: pathlib.Path) -> None: + """ Save the graph to the given directory. """ + if not path.is_dir(): + path.mkdir() + + self.entity_dictionary.save(path / "entities") + for attribute in ["sentences", "edges", "adj"]: + torch.save(getattr(self, attribute), path / attribute) + + def load(self, path: pathlib.Path) -> None: + """ Load a graph from the given directory. """ + self.entity_dictionary = Dictionary(path=path / "entities") + for attribute in ["sentences", "edges", "adj"]: + setattr(self, attribute, torch.load(path / attribute)) + + def share_memory(self) -> None: + self.sentences = SharedLongTensorList(self.sentences) + self.adj = SharedLongTensorList(self.adj, [-1, 2]) diff --git a/gbure/data/prepare_fewrel.py b/gbure/data/prepare_fewrel.py @@ -0,0 +1,65 @@ +from typing import Any, Dict, Iterable, Tuple +import argparse +import json +import pathlib + +import tqdm + +from gbure.utils import DATA_PATH +import gbure.data.preprocessing as preprocessing + +DATASET_PATH: pathlib.Path = DATA_PATH / "FewRel" +DOWNLOAD_URL: str = "https://thunlp.oss-cn-qingdao.aliyuncs.com/fewrel/" +FILES_SHA512: Dict[str, str] = { + "fewrel_train.json": "2ec687d16999bd59bbcac39fdfed319cee3bec14963717c6ee262da981ad64e58b93d5c95a6ea4f6c5fe9c3d09a57d098d25ad61955f4df5f910bc28718e8220", + "fewrel_val.json": "32bdac8c9aba880484d00417d823a310657acca5604a06fa7c4c01f8dfb54b9e05ecf67c0c374aa61bd42cb9e90cd62c0d50377cce2d6bc8fa6d1fbbb61d0f5e" + } + + +def get_data() -> None: + """ Download FewRel's train and val json files if needed. """ + for filename, sha512 in FILES_SHA512.items(): + if not (DATASET_PATH / filename).exists(): + preprocessing.download(DOWNLOAD_URL + filename, DATASET_PATH / filename, filename, sha512) + + +def read_data(path: pathlib.Path) -> Iterable[Tuple[str, str, str, str, str]]: + """ Read a FewRel json file and return (text, relation, relation_base, e1, e2) tuples. """ + with open(path) as file: + data = json.load(file) + + for relation, relset in tqdm.tqdm(data.items(), desc=f"loading {path.name}"): + relation_dataset = [] + for sentence in relset: + # We assume we know the direction of the relation + yield (process_sentence(sentence), relation, relation, sentence["h"][1][1:], sentence["t"][1][1:]) + + +def process_sentence(sentence: Dict[str, Any]) -> str: + """ Transform a FewRel json sentence object to a tagged sentence string. """ + tokens = sentence["tokens"] + tokens[sentence["h"][2][0][0]] = "<e1>" + tokens[sentence["h"][2][0][0]] + tokens[sentence["h"][2][0][-1]] += "</e1>" + tokens[sentence["t"][2][0][0]] = "<e2>" + tokens[sentence["t"][2][0][0]] + tokens[sentence["t"][2][0][-1]] += "</e2>" + return " ".join(tokens) + + +def read_splits() -> Dict[str, Iterable[Tuple[str, str, str, str, str]]]: + return {"train": read_data(DATASET_PATH / "fewrel_train.json"), + "valid": read_data(DATASET_PATH / "fewrel_val.json")} + + +if __name__ == "__main__": + parser: argparse.ArgumentParser = preprocessing.base_argument_parser("Prepare the few shot FewRel dataset.") + args: argparse.Namespace = parser.parse_args() + name: str = preprocessing.dataset_name(args) + + get_data() + preprocessing.serialize_dataset( + supervision="fewshot", + path=DATASET_PATH / name, + splits=read_splits(), + unknown_entity=None, + unknown_relation=None, + **preprocessing.args_to_serialize(args)) diff --git a/gbure/data/prepare_kbp37.py b/gbure/data/prepare_kbp37.py @@ -0,0 +1,41 @@ +from typing import Dict, Iterable, Tuple +import argparse +import pathlib + +import tqdm + +from gbure.utils import DATA_PATH +import gbure.data.prepare_semeval +import gbure.data.preprocessing as preprocessing + +DATASET_PATH: pathlib.Path = DATA_PATH / "KBP37" +COMMIT_ID: str = "7d88486ad632a9c6e9fe6adbc2468049e89bc11d" +DIRECTORY_NAME: str = f"kbp37-{COMMIT_ID}" +ARCHIVE_NAME: str = "kbp37_data.zip" +ARCHIVE_SHA512: str = "f6661df79d327a34ad4198f0405d7c06e05af4d9aab3723282c02270394c6511df41a330118d2be89f390b9d36c1eb2a32800db01a3d4387b990493befb011ac" +DOWNLOAD_URL: str = f"https://github.com/zhangdongxu/kbp37/archive/{COMMIT_ID}.zip" + +TRAIN_SIZE: int = 15917 +VALID_SIZE: int = 1724 +TEST_SIZE: int = 3405 +UNKNOWN_RELATION: str = "no_relation" + + +def read_splits() -> Dict[str, Iterable[Tuple[str, str, str, str, str]]]: + return {"train": gbure.data.prepare_semeval.read_data(DATASET_PATH / DIRECTORY_NAME / "train.txt", TRAIN_SIZE), + "valid": gbure.data.prepare_semeval.read_data(DATASET_PATH / DIRECTORY_NAME / "dev.txt", VALID_SIZE), + "test": gbure.data.prepare_semeval.read_data(DATASET_PATH / DIRECTORY_NAME / "test.txt", TEST_SIZE)} + + +if __name__ == "__main__": + parser: argparse.ArgumentParser = preprocessing.base_argument_parser("Prepare the supervised KBP37 dataset.") + args: argparse.Namespace = parser.parse_args() + name: str = preprocessing.dataset_name(args) + + preprocessing.get_zip_data(DATASET_PATH, DIRECTORY_NAME, ARCHIVE_NAME, ARCHIVE_SHA512, DOWNLOAD_URL) + preprocessing.serialize_dataset( + supervision="supervised", + path=DATASET_PATH / name, + splits=read_splits(), + unknown_relation=UNKNOWN_RELATION, + **preprocessing.args_to_serialize(args)) diff --git a/gbure/data/prepare_sampled_fewrel.py b/gbure/data/prepare_sampled_fewrel.py @@ -0,0 +1,65 @@ +from typing import Any, Dict, Iterable, List, Optional, Tuple +import argparse +import hashlib +import itertools +import json +import pathlib + +import tqdm + +from gbure.utils import DATA_PATH +from gbure.data.prepare_fewrel import DATASET_PATH, process_sentence +import gbure.data.preprocessing as preprocessing + + +def read_entry(entry: Dict[str, Any]) -> Tuple[str, str, str]: + """ Convert the json object of an entry to a preprocessing tuple (sentence, e1, e2). """ + return process_sentence(entry), entry["h"][1], entry["t"][1] + + +def read_file(inpath: pathlib.Path, outpath: Optional[pathlib.Path]) -> Iterable[Tuple[Tuple[str, str, str], List[List[Tuple[str, str, str]]], int]]: + """ + Yield (test, [[train]]) pair from a file of samples. + + Each input is a triplet (sentence, head entity, tail entity). + """ + with open(inpath) as file: + data = json.load(file) + + if outpath: + with open(outpath) as file: + answers = json.load(file) + else: + answers = itertools.repeat(-1) + + for problem, answer in zip(tqdm.tqdm(data, desc=f"processing {inpath.name}"+(" and "+outpath.name if outpath else "")), answers): + test = read_entry(problem["meta_test"]) + train = list(map(lambda candidates: list(map(read_entry, candidates)), problem["meta_train"])) + yield (test, train, answer) + + +if __name__ == "__main__": + parser: argparse.ArgumentParser = preprocessing.base_argument_parser("Prepare a sampled few shot FewRel dataset (generated by sample_io.py).", deterministic=True) + parser.add_argument("inpath", + type=pathlib.Path, + help="Path to the file containing the input ") + parser.add_argument("outpath", + type=pathlib.Path, + nargs="?", + help="Path to the file containing the output (optional)") + parser.add_argument("-S", "--suffix", + type=str, + default="", + help="Suffix to add to the tokenizer to find the dataset.") + + args: argparse.Namespace = parser.parse_args() + hashid: str = preprocessing.hash_file(args.inpath)[:8] + if args.outpath: + hashid += preprocessing.hash_file(args.outpath)[:8] + name: str = preprocessing.dataset_name(args, args.suffix) + + preprocessing.serialize_fewshot_sampled_split( + path=DATASET_PATH / name, + name=hashid, + split=read_file(args.inpath, args.outpath), + **preprocessing.args_to_serialize(args)) diff --git a/gbure/data/prepare_semeval.py b/gbure/data/prepare_semeval.py @@ -0,0 +1,83 @@ +from typing import Dict, Iterable, Tuple +import argparse +import pathlib +import random + +import tqdm + +from gbure.utils import DATA_PATH +import gbure.data.preprocessing as preprocessing + +DATASET_PATH: pathlib.Path = DATA_PATH / "SemEval 2010 Task 8" +DIRECTORY_NAME: str = "SemEval2010_task8_all_data" +ARCHIVE_NAME: str = f"{DIRECTORY_NAME}.zip" +ARCHIVE_SHA512: str = "7ac2d71ba1772105c1f73e4278e4b85cebd9fb95187fb8a153c83215d890f0d2b98929fb2363e8c117d2e2c8e7f9926d7997e38fc57b5730fb00912ff376b66b" +DOWNLOAD_URL: str = f"https://esimon.eu/GBURE/{ARCHIVE_NAME}" + +TRAIN_VALID_SIZE: int = 8000 +TEST_SIZE: int = 2717 +UNKNOWN_RELATION: str = "Other" + + +def read_data(path: pathlib.Path, size: int) -> Iterable[Tuple[str, str, str, str, str]]: + """ + Read a file in SemEval format and return (text, relation, relation_base, e1, e2) tuples. + + For now, the entities are empty, we could encode them using their surface form or use a true entity linker if we want to use this information. + """ + with path.open() as file: + for _ in tqdm.trange(size, desc=f"loading {path.name}"): + idtext_line: str = file.readline() + relation_line: str = file.readline() + file.readline() # Ignore Comment line + file.readline() # Ignore empty line + + if not (idtext_line and relation_line): + break + + id, raw_text = idtext_line.rstrip().split('\t') + text = raw_text[1:-1] # remove quotes around text + relation = relation_line.rstrip() + + dir_start: int = relation.find('(') + relation_base: str = relation[:dir_start] if dir_start >= 0 else relation + + # TODO handle entities + yield (text, relation, relation_base, "", "") + + +def split_train_valid(data: Iterable[Tuple[str, str, str, str, str]], valid_size: int, seed: int) -> Tuple[Iterable[Tuple[str, str, str, str, str]], Iterable[Tuple[str, str, str, str, str]]]: + data = list(data) + rng = random.Random(seed) + rng.shuffle(data) + + train = data[valid_size:] + valid = data[:valid_size] + return train, valid + + +def read_splits(valid_size: int) -> Dict[str, Iterable[Tuple[str, str, str, str, str]]]: + splits = {} + train_valid = read_data(DATASET_PATH / DIRECTORY_NAME / "SemEval2010_task8_training" / "TRAIN_FILE.TXT", TRAIN_VALID_SIZE) + splits["train"], splits["valid"] = split_train_valid(train_valid, args.valid_size, args.seed) + splits["test"] = read_data(DATASET_PATH / DIRECTORY_NAME / "SemEval2010_task8_testing_keys" / "TEST_FILE_FULL.TXT", TEST_SIZE) + return splits + + +if __name__ == "__main__": + parser: argparse.ArgumentParser = preprocessing.base_argument_parser("Prepare the supervised SemEval 2010 Task 8 dataset.") + parser.add_argument("-v", "--valid-size", + type=int, + default=1500, + help="Size of the validation set") + + args: argparse.Namespace = parser.parse_args() + name: str = preprocessing.dataset_name(args, f"-v{args.valid_size}" if args.valid_size != 1500 else "") + + preprocessing.get_zip_data(DATASET_PATH, DIRECTORY_NAME, ARCHIVE_NAME, ARCHIVE_SHA512, DOWNLOAD_URL) + preprocessing.serialize_dataset( + supervision="supervised", + path=DATASET_PATH / name, + splits=read_splits(args.valid_size), + unknown_relation=UNKNOWN_RELATION, + **preprocessing.args_to_serialize(args)) diff --git a/gbure/data/prepare_trex.py b/gbure/data/prepare_trex.py @@ -0,0 +1,76 @@ +from typing import Any, Dict, Iterable, List, Optional, Tuple +import argparse +import json +import os +import pathlib +import tqdm + +from gbure.utils import DATA_PATH +import gbure.data.preprocessing as preprocessing + +DATASET_PATH: pathlib.Path = DATA_PATH / "T-REx" +DIRECTORY_NAME: str = "raw_data" +ARCHIVE_NAME: str = f"T-REx.zip" +ARCHIVE_SHA512: str = "30349fa6f01c1928ce15325521ebd05643787220f9a545eb23b280f9209cb1615f4a855b08604f943a1affb4d1f4f17b94f8434698f347a1cb7a0d820fa9de9f" +DOWNLOAD_URL: str = f"https://esimon.eu/GBURE/{ARCHIVE_NAME}" + + +def process_json_object(data: List[Dict[str, Any]]) -> Iterable[Tuple[str, List[Tuple[str, int, int]]]]: + """ Process a T-REx json object and return (sentence, list of entities) tuples. """ + for article in data: + eid: int = 0 + for sbs in article["sentences_boundaries"]: + entities: List[Tuple[str, int, int]] = [] + while eid < len(article["entities"]) and article["entities"][eid]["boundaries"][0] < sbs[0]: + eid += 1 + + while eid < len(article["entities"]) and article["entities"][eid]["boundaries"][1] <= sbs[1]: + entity: Dict[str, Any] = article["entities"][eid] + eid += 1 + + # Ignore date entities + if entity["annotator"] != "Wikidata_Spotlight_Entity_Linker": + continue + + uri: str = entity["uri"] + prefix: str = "http://www.wikidata.org/entity/Q" + assert(uri.startswith(prefix)) + uri = uri[len(prefix):] + entities.append((uri, entity["boundaries"][0] - sbs[0], entity["boundaries"][1] - sbs[0])) + + # ignore sentences with less than two entities + if len(entities) < 2: + continue + + sentence = article["text"][sbs[0]:sbs[1]] + yield (sentence, entities) + + +def read_data(subset: Optional[int]) -> Iterable[Tuple[str, List[Tuple[str, int, int]]]]: + """ Read all T-REx files and return (sentence, list of entities) tuples. """ + filenames: List[str] = list(filter(lambda filename: filename.endswith(".json"), os.listdir(DATASET_PATH / DIRECTORY_NAME))) + + # Make the order deterministic. + filenames.sort() + if subset is not None: + filenames = filenames[:subset] + + for filename in tqdm.tqdm(filenames, desc="loading"): + with open(DATASET_PATH / DIRECTORY_NAME / filename, "r") as file: + data: List[Dict[str, Any]] = json.load(file) + yield from process_json_object(data) + + +if __name__ == "__main__": + parser: argparse.ArgumentParser = preprocessing.base_argument_parser("Prepare the unsupervised TREx dataset.") + parser.add_argument("-S", "--subset", + type=int, + help="Number of file to process (default to all, only used for creating a debug dataset)") + args: argparse.Namespace = parser.parse_args() + name: str = preprocessing.dataset_name(args, "" if args.subset is None else f"-ss{args.subset}") + + preprocessing.get_zip_data(DATASET_PATH, DIRECTORY_NAME, ARCHIVE_NAME, ARCHIVE_SHA512, DOWNLOAD_URL, unzip_directory=True) + preprocessing.serialize_unsupervised_dataset( + path=DATASET_PATH / name, + data=read_data(args.subset), + **preprocessing.args_to_serialize(args)) diff --git a/gbure/data/preprocessing.py b/gbure/data/preprocessing.py @@ -0,0 +1,392 @@ +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast +import argparse +import collections +import hashlib +import math +import os +import pathlib +import random +import urllib.request +import zipfile + +import torch +import tqdm +import transformers + +from gbure.data.dictionary import Dictionary, RelationDictionary +from gbure.data.graph import Graph + + +def hash_file(path: pathlib.Path, filename: Optional[str] = None, filesize: Optional[int] = None) -> str: + """ Get a unique identifier for the file. """ + hasher = hashlib.sha512() + with path.open("rb") as file: + loop = iter(lambda: file.read(2**16), b"") + if filename is not None and filesize is not None: + loop = tqdm.tqdm(loop, + desc=f"checking {filename} hash", + total=math.ceil(filesize / 2**16), + unit_scale=2**16, unit="B", unit_divisor=1024) + for chunk in loop: + hasher.update(chunk) + return hasher.hexdigest() + + +def download(url: str, path: pathlib.Path, filename: str, sha512: str) -> None: + """ Download a file at the given path and check its hash. """ + if not path.parent.is_dir(): + path.parent.mkdir(parents=True) + + unchecked: pathlib.Path = pathlib.Path(f"{path}.unchecked") + with tqdm.tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=f"downloading {filename}") as progress: + def report_hook(num_blocks: int, chunk_size: int, total_size: int): + if progress.total is None: + progress.total = total_size + progress.update(num_blocks * chunk_size - progress.n) + urllib.request.urlretrieve(url, unchecked, report_hook) + + unchecked_hash = hash_file(unchecked, filename, progress.total) + if unchecked_hash != sha512: + raise RuntimeError(f"Downloaded file \"{filename}\" has wrong hash.") + os.rename(unchecked, path) + + +def get_zip_data(dataset_path: pathlib.Path, directory_name: str, archive_name: str, archive_sha512: str, download_url: str, unzip_directory: bool = False) -> None: + """ Download and extract data zip archive if needed. """ + if not (dataset_path / directory_name).exists(): + if not (dataset_path / archive_name).exists(): + download(download_url, dataset_path / archive_name, archive_name, archive_sha512) + + with zipfile.ZipFile(str(dataset_path / archive_name), "r") as archive: + archive.extractall(dataset_path / directory_name if unzip_directory else dataset_path) + + +def base_argument_parser(description: str = "", deterministic: bool = False, parser: Optional[argparse.ArgumentParser] = None) -> argparse.ArgumentParser: + assert(description != "" or parser is not None) + """ Return an argument parser with standard command line arguments used by preprocessing functions. """ + parser: argparse.ArgumentParser = argparse.ArgumentParser(description=description) if parser is None else parser + parser.add_argument("tokenizer", + type=str, + nargs='?', + default="bert-base-cased", + help="Name of the transformers tokenizer") + if not deterministic: + parser.add_argument("-s", "--seed", + type=int, + default=0, + help="Seed of the RNG for shuffling the dataset") + return parser + + +def dataset_name(args: argparse.Namespace, infix: str = "") -> str: + """ Returns the dataset name with suffix containing non-standard preprocessing parameters. """ + suffix: str = "" + if "seed" in args and args.seed != 0: + suffix = f"-s{args.seed}" + return f"{args.tokenizer}{infix}{suffix}" + + +def args_to_serialize(args: argparse.Namespace) -> Dict[str, Any]: + """ Map standard preprocessing command line arguments defined in base_argument_parser to serialize_supervised_dataset parameters. """ + kwargs = {"tokenizer_name": args.tokenizer} + if "seed" in args: + kwargs["seed"] = args.seed + return kwargs + + +def make_tokenizer(name: str, path: pathlib.Path) -> transformers.PreTrainedTokenizer: + """ Build the given tokenizer and save it. """ + if not path.is_dir(): + path.mkdir() + + tokenizer = transformers.AutoTokenizer.from_pretrained(name) + special_tokens = ["<e1>", "</e1>", "<e2>", "</e2>", "<blank/>"] + tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) + tokenizer.save_pretrained(path) + + # fix huggingface transformers issue #6368 + config_file = transformers.AutoConfig.from_pretrained(name) + config_file.save_pretrained(path) + + return tokenizer + + +def process_text_2(raw_text: str, tokenizer: transformers.PreTrainedTokenizer) -> Tuple[torch.Tensor, int, int]: + """ + Transform a string with two entities tagged to a list of token ids together with the positions of the two entities. + + The returned token list contains the token corresponding to the tags. + The two returned positions are the positions of <e1> and <e2>. + """ + be1_id: int = tokenizer.convert_tokens_to_ids("<e1>") + be2_id: int = tokenizer.convert_tokens_to_ids("<e2>") + + text: List[int] = tokenizer.encode(raw_text, add_special_tokens=True) + e1_pos: int = text.index(be1_id) + e2_pos: int = text.index(be2_id) + if len(text) > tokenizer.model_max_length: + text = text[:tokenizer.model_max_length] + e1_pos = min(tokenizer.model_max_length-1, e1_pos) + e2_pos = min(tokenizer.model_max_length-1, e2_pos) + + return torch.tensor(text, dtype=torch.int32), e1_pos, e2_pos + + +def process_text_n(raw_text: str, raw_entities: List[Tuple[str, int, int]], tokenizer: transformers.PreTrainedTokenizer) -> Tuple[torch.Tensor, List[Tuple[str, int, int]]]: + """ + Transform a string with several entities tagged to a list of token id together with the positions of entities. + + The returned token list does not contain the token corresponding to the tags. + The returned postions, are where the tags should be inserted. + If the leftmost tag is inserted first, the position of subsequent inserts should be shifted accordingly. + """ + be1_id: int = tokenizer.convert_tokens_to_ids("<e1>") + + # If one entity end at a position, and another entity start at the same position, we want to close the first entity before starting the sencond one, the second field "1 - extremity" has this function since the list is sorted in lexicographic order. + tag_positions: List[Tuple[int, int, int]] = [ + (cast(int, entity[1 + extremity]), # Position of the tag (start or end of entity) in the sentence. + cast(int, 1 - extremity), # Whether this is a start or end of entity. + i) # The index of the entity used to rebuild the list at the end. + for i, entity in enumerate(raw_entities) for extremity in [0, 1]] + tag_positions.sort() + + # We insert the tag <e1> at every tag postion in order to be able to convert postions in the raw text to positions in the token list. + pieces: List[str] = [] + for piece_start, piece_end in zip([(0,)] + tag_positions, tag_positions + [(len(raw_text),)]): + pieces.append(raw_text[piece_start[0]:piece_end[0]]) + pieces.append("<e1>") + # Remove the last <e1> added at the end of the sentence. + pieces.pop() + + text: List[int] = tokenizer.encode("".join(pieces), add_special_tokens=True) + if len(text) > tokenizer.model_max_length: + text = text[:tokenizer.model_max_length] + + # New entity list, with converted positions. + entities: List[List[Union[str, int]]] = [[entity[0], -1, -1] for entity in raw_entities] + + j: int = 0 # Counter on the tags. + for i, token in enumerate(text): + if token == be1_id: + # The order of the <e1> in the text match the one in tag_positions. + tag_position: Tuple[int, int, int] = tag_positions[j] + + # tag_position[2] is the index of the entity in raw_entities (and thus entities). + # tag_position[1] is 0 for the end of the entity and 1 for its start. + # Since the returned token list will be pruned of all the <e1>, the position of the tag should be shifted by the number of <e1> already met, thus "i - j". + entities[tag_position[2]][2 - tag_position[1]] = i - j + + j += 1 # Move to the next tag. + + # Remove all tags + text = list(filter(lambda x: x != be1_id, text)) + + # Remove entities which didn't fit inside tokenizer.model_max_length tokens. + entities = list(filter(lambda x: x[1] >= 0 and x[2] >= 0, entities)) + + tuple_entities: List[Tuple[str, int, int]] = list(map(tuple, entities)) + return torch.tensor(text, dtype=torch.int32), tuple_entities + + +def serialize_supervised_split( + path: pathlib.Path, + split: Iterable[Tuple[str, str, str, str, str]], + tokenizer: transformers.PreTrainedTokenizer, + entity_dictionary: Dictionary, + relation_dictionary: RelationDictionary) -> None: + """ + Serialize a supervised split to a given path. + + split is an iterable containing (text, directed relation, undirected relation, e1, e2) tuples. + Entities are ignored. + The relations are raw values (e.g. P42). This function performs the encoding. + """ + data: List[Tuple[torch.Tensor, int, int, int]] = [] + + # TODO handle entities + for raw_text, relation, relation_base, _, _ in split: + text, e1_pos, e2_pos = process_text_2(raw_text, tokenizer) + relation_id: int = relation_dictionary.encode(relation, relation_base) + data.append((text, e1_pos, e2_pos, relation_id)) + + torch.save(("supervised", data), path) + + +def serialize_fewshot_split( + path: pathlib.Path, + split: Iterable[Tuple[str, str, str, str, str]], + tokenizer: transformers.PreTrainedTokenizer, + entity_dictionary: Dictionary, + relation_dictionary: RelationDictionary) -> None: + """ + Serialize a fewshot split to a given path. + + split is an iterable containing (text, directed relation, undirected relation, e1, e2) tuples. + The relations and entities are raw values (e.g. P42, Q42). This function performs the encoding. + """ + data: Dict[int, List[Tuple[torch.Tensor, int, int, int, int, int]]] = collections.defaultdict(list) + + for raw_text, relation, relation_base, e1, e2 in split: + text, e1_pos, e2_pos = process_text_2(raw_text, tokenizer) + relation_id: int = relation_dictionary.encode(relation, relation_base) + e1_id: int = entity_dictionary.encode(e1) + e2_id: int = entity_dictionary.encode(e2) + data[relation_id].append((text, e1_pos, e2_pos, relation_id, e1_id, e2_id)) + + torch.save(("fewshot", list(data.values())), path) + + +def serialize_fewshot_sampled_split( + path: pathlib.Path, + name: str, + split: Iterable[Tuple[Tuple[str, str, str], List[List[Tuple[str, str, str]]], int]], + tokenizer_name: str) -> None: + """ + Serialize a sampled fewshot split. + + split is an iterable of (query, candidates, answer) tuples. + In these tuples, query is a tuple (text, e1, e2). + The relations are not given. + """ + tokenizer: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(str(path / "tokenizer")) + entity_dictionary = Dictionary() + + data: List[Tuple[torch.Tensor, int, int, int, int, List[List[torch.Tensor]], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]] = [] + for train, test, answer in split: + query_text, query_e1_pos, query_e2_pos = process_text_2(train[0], tokenizer) + query_e1 = entity_dictionary.encode(train[1]) + query_e2 = entity_dictionary.encode(train[2]) + + way = len(test) + shot = len(test[0]) + candidates_processed_text: List[List[Tuple[torch.Tensor, int, int]]] = list(map(lambda relation: list(map(lambda candidate: process_text_2(candidate[0], tokenizer), relation)), test)) + candidates_text_len = max(map(lambda relation: max(map(lambda candidate: candidate[0].shape[0], relation)), candidates_processed_text)) + + candidates_text = [[None]*shot for _ in range(way)] + candidates_e1_pos = torch.empty((way, shot), dtype=torch.int64) + candidates_e2_pos = torch.empty((way, shot), dtype=torch.int64) + candidates_e1 = torch.empty((way, shot), dtype=torch.int64) + candidates_e2 = torch.empty((way, shot), dtype=torch.int64) + + for n, (relation, relation_processed) in enumerate(zip(test, candidates_processed_text)): + for k, (candidate, candidate_processed) in enumerate(zip(relation, relation_processed)): + candidates_text[n][k] = candidate_processed[0] + candidates_e1_pos[n, k] = candidate_processed[1] + candidates_e2_pos[n, k] = candidate_processed[2] + candidates_e1[n, k] = entity_dictionary.encode(candidate[1]) + candidates_e2[n, k] = entity_dictionary.encode(candidate[2]) + + data.append((query_text, query_e1_pos, query_e2_pos, query_e1, query_e2, candidates_text, candidates_e1_pos, candidates_e2_pos, candidates_e1, candidates_e2, answer)) + entity_dictionary.save(path / f"{name}.entities") + torch.save(("sampled fewshot", data), path / name) + + +def serialize_dataset( + supervision: str, + path: pathlib.Path, + splits: Dict[str, Iterable[Tuple[str, str, str, str, str]]], + tokenizer_name: str, + unknown_entity: Optional[str] = None, + unknown_relation: Optional[str] = None, + seed: Optional[int] = None) -> None: + """ + Serialize a dataset to a given path. + + The splits must be given as iterables of (text, relation, relation_base, e1, e2) tuples. + supervision must be one of "supervised" or "fewshot". + """ + if not path.is_dir(): + path.mkdir() + + tokenizer: transformers.PreTrainedTokenizer = make_tokenizer(tokenizer_name, path / "tokenizer") + entity_dictionary = Dictionary(unknown=unknown_entity) + relation_dictionary = RelationDictionary(unknown=unknown_relation) + + serialize_split = serialize_supervised_split if supervision == "supervised" else serialize_fewshot_split + for split_name in ["train", "valid", "test"]: + if split_name not in splits: + continue + + split = list(splits[split_name]) + if split_name == "train": + rng = random.Random(seed) + rng.shuffle(split) + split = tqdm.tqdm(split, desc=f"{split_name} tokenization") + serialize_split(path / split_name, split, tokenizer, entity_dictionary, relation_dictionary) + entity_dictionary.save(path / "entities") + relation_dictionary.save(path / "relations") + + +def build_edge_list(data: Iterable[Tuple[str, List[Tuple[str, int, int]]]], tokenizer: transformers.PreTrainedTokenizer) -> Tuple[List[torch.Tensor], Dictionary, List[int], List[Tuple[int, int, int, int, int, int, int]]]: + """ + Build a list of edges and nodes corresponding to the given data. + + The tuples in the returned edge list are composed of the following elements: + (entity 1, entity 2, sentence id, entity 1 start, entity 1 end, entity 2 start, entity 2 end) + """ + sentences: List[str] = [] + entity_dictionary = Dictionary() + degrees: List[int] = [] + edges: List[Tuple[int, int, int, int, int, int, int]] = [] + + for raw_sentence, raw_entities in data: + sentence: torch.Tensor + entities: List[Tuple[str, int, int]] + sentence, entities = process_text_n(raw_sentence, raw_entities, tokenizer) + + # Buffer the ids to avoid re-hashing the entities + entity_ids: List[Optional[int]] = [None] * len(entities) + edge_added: bool = False + + # Add all edges appearing in the clique corresponding to this sentence + for i, (e1_name, e1_start, e1_end) in enumerate(entities): + for j, (e2_name, e2_start, e2_end) in enumerate(entities[:i]): + # Soares et al. footnote 2 "We use a window of 40 tokens" + if max(e2_end - e1_start, e1_end - e2_start) < 40: + if entity_ids[i] is None: + entity_ids[i] = entity_dictionary.encode(e1_name) + if entity_ids[i] >= len(degrees): + degrees.append(0) + e1_id: int = cast(int, entity_ids[i]) + + if entity_ids[j] is None: + entity_ids[j] = entity_dictionary.encode(e2_name) + if entity_ids[j] >= len(degrees): + degrees.append(0) + e2_id: int = cast(int, entity_ids[j]) + + if e1_id <= e2_id: + edges.append((e1_id, e2_id, len(sentences), e1_start, e1_end, e2_start, e2_end)) + else: + edges.append((e2_id, e1_id, len(sentences), e2_start, e2_end, e1_start, e1_end)) + degrees[e1_id] += 1 + degrees[e2_id] += 1 + edge_added = True + + if edge_added: + sentences.append(sentence) + + return sentences, entity_dictionary, degrees, edges + + +def serialize_unsupervised_dataset( + path: pathlib.Path, + data: Iterable[Tuple[str, List[Tuple[str, int, int]]]], + tokenizer_name: str, + seed: int) -> None: + """ + Serialize an unsupervised dataset to a given path. + + The data must be given as an iterable of (sentence, list of entities) tuples. + Where entities are tuples of (identifier, start indice in sentence, end indice in sentence). + """ + if not path.is_dir(): + path.mkdir() + + tokenizer: transformers.PreTrainedTokenizer = make_tokenizer(tokenizer_name, path / "tokenizer") + + sentences: List[str] + entities: Dictionary + edges: List[Tuple[int, int, int, int, int, int, int]] + graph = Graph(*build_edge_list(data, tokenizer)) + graph.save(path / "train") diff --git a/gbure/eval.py b/gbure/eval.py @@ -0,0 +1,34 @@ +from typing import Any, Dict +import pathlib + +import torch + +import gbure.train +import gbure.utils + + +class Evaluator(gbure.train.Trainer): + """ + Evaluate a model. + + Config: + valid: only evualte on validation split + test: only evualte on test split + """ + + def main(self) -> None: + """ Run the experiment (i.e. here, evaluate). """ + if self.config.get("valid") or not self.config.get("test"): + self.evaluate("valid") + if self.config.get("test") or not self.config.get("valid"): + self.evaluate("test") + + +if __name__ == "__main__": + gbure.utils.fix_transformers_logging_handler() + config: gbure.utils.dotdict = gbure.utils.parse_args() + + state_dicts: Dict[str, Any] = torch.load(config.load) + logdir: pathlib.Path = state_dicts["logdir"] + gbure.utils.add_logging_handler(logdir) + Evaluator(config, logdir, state_dicts).run() diff --git a/gbure/metrics.py b/gbure/metrics.py @@ -0,0 +1,295 @@ +from typing import Dict, Optional, Union +import math + +import torch +import transformers + +from gbure.data.dictionary import RelationDictionary +import gbure.data.graph + + +class Metrics: + """ + Class for computing metrics. + + Twenty metrics are computed: + - Optimized loss (usually negative log likelihood) + - Accuracy + - {directed, undirected, half_directed} {micro, macro} {f1, precision, recall} + Note that the Accuracy is the true accuracy, taking directionality into account and scoring the unknown relation as any other relation. + The last 18 metrics follow the SemEval scorer: + - The unknown ("Other") relation is only scored indirectly + - Directed is equivalent to the metrics "USING DIRECTIONALITY" + - Undirected is equivalent to the metrics "IGNORING DIRECTIONALITY" + - Half-directed is equivalent to the metrics "TAKING DIRECTIONALITY INTO ACCOUNT -- OFFICIAL" + Note that the directed and half_directed micro metrics are equivalents. + """ + + def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: RelationDictionary, graph: Optional[gbure.data.graph.Graph]) -> None: + """ Initialize all metrics. """ + self.config: gbure.utils.dotdict = config + self.tokenizer: transformers.PreTrainedTokenizer = tokenizer + self.relation_dictionary: RelationDictionary = relation_dictionary + self.graph: Optional[gbure.data.graph.Graph] = graph + self.num_relations: int = len(relation_dictionary) + self.num_base_relations: int = relation_dictionary.base_size() + + self.mask: torch.Tensor = self.build_mask() + self.base_transition: torch.Tensor = self.build_base_transition() + self.base_mask: torch.Tensor = self.base_transition.t().mv(self.mask).clamp(0, 1) + + self.loss_sum: float = 0.0 + self.relation_buckets: int = 1 + self.config.get("neighborhood_size", 0) + self.per_bucket_confusion: torch.Tensor = torch.zeros((self.relation_buckets, self.num_relations, self.num_relations), dtype=torch.int32) + + @property + def confusion(self): + return self.per_bucket_confusion.sum(0) + + def build_mask(self) -> torch.Tensor: + """ Return the relation mask used by semeval scorer (which partly ignore the unknown relation). """ + mask: torch.Tensor = torch.ones(self.num_relations) + if self.relation_dictionary.unknown is not None: + assert(self.relation_dictionary.decode(0) == self.relation_dictionary.unknown) + mask[0] = 0 + return mask + + def build_base_transition(self) -> torch.Tensor: + """ Return the transition matrix from "directed relations" to "undirected relations". """ + base_transition: torch.Tensor = torch.zeros((self.num_relations, self.num_base_relations)) + for id, bid in enumerate(self.relation_dictionary.id_to_bid): + base_transition[id, bid] = 1 + return base_transition + + def compute_neighborhood_bucket(self, batch: Dict[str, torch.Tensor], index: int) -> int: + """ Return an index between 0 and self.config.neighborhood_size corresponding to the minimum number of neighbors in the sample. """ + if self.graph is None: + return 0 + neighborhood_size: Union[float, int] = math.inf + for feature, value in batch.items(): + if "degree" in feature and "neighborhood" not in feature: + neighborhood_size = min(neighborhood_size, value[index].min().item()) + # TODO here substract 1 for unsupervised (not that important since we don't care about unsupervised accuracies) + return min(neighborhood_size, self.relation_buckets-1) if neighborhood_size != math.inf else 0 # pytype: disable=bad-return-type + + def update(self, batch: Dict[str, torch.Tensor], loss: torch.Tensor, losses: Dict[str, torch.Tensor], variables: Dict[str, torch.Tensor]) -> None: + """ + Update metrics according to the given batch and the outputs of the model on this batch. + + The variables dictionary returned by the model should usually contain a predicted_relation tensor. + + Args: + batch: the input values used for evaluation + loss: the loss optimized by the model + losses: intermediary (unweighted) losses + variables: internal variables used by the model to compute the loss + """ + predictions: torch.Tensor = variables.get("predicted_relation") + targets: torch.Tensor = batch.get("relation") + if targets is None: + targets = batch.get("query_relation") + + if predictions is None and targets is None: + predictions = variables.get("prediction_relative") + targets = batch.get("answer") + + for i, (prediction, target) in enumerate(zip(predictions, targets)): + neighborhood_bucket: int = self.compute_neighborhood_bucket(batch, i) + self.per_bucket_confusion[neighborhood_bucket, prediction, target] += 1 + + batch_size: int = predictions.shape[0] + self.loss_sum += loss.item() * batch_size + + @property + def summary(self) -> Dict[str, str]: + """ Return a summary of metrics to be quickly displayed. """ + metrics: Dict[str, str] = {"accuracy": f"{self.accuracy*100:.2f}", + "loss": f"{self.loss:.2f}"} + if self.relation_buckets > 1: + metrics.update({"accuracy_non_empty": f"{self.accuracy_non_empty*100:.2f}", + "accuracy_full": f"{self.accuracy_full*100:.2f}"}) + return metrics + + @property + def all(self) -> Dict[str, float]: + """ Return a dictionary of all metrics. """ + keys = ["accuracy", "accuracy_non_empty", "accuracy_full", "loss"] + [ + f"{direction}_{level}_{metric}" + for direction in ["directed", "undirected", "half_directed"] + for level in ["macro", "micro"] + for metric in ["f1", "precision", "recall"]] + return {key: getattr(self, key) for key in keys} + + @property + def base_confusion(self) -> torch.Tensor: + """ Confusion matrix between "undirected" relation classes. """ + return self.base_transition.t().mm(self.confusion.type_as(self.base_transition)).mm(self.base_transition) + + @property + def accuracy(self) -> float: + return math.nan if self.confusion.sum() == 0 else self.confusion.diagonal().sum() / self.confusion.sum().type(torch.float32) + + @property + def accuracy_non_empty(self) -> float: + non_empty: torch.Tensor = self.per_bucket_confusion[1:].sum(0) + return math.nan if non_empty.sum() == 0 else non_empty.diagonal().sum() / non_empty.sum().type(torch.float32) + + @property + def accuracy_full(self) -> float: + full: torch.Tensor = self.per_bucket_confusion[-1] + return math.nan if full.sum() == 0 else full.diagonal().sum() / full.sum().type(torch.float32) + + @property + def loss(self) -> float: + return math.nan if self.confusion.sum() == 0 else self.loss_sum / self.confusion.sum().type(torch.float32) + + ########################## + # Directed macro metrics # + ########################## + + @property + def directed_class_precision(self) -> torch.Tensor: + norm: torch.Tensor = self.confusion.sum(1) + norm[norm == 0] = 1 + return self.confusion.diagonal() / norm.type(torch.float32) + + @property + def directed_class_recall(self) -> torch.Tensor: + norm: torch.Tensor = self.confusion.sum(0) + norm[norm == 0] = 1 + return self.confusion.diagonal() / norm.type(torch.float32) + + @property + def directed_class_f1(self) -> torch.Tensor: + norm: torch.Tensor = self.directed_class_precision + self.directed_class_recall + norm[norm == 0] = 1 + return 2 * self.directed_class_precision * self.directed_class_recall / norm + + @property + def directed_macro_precision(self) -> float: + return ((self.directed_class_precision * self.mask).sum() / self.mask.sum()).item() + + @property + def directed_macro_recall(self) -> float: + return ((self.directed_class_recall * self.mask).sum() / self.mask.sum()).item() + + @property + def directed_macro_f1(self) -> float: + return ((self.directed_class_f1 * self.mask).sum() / self.mask.sum()).item() + + ############################ + # Undirected macro metrics # + ############################ + + @property + def undirected_class_precision(self) -> torch.Tensor: + norm: torch.Tensor = self.base_confusion.sum(1) + norm[norm == 0] = 1 + return self.base_confusion.diagonal() / norm + + @property + def undirected_class_recall(self) -> torch.Tensor: + norm: torch.Tensor = self.base_confusion.sum(0) + norm[norm == 0] = 1 + return self.base_confusion.diagonal() / norm + + @property + def undirected_class_f1(self) -> torch.Tensor: + norm: torch.Tensor = self.undirected_class_precision + self.undirected_class_recall + norm[norm == 0] = 1 + return 2 * self.undirected_class_precision * self.undirected_class_recall / norm + + @property + def undirected_macro_precision(self) -> float: + return ((self.undirected_class_precision * self.base_mask).sum() / self.base_mask.sum()).item() + + @property + def undirected_macro_recall(self) -> float: + return ((self.undirected_class_recall * self.base_mask).sum() / self.base_mask.sum()).item() + + @property + def undirected_macro_f1(self) -> float: + return ((self.undirected_class_f1 * self.base_mask).sum() / self.base_mask.sum()).item() + + ############################### + # Half-directed macro metrics # + ############################### + + @property + def half_directed_class_precision(self) -> torch.Tensor: + norm: torch.Tensor = self.base_confusion.sum(1) + norm[norm == 0] = 1 + return self.base_transition.t().mv(self.confusion.diagonal().type_as(self.base_transition)) / norm + + @property + def half_directed_class_recall(self) -> torch.Tensor: + norm: torch.Tensor = self.base_confusion.sum(0) + norm[norm == 0] = 1 + return self.base_transition.t().mv(self.confusion.diagonal().type_as(self.base_transition)) / norm + + @property + def half_directed_class_f1(self) -> torch.Tensor: + norm: torch.Tensor = self.half_directed_class_precision + self.half_directed_class_recall + norm[norm == 0] = 1 + return 2 * self.half_directed_class_precision * self.half_directed_class_recall / norm + + @property + def half_directed_macro_precision(self) -> float: + return ((self.half_directed_class_precision * self.base_mask).sum() / self.base_mask.sum()).item() + + @property + def half_directed_macro_recall(self) -> float: + return ((self.half_directed_class_recall * self.base_mask).sum() / self.base_mask.sum()).item() + + @property + def half_directed_macro_f1(self) -> float: + return ((self.half_directed_class_f1 * self.base_mask).sum() / self.base_mask.sum()).item() + + ################# + # Micro metrics # + ################# + + @property + def directed_micro_precision(self) -> float: + norm: torch.Tensor = (self.confusion.sum(1) * self.mask).sum() + return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item() + + @property + def directed_micro_recall(self) -> float: + norm: torch.Tensor = (self.confusion.sum(0) * self.mask).sum() + return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item() + + @property + def directed_micro_f1(self) -> float: + norm: float = self.directed_micro_precision + self.directed_micro_recall + return 0 if norm == 0 else 2 * (self.directed_micro_precision * self.directed_micro_recall) / norm + + @property + def half_directed_micro_precision(self) -> float: + norm: torch.Tensor = (self.confusion.sum(1) * self.mask).sum() + return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item() + + @property + def half_directed_micro_recall(self) -> float: + norm: torch.Tensor = (self.confusion.sum(0) * self.mask).sum() + return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item() + + @property + def half_directed_micro_f1(self) -> float: + norm: float = self.half_directed_micro_precision + self.half_directed_micro_recall + return 0 if norm == 0 else 2 * (self.half_directed_micro_precision * self.half_directed_micro_recall) / norm + + @property + def undirected_micro_precision(self) -> float: + norm: torch.Tensor = (self.base_confusion.sum(1) * self.base_mask).sum() + return 0 if norm == 0 else ((self.base_confusion.diagonal() * self.base_mask).sum() / norm).item() + + @property + def undirected_micro_recall(self) -> float: + norm: torch.Tensor = (self.base_confusion.sum(0) * self.base_mask).sum() + return 0 if norm == 0 else ((self.base_confusion.diagonal() * self.base_mask).sum() / norm).item() + + @property + def undirected_micro_f1(self) -> float: + norm: float = self.undirected_micro_precision + self.undirected_micro_recall + return 0 if norm == 0 else 2 * (self.undirected_micro_precision * self.undirected_micro_recall) / norm diff --git a/gbure/model/__init__.py b/gbure/model/__init__.py diff --git a/gbure/model/contrastive_alignment.py b/gbure/model/contrastive_alignment.py @@ -0,0 +1,85 @@ +from typing import Any, Dict, List, Tuple + +import torch +import transformers + +from gbure.model.linguistic_encoder import LinguisticEncoder +from gbure.model.masked_lm import MaskedLM +from gbure.model.similarity import LinguisticSimilarity, TopologicalSimilarity +from gbure.model.topological_encoder import TopologicalEncoder +import gbure.utils + + +class Model(torch.nn.Module): + """ + Unsupervised pre-training model from Soares et al. + + Correspond to the model explained in section 4. + """ + + def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: None) -> None: + """ + Instantiate a Soares et al. matching the blanks model. + + Args: + config: global config object + tokenizer: tokenizer used to create the vocabulary + relation_dictionary: dictionary of all relations (unused) + margin: maximum enforced meta-distance between positive and negative distances + """ + super().__init__() + + self.config: gbure.utils.dotdict = config + self.tokenizer: transformers.PreTrainedTokenizer = tokenizer + + self.transformer: transformers.PreTrainedModel = transformers.AutoModelForMaskedLM + self.language_model: torch.nn.Module = MaskedLM(config, tokenizer) + self.linguistic_encoder: torch.nn.Module = LinguisticEncoder(config, tokenizer, transformer=self.language_model.encoder) + self.topological_encoder: torch.nn.Module = TopologicalEncoder(config, self.linguistic_encoder) + self.linguistic_similarity: torch.nn.Module = LinguisticSimilarity(config) + self.topological_similarity: torch.nn.Module = TopologicalSimilarity(config) + + def forward(self, **batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + """ Compute the unsupervised matching the blanks loss between the given pairs. """ + linguistic_embeddings: List[torch.Tensor] = [] + topological_embeddings: List[Any] = [] + + for i, order in enumerate(["first_", "second_", "third_"]): + linguistic_embeddings.append(self.linguistic_encoder(batch[f"{order}text"], batch[f"{order}mask"], batch[f"{order}entity_positions"])[0]) + topological_embeddings.append(self.topological_encoder(order, linguistic_embeddings[-1], **batch)) + + # d(a_1, a_2, [1, 0]) + positive_linguistic_similarity: torch.Tensor = self.linguistic_similarity(linguistic_embeddings[0], linguistic_embeddings[1]) + # d(a_1, a_3, [1, 0]) + negative_linguistic_similarity: torch.Tensor = self.linguistic_similarity(linguistic_embeddings[0], linguistic_embeddings[2]) + + # d(a_1, a_2, [0, 1]) + positive_topological_similarity: torch.Tensor = self.topological_similarity(topological_embeddings[0], topological_embeddings[1])[0] + # d(a_1, a_3, [0, 1]) + negative_topological_similarity: torch.Tensor = self.topological_similarity(topological_embeddings[0], topological_embeddings[2])[0] + + # (d(a_1, a_2, [1, 0]) - d(a_1, a_2, [0, 1]))² + positive: torch.Tensor = 2 * (positive_linguistic_similarity - positive_topological_similarity)**2 + # (d(a_1, a_3, [1, 0]) - d(a_1, a_2, [0, 1]))² + (d(a_1, a_2, [1, 0]) - d(a_1, a_3, [0, 1]))² + negative: torch.Tensor = ((positive_linguistic_similarity - negative_topological_similarity)**2 + (negative_linguistic_similarity - positive_topological_similarity)**2) + + contrastive_loss: torch.Tensor = torch.nn.functional.relu(self.config.margin + positive - negative).mean() + if self.config.get("language_model_weight", 0) > 0: + lm_loss: torch.Tensor = self.language_model(batch["first_mlm_input"], batch["first_mlm_target"], batch["first_mask"]) + else: + lm_loss: int = 0 + loss: torch.Tensor = contrastive_loss + self.config.get("language_model_weight", 0) * lm_loss + + losses: Dict[str, torch.Tensor] = { + "positive": positive.mean(), + "negative": negative.mean(), + "contrastive": contrastive_loss, + "reconstruction": lm_loss} + variables: Dict[str, torch.Tensor] = { + "positive_linguistic_similarity": positive_linguistic_similarity, + "negative_linguistic_similarity": negative_linguistic_similarity, + "positive_topological_similarity": positive_topological_similarity, + "negative_topological_similarity": negative_topological_similarity + } + + return loss, losses, variables diff --git a/gbure/model/fewshot.py b/gbure/model/fewshot.py @@ -0,0 +1,117 @@ +from typing import Dict, Optional, Tuple + +import torch +import transformers + +import gbure.data.dictionary +import gbure.model.linguistic_encoder +import gbure.model.similarity +import gbure.model.topological_encoder +import gbure.utils + + +class Model(torch.nn.Module): + """ + Few shot model from Soares et al. + + Correspond to the right subfigure of Figure 2. + """ + + def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: gbure.data.dictionary.RelationDictionary, train_model: Optional[torch.nn.Module] = None) -> None: + """ + Instantiate a Soares et al. few shot model. + + Args: + config: global config object + tokenizer: tokenizer used to create the vocabulary + relation_dictionary: dictionary of all relations + train_model: unsupervised model used to initialize the few shot model. + """ + super().__init__() + + self.config: gbure.utils.dotdict = config + self.tokenizer: transformers.PreTrainedTokenizer = tokenizer + self.relation_dictionary: gbure.data.dictionary.RelationDictionary = relation_dictionary + + if train_model is None: + self.linguistic_encoder: torch.nn.Module = gbure.model.linguistic_encoder.LinguisticEncoder(config, tokenizer) + self.linguistic_similarity: torch.nn.Module = gbure.model.similarity.LinguisticSimilarity(config) + else: + self.linguistic_encoder: torch.nn.Module = train_model.linguistic_encoder + self.linguistic_similarity: torch.nn.Module = train_model.linguistic_similarity + self.loss_module = torch.nn.NLLLoss(reduction="mean") + + if self.config.get("neighborhood_size", 0) > 0: + if train_model is None: + self.topological_encoder: torch.nn.Module = gbure.model.topological_encoder.TopologicalEncoder(config, self.linguistic_encoder) + self.topological_similarity: torch.nn.Module = gbure.model.similarity.TopologicalSimilarity(config) + else: + self.topological_encoder: torch.nn.Module = train_model.topological_encoder + self.topological_similarity: torch.nn.Module = train_model.topological_similarity + if not self.config.get("undefined_poison_whole_meta"): + if train_model is not None: + self.neutral_topological_similarity = train_model.neutral_topological_similarity + elif self.config.get("neutral_topological_similarity") is not None: + self.neutral_topological_similarity: float = self.config.neutral_topological_similarity + else: + self.neutral_topological_similarity: torch.nn.Parameter = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) + + def combine_similarities(self, linguistic: torch.Tensor, topological: torch.Tensor, topological_mask: Optional[torch.Tensor]) -> torch.Tensor: + """ Combine linguistic and topological similarities into a single value. """ + # topological is of dimension (batch, way, shot, slot) + if topological_mask is not None: + if self.config.get("undefined_poison_whole_meta"): + topological *= topological_mask.prod(1, keepdim=True).prod(2, keepdim=True) + else: + topological += (~topological_mask) * self.neutral_topological_similarity + topological = topological.mean(3) + return self.config.get("linguistic_weight", 1) * linguistic + self.config.get("topological_weight", 1) * topological + + def forward(self, + query_text: torch.Tensor, + query_mask: torch.Tensor, + query_entity_positions: torch.Tensor, + candidates_text: torch.Tensor, + candidates_mask: torch.Tensor, + candidates_entity_positions: torch.Tensor, + answer: torch.Tensor, + query_relation: Optional[torch.Tensor] = None, + candidates_relation: Optional[torch.Tensor] = None, + **batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + """ Compute the fewshot loss on the given query and candidates. """ + batch_size: int = query_text.shape[0] + fake_batch_size: int = batch_size * self.config.shot * self.config.way + + query: torch.Tensor = self.linguistic_encoder(query_text, query_mask, query_entity_positions)[0] + candidates: torch.Tensor = self.linguistic_encoder( + candidates_text.view(fake_batch_size, -1), + candidates_mask.view(fake_batch_size, -1), + candidates_entity_positions.view(fake_batch_size, 2) + )[0].view(batch_size, self.config.way, self.config.shot, -1) + logits: torch.Tensor = self.linguistic_similarity(candidates, query.unsqueeze(1).unsqueeze(2)) + + if self.config.get("neighborhood_size", 0) > 0: + topological_query = self.topological_encoder("query_", query, degree_delta=1, **batch) + topological_candidates = self.topological_encoder("candidates_", candidates, degree_delta=1, **batch) + if isinstance(topological_query, torch.Tensor): + topological_query = topological_query.view(topological_query.shape[0], 1, 1, -1) + else: + topological_query = tuple(x.view(x.shape[0], 1, 1, *x.shape[1:]) for x in topological_query) + + topological_similarity, topological_mask = self.topological_similarity(topological_query, topological_candidates) + logits = self.combine_similarities(logits, topological_similarity, topological_mask) + + log_probabilities = torch.nn.functional.log_softmax( + logits.view(batch_size, self.config.way*self.config.shot), + dim=1).view(batch_size, self.config.way, self.config.shot) + log_probabilities = log_probabilities.logsumexp(2) + + loss: torch.Tensor = self.loss_module(log_probabilities, answer) + prediction: torch.Tensor = log_probabilities.argmax(1) + + variables = {"prediction_logits": logits, "prediction_relative": prediction} + if candidates_relation is not None: + batch_ids: torch.Tensor = torch.arange(batch_size, device=loss.device) + variables["predicted_relation"] = candidates_relation[batch_ids, prediction, 0] + + return loss, {}, variables diff --git a/gbure/model/linguistic_encoder.py b/gbure/model/linguistic_encoder.py @@ -0,0 +1,75 @@ +from typing import Callable, Optional, Tuple + +import torch +import transformers + +import gbure.utils + + +class LinguisticEncoder(torch.nn.Module): + """ + Transformer encoder from Soares et al. + + Correspond to the left part of each subfigure of Figure 2 (Deep Transformer Encoder and the green layer above). + We only implement the "entity markers, entity start" variant (which is the one with the best performance). + + Config: + transformer_model: Which transformer to use (e.g. bert-large-uncased). + post_transformer_layer: The transformation applied after the transformer (must be "linear" or "layer_norm") + """ + + def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, transformer: Optional[transformers.PreTrainedModel] = None) -> None: + """ + Instantiate a Soares et al. encoder. + + Args: + config: global config object + tokenizer: tokenizer used to create the vocabulary + transformer: the transformer to use instead of loading a pre-trained one + """ + super().__init__() + + self.config: gbure.utils.dotdict = config + self.tokenizer: transformers.PreTrainedTokenizer = tokenizer + + self.transformer: transformers.PreTrainedModel + if transformer is not None: + self.transformer = transformer + elif self.config.get("load") or self.config.get("pretrained"): + # TODO introduce a config parameter to change the initialization of <tags> embeddings + transformer_config = transformers.AutoConfig.from_pretrained(self.config.transformer_model) + transformer_config.vocab_size = len(tokenizer) + self.transformer = transformers.AutoModel.from_config(transformer_config) + else: + self.transformer = transformers.AutoModel.from_pretrained(self.config.transformer_model) + self.transformer.resize_token_embeddings(len(tokenizer)) + + self.post_transformer: Callable[[torch.Tensor], torch.Tensor] + if self.config.post_transformer_layer == "linear": + self.post_transformer_linear = torch.nn.Linear(in_features=self.output_size, out_features=self.output_size) + self.post_transformer = lambda x: self.post_transformer_linear(x) + elif self.config.post_transformer_layer == "layer_norm": + # It is not clear whether a Linear should be added before the layer_norm, see Soares et al. section 3.3 + # Setting elementwise_affine to True (the default) makes little sense when computing similarity scores. + self.post_transformer_linear = torch.nn.Linear(in_features=self.output_size, out_features=self.output_size) + self.post_transformer_activation = torch.nn.LayerNorm(self.output_size, elementwise_affine=False) + self.post_transformer = lambda x: self.post_transformer_activation(self.post_transformer_linear(x)) + elif self.config.post_transformer_layer == "none": + self.post_transformer = lambda x: x + else: + raise RuntimeError("Unsuported config value for post_transformer_layer") + + @property + def output_size(self) -> int: + """ Dimension of the representation returned by the model. """ + return 2 * self.transformer.config.hidden_size + + def forward(self, text: torch.Tensor, mask: torch.Tensor, entity_positions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ Encode the given text into a fixed size representation. """ + batch_size: int = text.shape[0] + batch_ids: torch.Tensor = torch.arange(batch_size, device=text.device, dtype=torch.int64).unsqueeze(1) + + # The first element of the tuple is the Batch×Sentence×Hidden output matrix. + transformer_out: torch.Tensor = self.transformer(text, attention_mask=mask)[0] + sentence: torch.Tensor = transformer_out[batch_ids, entity_positions].view(batch_size, self.output_size) + return self.post_transformer(sentence), transformer_out diff --git a/gbure/model/masked_lm.py b/gbure/model/masked_lm.py @@ -0,0 +1,59 @@ +from typing import Callable + +import torch +import transformers + +import gbure.utils + + +class MaskedLM(torch.nn.Module): + """ + Masked language model to be used on top of a transformer. + + This class is only useful for unsupervised pre-training, Soares et al. keep the BERT loss alongside their "matching the blanks" loss. + + Config: + transformer_model: Which transformer to use (e.g. bert-large-uncased). + """ + + def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer) -> None: + """ + Instantiate a masked language model. + + Args: + config: global config object + tokenizer: tokenizer used to create the vocabulary + """ + super().__init__() + + self.config: gbure.utils.dotdict = config + self.tokenizer: transformers.PreTrainedTokenizer = tokenizer + + self.transformer: transformers.PreTrainedModel + if self.config.get("load") or self.config.get("pretrained"): + # TODO introduce a config parameter to change the initialization of <tags> embeddings + transformer_config = transformers.AutoConfig.from_pretrained(self.config.transformer_model) + transformer_config.vocab_size = len(tokenizer) + self.transformer = transformers.AutoModelForMaskedLM(transformer_config) + else: + self.transformer = transformers.AutoModelForMaskedLM.from_pretrained(self.config.transformer_model) + self.transformer.resize_token_embeddings(len(tokenizer)) + + @property + def encoder(self) -> transformers.PreTrainedModel: + if isinstance(self.transformer, transformers.BertForMaskedLM): + return self.transformer.bert + elif isinstance(self.transformer, transformers.DistilBertForMaskedLM): + return self.transformer.distilbert + else: + raise RuntimeError("Unknown transformer model, can't split masked language model off") + + def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Compute masked language model loss from transformer's output. + + Sadly, huggingface does not provide an unified interface, so we need to do some copy-pasting. + """ + masked_target: torch.Tensor = target * mask - 100 * (1 - mask.long()) + output = self.transformer(input_ids=input, attention_mask=mask, labels=masked_target, return_dict=True) + return output.loss diff --git a/gbure/model/matching_the_blanks.py b/gbure/model/matching_the_blanks.py @@ -0,0 +1,109 @@ +from typing import Dict, Optional, Tuple +import math + +import torch +import transformers + +import gbure.model.linguistic_encoder +import gbure.model.topological_encoder +import gbure.model.masked_lm +import gbure.model.similarity +import gbure.utils + + +class Model(torch.nn.Module): + """ + Unsupervised pre-training model from Soares et al. + + Correspond to the model explained in Section 4. + + Config: + linguistic_weight: factor for the linguistic similarity + topological_weight: factor for the topological similarity + """ + + def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: None) -> None: + """ + Instantiate a Soares et al. matching the blanks model. + + Args: + config: global config object + tokenizer: tokenizer used to create the vocabulary + relation_dictionary: dictionary of all relations (unused) + """ + super().__init__() + + self.config: gbure.utils.dotdict = config + self.tokenizer: transformers.PreTrainedTokenizer = tokenizer + + self.transformer: transformers.PreTrainedModel = transformers.AutoModelForMaskedLM + self.language_model: torch.nn.Module = gbure.model.masked_lm.MaskedLM(config, tokenizer) + self.linguistic_encoder: torch.nn.Module = gbure.model.linguistic_encoder.LinguisticEncoder(config, tokenizer, transformer=self.language_model.encoder) + self.linguistic_similarity: torch.nn.Module = gbure.model.similarity.LinguisticSimilarity(config) + + if self.config.get("neighborhood_size", 0) > 0: + self.topological_encoder: torch.nn.Module = gbure.model.topological_encoder.TopologicalEncoder(config, self.linguistic_encoder) + self.topological_similarity: torch.nn.Module = gbure.model.similarity.TopologicalSimilarity(config) + if not self.config.get("undefined_poison_whole_meta"): + if self.config.get("neutral_topological_similarity") is not None: + self.neutral_topological_similarity: float = self.config.neutral_topological_similarity + else: + self.neutral_topological_similarity: torch.nn.Parameter = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) + + def combine_similarities(self, linguistic: torch.Tensor, topological: torch.Tensor, topological_mask: Optional[torch.Tensor]) -> torch.Tensor: + """ Combine linguistic and topological similarities into a single value. """ + # topological is of dimension (batch, slot) + if topological_mask is not None: + if not self.config.get("undefined_poison_whole_meta"): + topological += (~topological_mask) * self.neutral_topological_similarity + topological = topological.mean(1) + return self.config.get("linguistic_weight", 1) * linguistic + self.config.get("topological_weight", 1) * topological + + def forward(self, + first_text: torch.Tensor, + first_mask: torch.Tensor, + first_entity_positions: torch.Tensor, + second_text: torch.Tensor, + second_mask: torch.Tensor, + second_entity_positions: torch.Tensor, + polarity: torch.Tensor, + first_mlm_input: torch.Tensor, + first_mlm_target: torch.Tensor, + **batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + """ Compute the unsupervised matching the blanks loss between the given pairs. """ + first: torch.Tensor + first_transformer_out: torch.Tensor + first, first_transformer_out = self.linguistic_encoder(first_text, first_mask, first_entity_positions) + second: torch.Tensor = self.linguistic_encoder(second_text, second_mask, second_entity_positions)[0] + linguistic_similarity: torch.Tensor = self.linguistic_similarity(first, second) + lm_loss: torch.Tensor = self.language_model(first_mlm_input, first_mlm_target, first_mask) + + similarity: torch.Tensor = linguistic_similarity + if self.config.get("neighborhood_size", 0) > 0: + topological_first = self.topological_encoder("first_", first, **batch) + topological_second = self.topological_encoder("second_", second, **batch) + + topological_similarity, topological_mask = self.topological_similarity(topological_first, topological_second) + similarity = self.combine_similarities(linguistic_similarity, topological_similarity, topological_mask) + + # There seem to be a mistake in Soares et al. §4.1 in the equation of p(l=1|r,r') + # The equation use 1/(1+exp(x)) which seems counter intuitive since the case where r=r' would lead to a low probability. + # By default 1/(1+exp(-x)) is used, but the equation given in the paper can be used with --reverse_sigmoid. + if self.config.get("reverse_sigmoid"): + scores: torch.Tensor = - torch.nn.functional.logsigmoid(- polarity * similarity) + else: + scores: torch.Tensor = - torch.nn.functional.logsigmoid(polarity * similarity) + mtb_loss: torch.Tensor = scores.mean() + + is_positive: torch.Tensor = (polarity + 1) // 2 + is_negative: torch.Tensor = 1 - is_positive + positive: torch.Tensor = (scores * is_positive).sum() / is_positive.sum() + negative: torch.Tensor = (scores * is_negative).sum() / is_negative.sum() + + losses: Dict[str, torch.Tensor] = { + "positive": positive, + "negative": negative, + "mtb": mtb_loss, + "reconstruction": lm_loss} + loss: torch.Tensor = mtb_loss + self.config.language_model_weight * lm_loss + return loss, losses, {"similarity": similarity} diff --git a/gbure/model/similarity.py b/gbure/model/similarity.py @@ -0,0 +1,134 @@ +from typing import Dict, List, Optional, Tuple, Union + +import geomloss +import torch + +import gbure.utils + + +class LinguisticSimilarity(torch.nn.Module): + """ + Compute the similarity between two bertcoder representations. + + The bertcoder embeddings are all in the same direction of space since they all correspond to the tags <e1> and <e2> + Thus after a dot product, the activations are not standardized anymore, even after scaling by √d + + Config: + linguistic_similarity: the function used to compute the similarity between embeddings + linguistic_similarity_delta: an additive constant to add to all similarity (useful to have a strictly positive cosine) + latent_metric_scale: how to scale the similarity once computed + latent_dot_mean: when latent_metric_scale=="standard", the value to substract + latent_dot_std: when latent_metric_scale=="standard", the value to divide by + """ + + def __init__(self, config: gbure.utils.dotdict) -> None: + super().__init__() + self.config = config + + def forward(self, lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor: + """ Compute the similarity between lhs and rhs. """ + if self.config.linguistic_similarity == "dot": + similarity: torch.Tensor = torch.einsum('b...d,b...d->b...', lhs, rhs) + elif self.config.linguistic_similarity == "cosine": + similarity: torch.Tensor = torch.nn.functional.cosine_similarity(lhs, rhs, dim=-1) + elif self.config.linguistic_similarity == "euclidean": + similarity: torch.Tensor = - torch.sum((lhs - rhs)**2, dim=-1) + else: + raise RuntimeError(f"Unknown config value for linguistic_similarity: {self.config.linguistic_similarity}") + similarity += self.config.get("linguistic_similarity_delta", 0) + + encoder_output_size: int = lhs.shape[-1] + if self.config.get("latent_metric_scale") == "sqrt": + # If X, Y ~ N(0, I_d): + # E[X·Y] = 0 + # Var[X·Y] = √d + similarity = similarity / (encoder_output_size ** 0.5) + elif self.config.get("latent_metric_scale") == "full": + similarity = similarity / encoder_output_size + elif self.config.get("latent_metric_scale") == "match": + # If X ~ N(0, 1): + # E[X²] = 1 + # Var[X²] = √2 + similarity = (similarity - encoder_output_size) / ((2**0.5) * encoder_output_size) + elif self.config.get("latent_metric_scale") == "standard": + similarity = (similarity - self.config.latent_dot_mean) / self.config.latent_dot_std + elif self.config.get("latent_metric_scale") is not None: + raise RuntimeError("Unsuported config value for latent_metric_scale") + + return similarity + + +class TopologicalSimilarity(torch.nn.Module): + """ + Compute the similarity between two neighborhood representations. + + This similarity is either an inner product in the case of fixed-size representations or the negative of the 1-Wasserstein distance. + """ + + def __init__(self, config: gbure.utils.dotdict) -> None: + super().__init__() + self.config = config + if self.config.get("gcn_aggregator", "none") == "none": + self.sinkhorn = geomloss.SamplesLoss(loss="sinkhorn", p=self.config.get("wasserstein_underlying_distance", 2), blur=self.config.get("sinkhorn_blur", 0.05)) + + @staticmethod + def merge_shapes(lhs: List[int], rhs: List[int]) -> List[int]: + """ Find the shape of an elementwise operation between two tensors of the given shapes. """ + lhs = [1] * (len(rhs) - len(lhs)) + lhs + rhs = [1] * (len(lhs) - len(rhs)) + rhs + res = [] + for left, right in zip(lhs, rhs): + if left == right: + res.append(left) + elif left == 1 or right == 1: + res.append(left + right - 1) + else: + raise RuntimeError(f"Incompatible shapes {lhs} and {rhs}") + return res + + def forward(self, lhs: Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], rhs: Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute the distance between two topological embeddings. + + The embeddings are either a single vector (per sample) which was pooled using a GCN, or four matrices corresponding to: + - the embeddings of the head neighborhood + - the embeddings of the tail neighborhood + - the mask of the head neighborhood + - the mask of the tail neighborhood + """ + if isinstance(lhs, tuple): + shape: List[int] = self.merge_shapes(list(lhs[0].shape)[:-2], list(rhs[0].shape)[:-2]) + + def expand_tuple(t: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return (t[0].expand(tuple(shape)+t[0].shape[-2:]), + t[1].expand(tuple(shape)+t[1].shape[-2:]), + t[2].expand(tuple(shape)+t[2].shape[-1:]), + t[3].expand(tuple(shape)+t[3].shape[-1:])) + + lhs = expand_tuple(lhs) + rhs = expand_tuple(rhs) + + cumulative_shape: List[int] = [] + accumulator: int = 1 + for x in shape: + cumulative_shape.append(accumulator) + accumulator *= x + result: torch.Tensor = lhs[0].new_zeros(shape+[2]) + result_mask: torch.Tensor = lhs[0].new_zeros(shape+[2], dtype=torch.bool) + for i in range(accumulator): + indices = [] + for denominator, modulo in zip(cumulative_shape, shape): + indices.append(i // denominator % modulo) + indices = tuple(indices) + for slot in range(2): + sample_lhs: torch.Tensor = lhs[slot][indices][lhs[2+slot][indices]] + sample_rhs: torch.Tensor = rhs[slot][indices][rhs[2+slot][indices]] + if sample_lhs.numel() > 0 and sample_rhs.numel() > 0: + result[indices+(slot,)] = - self.sinkhorn(sample_lhs, sample_rhs) + result_mask[indices+(slot,)] = True + else: + result_mask[indices+(slot,)] = False + return result, result_mask + else: + # TODO implement other topological similarities + return torch.einsum('b...d,b...d->b...', lhs, rhs), None diff --git a/gbure/model/supervised.py b/gbure/model/supervised.py @@ -0,0 +1,45 @@ +from typing import Dict, Tuple + +import torch +import transformers + +from gbure.data.dictionary import RelationDictionary +from gbure.model.linguistic_encoder import LinguisticEncoder +import gbure.utils + + +class Model(torch.nn.Module): + """ + Supervised model from Soares et al. + + Correspond to the left subfigure of Figure 2. + """ + + def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: RelationDictionary) -> None: + """ + Instantiate a Soares et al. supervised model. + + Args: + config: global config object + tokenizer: tokenizer used to create the vocabulary + relation_dictionary: dictionary of all relations + """ + super().__init__() + + self.config: gbure.utils.dotdict = config + self.tokenizer: transformers.PreTrainedTokenizer = tokenizer + self.relation_dictionary: RelationDictionary = relation_dictionary + + self.encoder: torch.nn.Module = LinguisticEncoder(config, tokenizer) + self.relation_encoder = torch.nn.Linear( + in_features=self.encoder.output_size, + out_features=len(relation_dictionary), + bias=False) + self.loss_module = torch.nn.CrossEntropyLoss(reduction="mean") + + def forward(self, text: torch.Tensor, mask: torch.Tensor, entity_positions: torch.Tensor, relation: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + """ Compute the supervised loss on the given text and target relation. """ + latent: torch.Tensor = self.encoder(text, mask, entity_positions)[0] + logits: torch.Tensor = self.relation_encoder(latent) + loss: torch.Tensor = self.loss_module(logits, relation) + return loss, {}, {"prediction_logits": logits, "predicted_relation": logits.argmax(1)} diff --git a/gbure/model/topological_encoder.py b/gbure/model/topological_encoder.py @@ -0,0 +1,65 @@ +from typing import List, Optional, Tuple, Union +import functools +import math +import operator + +import torch + +import gbure.utils + + +class TopologicalEncoder(torch.nn.Module): + """ + Encoder for neighborhoods. + + Config: + gcn_aggregator: aggregator used to pool the representations of several neighbors into a fixed-size one. + """ + + def __init__(self, config: gbure.utils.dotdict, linguistic_encoder: torch.nn.Module) -> None: + """ + Instantiate a Soares et al. encoder. + + Args: + config: global config object + linguistic_encoder: the model used to get a fixed-size representation of text + """ + super().__init__() + self.config: gbure.utils.dotdict = config + self.linguistic_encoder: torch.nn.Module = linguistic_encoder + + if self.config.get("gcn_aggregator", "") in ["mean", "chebyshev"]: + self.gcn_layer: torch.nn.Module = torch.nn.Linear(in_features=self.linguistic_encoder.output_size, out_features=self.linguistic_encoder.output_size) + + def forward(self, prefix: str, loop: torch.Tensor, degree_delta: int = 0, **batch) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: + """ + Encode the neighborhood of the given prefix. + + When a gcn_aggregator is defined, this result in a fixed-size representation, otherwise it returns clouds of points to be compared using optimal transport. + The degree_delta parameters changes the degrees used to compute various GCN weighting. This can be useful when the sample comes from outside the graph so 1 should be added to the degrees. + """ + linguistic_embeddings: List[torch.Tensor] = [] + masks: List[torch.Tensor] = [] + for slot in [1, 2]: + fake_batch_size: int = functools.reduce(operator.mul, batch[f"{prefix}e{slot}_neighborhood_text"].shape[:-1]) + mask: torch.Tensor = batch[f"{prefix}e{slot}_neighborhood_mask"].view(fake_batch_size, -1)[:, 0].unsqueeze(1) + linguistic_embeddings.append((self.linguistic_encoder( + batch[f"{prefix}e{slot}_neighborhood_text"].view(fake_batch_size, -1), + batch[f"{prefix}e{slot}_neighborhood_mask"].view(fake_batch_size, -1), + batch[f"{prefix}e{slot}_neighborhood_entity_positions"].view(fake_batch_size, 2) + )[0] * mask).view(*batch[f"{prefix}e{slot}_neighborhood_text"].shape[:-1], self.linguistic_encoder.output_size)) + masks.append(mask.view(*batch[f"{prefix}e{slot}_neighborhood_text"].shape[:-1])) + + if self.config.get("gcn_aggregator", "") == "mean": + head: torch.Tensor = linguistic_embeddings[0].sum(-2) + tail: torch.Tensor = linguistic_embeddings[1].sum(-2) + neighborhood_size: torch.Tensor = sum(mask.sum(-1, keepdim=True) for mask in masks) + return self.gcn_layer((loop + head + tail) / (neighborhood_size + 1)) + elif self.config.get("gcn_aggregator", "") == "chebyshev": + pre_embedding: torch.Tensor = loop / torch.sqrt(2 * (batch[f"{prefix}entity_degrees"].sum(-1, keepdim=True) - 1 + degree_delta)) + for slot in [1, 2]: + weights = 1 / torch.sqrt(batch[f"{prefix}e{slot}_neighborhood_entity_degrees"].sum(-1, keepdim=True) - 1 + degree_delta) + pre_embedding += torch.sum(weights * linguistic_embeddings[slot-1], dim=-2) + return self.gcn_layer(pre_embedding) + else: + return (linguistic_embeddings[0], linguistic_embeddings[1], masks[0], masks[1]) diff --git a/gbure/outputs.py b/gbure/outputs.py @@ -0,0 +1,55 @@ +from __future__ import annotations +from typing import Dict, Optional, Type +import pathlib +import types + +import torch +import transformers + +from gbure.data.dictionary import RelationDictionary + + +class Outputs: + """ + Class for outputing data about the learned model. + """ + # TODO parametrize this class + + def __init__(self, logdir: pathlib.Path, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: RelationDictionary) -> None: + """ Initialize the outputs variables, but do not acquire necessary resources yet. """ + self.logdir: pathlib.Path = logdir + self.tokenizer: transformers.PreTrainedTokenizer = tokenizer + self.relation_dictionary: RelationDictionary = relation_dictionary + + def __enter__(self) -> Outputs: + """ Acquire resources needed for outputing the data. """ + # TODO parametrize this file name with split and epoch + self.target_prediction_file = (self.logdir / "target_prediction").open("w") + return self + + def __exit__(self, + exc_type: Optional[Type[BaseException]], + exc_inst: Optional[BaseException], + exc_tb: Optional[types.TracebackType]) -> None: + """ Free used resources. """ + self.target_prediction_file.close() + + def update(self, batch: Dict[str, torch.Tensor], loss: torch.Tensor, losses: Dict[str, torch.Tensor], variables: Dict[str, torch.Tensor]) -> None: + """ + Update model output data files with the given batch and the outputs of the model on this batch. + + The variables dictionary returned by the model must contain a predicted_relation tensor. + + Args: + batch: the input values used for evaluation + loss: the loss optimized by the model + losses: intermediary (unweighted) losses + variables: internal variables used by the model to compute the loss + """ + targets: torch.Tensor = batch.get("relation", batch.get("query_relation")) + predictions: torch.Tensor = variables.get("predicted_relation") + if targets is not None and predictions is not None: + for prediction, target in zip(predictions, targets): + print(f"{self.relation_dictionary.decode(target)}\t{self.relation_dictionary.decode(prediction)}", file=self.target_prediction_file) + elif "prediction_relative" in variables: + print("\n".join(map(str, variables["prediction_relative"].tolist())), file=self.target_prediction_file) diff --git a/gbure/train.py b/gbure/train.py @@ -0,0 +1,421 @@ +from typing import Any, Callable, Dict, Iterable, List, Optional +import contextlib +import gc +import logging +import math +import os +import pathlib +import shutil + +from torch.utils.tensorboard import SummaryWriter +import torch +import torch.utils +import tqdm +import transformers + +import gbure.data.batcher +import gbure.data.dataset +import gbure.data.dictionary +import gbure.data.graph +import gbure.metrics +import gbure.outputs +import gbure.utils + +logger = logging.getLogger(__name__) + + +class Trainer(gbure.utils.Experiment): + """ + Train a model. + + Config: + Model: the model class to use for training and evaluation + Optimizer: the optimizer class to use for training + Scheduler: the learning rate scheduler class to use + accumulated_batch_size: the actual number of sample in a batch after accumulation, that is the number of samples seen before an optimizer step (must be a multiple of batch_size) + amp: enable automatic mixed precision and gradient scaler + batch_size: the number of samples in the batch of data loaded + eval_batch_size: batch size used for evaluation + clip_gradient: the maximum norm of the gradient + dataset_name: name of the dataset to load + dataset_spec: dataset specification, usually None, can be used to select a (smaller) test version + eval_dataset_name: overwrite evaluation dataset + eval_dataset_spec: overwrite evaluation dataset specification + early_stopping_patience: how many epoch to train after best validation score has been reached + learning_rate: learning rate + max_epoch: maximum number of epoch + no_initial_validation: do not run evaluation on the valid dataset before first epoch + optimizer_hyperparameters: hyperparameters for the optimizer (e.g. weight decay, etc) + pretrained: path to a pretrained model to load + scheduler_parameters: the parameters for initializing the Scheduler + test_output: path to a file where the test predictions will be written + transformer_model: the model of transformer to use + unsupervised: train an unsupervised model + validation_metric: metric used for early stopping + workers: number of data generating workers to spawn + """ + + def init(self) -> None: + """ Prepare training. """ + self.prepare_dataset() + self.build_model() + self.count_parameters() + self.setup_optimizer() + self.init_writer() + self.init_epochs() + + def main(self) -> None: + """ Run the experiment (i.e. here, train). """ + self.train() + + def prepare_dataset(self) -> None: + """ Load datasets and create iterators. """ + dataset_spec: str = self.config.transformer_model + self.config.get("dataset_spec", "") + self.data_dir: pathlib.Path = gbure.utils.DATA_PATH / self.config.dataset_name / dataset_spec + + self.tokenizer: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(str(self.data_dir / "tokenizer")) + self.batcher = gbure.data.batcher.Batcher(self.tokenizer.pad_token_id) + self.relation_dictionary: Optional[gbure.data.dictionary.RelationDictionary] = None if self.config.get("unsupervised") else gbure.data.dictionary.RelationDictionary(path=self.data_dir / "relations") + entities_path: pathlib.Path = self.data_dir / "train" / "entities" if self.config.get("unsupervised") else self.data_dir / "entities" + self.entity_dictionary: gbure.data.dictionary.Dictionary = gbure.data.dictionary.Dictionary(path=entities_path) + + if self.config.get("eval_dataset_name"): + eval_dataset_spec: str = self.config.transformer_model + self.config.get("eval_dataset_spec", "") + self.eval_data_dir: pathlib.Path = gbure.utils.DATA_PATH / self.config.eval_dataset_name / eval_dataset_spec + # We assume the tokenizer is the same. + self.eval_relation_dictionary: gbure.data.dictionary.RelationDictionary = gbure.data.dictionary.RelationDictionary(path=self.eval_data_dir / "relations") + # FIXME doesn't work when a test dataset is specified … + entities_path: pathlib.Path = self.data_dir / f"{self.config.valid_name}.entities" if self.config.get("valid_name") else self.data_dir / "entities" + self.eval_entity_dictionary: gbure.data.dictionary.Dictionary = gbure.data.dictionary.Dictionary(path=self.eval_data_dir / "entities") + else: + self.eval_data_dir: pathlib.Path = self.data_dir + self.eval_relation_dictionary: Optional[gbure.data.dictionary.RelationDictionary] = self.relation_dictionary + self.eval_entity_dictionary: gbure.data.dictionary.Dictionary = self.entity_dictionary + + self.dataset: Dict[str, torch.utils.data.Dataset] = {} + self.iterator: Dict[str, Callable[[], Any]] = {} + graph_buffer: Optional[gbure.data.graph.Graph] = None + + for split in ["train", "valid", "test"]: + kwargs: Dict[str, Any] = {"rng": self.state_dicts.get("train_rng")} if split == "train" and self.state_dicts else {} + split_path = pathlib.Path(split) + data_dir = self.data_dir if split == "train" else self.eval_data_dir + if f"{split}_name" in self.config: + split_path = pathlib.Path(self.config[f"{split}_name"]) + + try: + self.dataset[split] = gbure.data.dataset.load_dataset( + config=self.config, + split=split, + path=data_dir / split_path, + tokenizer=self.tokenizer, + evaluation=(split != "train"), + **kwargs) + except FileNotFoundError: + if split == "train": + raise + continue + + if self.config.get("graph_name"): + graph_spec: str = self.config.transformer_model + self.config.get("graph_spec", "") + graph_dir: pathlib.Path = gbure.utils.DATA_PATH / self.config.graph_name / graph_spec + entity_dictionary = self.entity_dictionary if split == "train" else self.eval_entity_dictionary + self.dataset[split] = gbure.data.dataset.GraphAdapter( + self.dataset[split], + entity_dictionary, + path=graph_dir / "train", + graph=graph_buffer) + graph_buffer = self.dataset[split].graph + + self.iterator[split] = lambda trainer=self, split=split: tqdm.tqdm( + iterable=torch.utils.data.DataLoader( + dataset=trainer.dataset[split], + collate_fn=trainer.batcher, + batch_size=trainer.config.batch_size if split == "train" else trainer.config.get("eval_batch_size", trainer.config.batch_size), + shuffle=(split == "train" and trainer.dataset[split].shuffleable), + num_workers=trainer.config.workers, + worker_init_fn=getattr(trainer.dataset[split], "init_seed", None), + pin_memory=(trainer.device.type == "cuda")), + desc=f"Epoch {trainer.epoch:2} {split:5}", + unit="samples", + unit_scale=trainer.config.batch_size if split == "train" else trainer.config.get("eval_batch_size", trainer.config.batch_size), + total=math.ceil(len(trainer.dataset[split]) / (trainer.config.batch_size if split == "train" else trainer.config.get("eval_batch_size", trainer.config.batch_size))), + leave=False) + + def build_model(self) -> None: + """ Instantiate the model. """ + self.model: torch.nn.Module = self.config.Model(self.config, self.tokenizer, self.relation_dictionary) + if self.state_dicts: + missing_keys: List[str] + unexpected_keys: List[str] + missing_keys, unexpected_keys = self.model.load_state_dict(self.state_dicts["model"], strict=False) + if missing_keys or unexpected_keys: + unexpected_display: List[str] = unexpected_keys + if self.config.get("pretrained"): + unexpected_display = list(filter(lambda key: not key.startswith("language_model."), unexpected_keys)) + if len(unexpected_display) < len(unexpected_keys): + unexpected_display.append("language_model.*") + print("\033[31mLoading state_dict mismatch.\033[0m\n") + print(f"\033[31mMissing keys: {' '.join(missing_keys)}.\033[0m\n") + print(f"\033[31mUnexpected keys: {' '.join(unexpected_display)}.\033[0m\n") + assert(self.config.get("pretrained")) + self.model.to(self.device) + if self.config.get("EvalModel"): + self.eval_model: torch.nn.Module = self.config.EvalModel(self.config, self.tokenizer, self.relation_dictionary, train_model=self.model) + else: + self.eval_model: torch.nn.Module = self.model + self.mixed_precision = torch.cuda.amp.autocast if self.config.get("amp") else contextlib.nullcontext + + def count_parameters(self) -> None: + """ Display the number of parameters used by the model. """ + total: int = 0 + for parameter in self.model.parameters(): + total += parameter.shape.numel() + print(f"\033[33mNumber of parameters: {total:,}\033[0m") + + def setup_optimizer(self) -> None: + """ Instantiate the optimizer. """ + optimizer_hyperparameters: Dict[str, Any] = self.config.get("optimizer_hyperparameters", {}) + optimizer_hyperparameters["lr"] = self.config.learning_rate + self.optimizer: torch.optim.optimizer.Optimizer = self.config.Optimizer(self.model.parameters(), **optimizer_hyperparameters) + # if self.state_dicts and "optimizer" in self.state_dicts: + # self.optimizer.load_state_dict(self.state_dicts["optimizer"]) + if self.config.get("Scheduler"): + total_iters: int = math.ceil(self.config.max_epoch * len(self.dataset["train"]) / self.config.accumulated_batch_size) + self.scheduler = self.config.Scheduler(self.optimizer, total_iters=total_iters, **self.config.get("scheduler_parameters", {})) + if self.config.get("amp"): + self.scaler = torch.cuda.amp.GradScaler(self.config.get("initial_grad_scale", 65536)) + + def save(self, path: pathlib.Path) -> None: + """ Save the model and all additional state which might be needed to resume training. """ + state_dicts: Dict[str, Any] = { + "logdir": self.logdir, + "config": self.config, + "model": self.model.state_dict(), + "optimizer": self.optimizer.state_dict(), + "torch_rng": torch.random.get_rng_state(), + "epoch": self.epoch, + "best_epoch": self.best_epoch, + "best_eval": self.best_eval, + "global_batch_id": self.global_batch_id, + "results": self.results, + } + if self.config.workers == 0: + state_dicts["train_rng"] = getattr(self.dataset["train"], "rng", None) + if torch.cuda.is_available(): + state_dicts["cuda_rng"] = torch.cuda.random.get_rng_state_all() + torch.save(state_dicts, path) + + def init_writer(self) -> None: + """ Initialize summary writer for tensorboard logging. """ + self.writer = SummaryWriter(log_dir=self.logdir) + self.results: Dict[str, float] = {} + if self.state_dicts: + self.results = self.state_dicts.get("results", self.results) + + def close(self) -> None: + """ Close the writer. """ + self.writer.close() + + def init_epochs(self) -> None: + """ + Given initial values for early stopping. + + Assume a "higher is better" metric. + """ + self.epoch: int = 0 + self.best_epoch: int = 0 + self.best_eval: float = -math.inf + self.global_batch_id: int = 0 + + if self.state_dicts: + self.epoch = self.state_dicts.get("epoch", self.epoch) + self.best_epoch = self.state_dicts.get("best_epoch", self.best_epoch) + self.best_eval = self.state_dicts.get("best_eval", self.best_eval) + self.global_batch_id = self.state_dicts.get("global_batch_id", self.global_batch_id) + + def checkpoint(self) -> None: + """ Save current model. """ + self.save(self.logdir / "checkpoint.new") + os.rename(self.logdir / "checkpoint.new", self.logdir / "checkpoint") + + def update_best_model(self, candidate: float) -> bool: + """ Update current best model and return whether we should early stop. """ + if candidate is math.nan: + return False + + if candidate > self.best_eval: + logger.info(f"Updating best model at epoch {self.epoch} (metric {candidate} > {self.best_eval})") + self.best_epoch = self.epoch + self.best_eval = candidate + shutil.copy(self.logdir / "checkpoint", self.logdir / "best") + return (self.epoch - self.best_epoch > self.config.get("early_stopping_patience", self.config.max_epoch)) + + def load_best_model(self) -> None: + """ Load the model with the best validation score. """ + best_path: pathlib.Path = self.logdir / "best" + if self.best_eval != -math.inf and best_path.exists(): + print(f"Loading best model from {best_path}…", end="", flush=True) + self.model.load_state_dict(torch.load(best_path)["model"]) + print(" done") + logger.info(f"{best_path} loaded") + + def train_apply_grads(self) -> None: + """ Apply the gradients to the parameters. """ + if self.config.get("amp"): + if self.config.get("clip_gradient"): + self.scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip_gradient) + self.scaler.step(self.optimizer) + self.scaler.update() + else: + if self.config.get("clip_gradient"): + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip_gradient) + self.optimizer.step() + self.optimizer.zero_grad() + if self.config.get("Scheduler"): + self.scheduler.step() + + def train(self) -> None: + """ Train the model with intermediate evaluation on the validation dataset and final evaluation on the test dataset. """ + assert(self.config.accumulated_batch_size % self.config.batch_size == 0) + data_batch_per_true_batch: int = self.config.accumulated_batch_size // self.config.batch_size + + if not self.config.get("no_initial_validation"): + self.best_eval = self.evaluate("valid") + + for self.epoch in range(self.epoch+1, self.config.max_epoch+1): + if self.interrupted: + break + self.model.train() + self.optimizer.zero_grad() + total_loss: float = 0.0 + total_sample: int = 0 + + epoch_loop = self.iterator["train"]() + for batch_id, batch in enumerate(epoch_loop): + batch: Dict[str, torch.Tensor] = {key: value.to(self.device) for key, value in batch.items()} + + loss: torch.Tensor + losses: Dict[str, torch.Tensor] + variables: Dict[str, torch.Tensor] + with self.mixed_precision(): + loss, losses, variables = self.model(**batch) + if self.config.get("amp"): + self.scaler.scale(loss).backward() + else: + loss.backward() + + if torch.isfinite(loss): + total_loss += loss.item() + total_sample += next(iter(batch.values())).shape[0] + epoch_loop.set_postfix(loss=f"{total_loss / total_sample:.2f} .", refresh=False) + else: + epoch_loop.set_postfix(loss=f"{total_loss / (total_sample if total_sample else 1):.2f} NaN", refresh=False) + + self.writer.add_scalar("Loss/train", loss, self.global_batch_id) + for key, value in losses.items(): + self.writer.add_scalar(f"{key}/train", value, self.global_batch_id) + + if batch_id % 1000 == 0: + for key, value in variables.items(): + self.writer.add_histogram(f"{key}/train", value, self.global_batch_id) + if batch_id % data_batch_per_true_batch == data_batch_per_true_batch - 1: + self.train_apply_grads() + self.global_batch_id += 1 + + if total_sample > 0: + self.results[f"train.loss"] = float(total_loss / total_sample) + else: + total_sample = 1 + + if batch_id % data_batch_per_true_batch != data_batch_per_true_batch - 1: + self.train_apply_grads() + + print(f"Epoch {self.epoch:2} train mean loss: {total_loss / total_sample:8.4f}") + logger.info(f"epoch {self.epoch} train mean_loss {total_loss / total_sample}") + + self.checkpoint() + candidate: float = self.evaluate("valid") + if self.update_best_model(candidate): + break + + if "test" in self.iterator: + self.load_best_model() + self.evaluate("test") + self.write_results() + + def metric_message(self, split: str) -> str: + """ Message to be displayed after an evaluation. """ + message = "Epoch {self.epoch:2} {split:5} loss: {metrics.loss:8.4f}" + if not isinstance(self.dataset[split], gbure.data.dataset.UnsupervisedDataset) or hasattr(self.dataset[split], "dataset") and not isinstance(self.dataset[split].dataset, gbure.data.dataset.UnsupervisedDataset): + message += " accuracy: {metrics.accuracy:8.6f}" + if isinstance(self.dataset[split], gbure.data.dataset.SupervisedDataset): + message += " half-directed Macro F1: {metrics.half_directed_macro_f1:8.6f} (P: {metrics.half_directed_macro_precision:8.6f} R: {metrics.half_directed_macro_recall:8.6f})" + if isinstance(self.dataset[split], gbure.data.dataset.GraphAdapter) and (isinstance(self.dataset[split].dataset, gbure.data.dataset.FewShotDataset) or isinstance(self.dataset[split].dataset, gbure.data.dataset.SampledFewShotDataset)): + message += " accuracy non-empty: {metrics.accuracy_non_empty:8.6f} accuracy full: {metrics.accuracy_full:8.6f}" + return message + + def evaluate(self, split: str) -> float: + """ Evaluate the model on the given split and return the early stopping metric. """ + if split not in self.iterator: + if not self.config.get("unsupervised"): + print(f"\033[31m{split} dataset does not exist, evaluation skipped.\033[0m") + return math.nan + + self.model.zero_grad(set_to_none=True) + gc.collect() + self.eval_model.eval() + with torch.no_grad(), gbure.outputs.Outputs(self.logdir, self.tokenizer, self.eval_relation_dictionary) as outputs: + metrics = gbure.metrics.Metrics(self.config, self.tokenizer, self.eval_relation_dictionary, getattr(self.dataset[split], "graph", None)) + loop = self.iterator[split]() + for batch in loop: + batch: Dict[str, torch.Tensor] = {key: value.to(self.device) for key, value in batch.items()} + loss, losses, variables = self.eval_model(**batch) + metrics.update(batch, loss, losses, variables) + outputs.update(batch, loss, losses, variables) + loop.set_postfix(**metrics.summary, refresh=False) + print(self.metric_message(split).format(**locals())) + for metric, value in metrics.all.items(): + logger.info(f"epoch {self.epoch} {split} {metric} {value}") + + candidate: float = getattr(metrics, self.config.validation_metric) if self.config.get("validation_metric") else math.nan + if split == "test" or (split == "valid" and candidate > self.best_eval): + for key, value in metrics.all.items(): + self.results[f"{split}.{key}"] = float(value) + return candidate + + def write_results(self): + """ Write best valid and test score as tensorboard events. """ + self.writer.add_hparams( + hparam_dict=gbure.utils.flatten_dict(self.config), + metric_dict=self.results, + run_name="hparams") + + +if __name__ == "__main__": + gbure.utils.fix_transformers_logging_handler() + + config: gbure.utils.dotdict = gbure.utils.parse_args() + logdir: pathlib.Path + state_dicts: Optional[Dict[str, Any]] = None + new_log_dir: bool = True + + if config.get("load"): + state_dicts = torch.load(config.load, map_location=torch.device('cpu')) + if not config.get("overwrite_config"): + config = state_dicts["config"] + logdir = state_dicts["logdir"] + new_log_dir = False + + if config.get("pretrained"): + state_dicts = torch.load(config.pretrained, map_location=torch.device('cpu')) + state_dicts = {"model": state_dicts["model"]} + + if new_log_dir: + logdir = gbure.utils.LOG_PATH / gbure.utils.logdir_name("GBURE") + assert(not logdir.exists()) + logdir.mkdir() + + gbure.utils.add_logging_handler(logdir) + Trainer(config, logdir, state_dicts).run() diff --git a/gbure/utils.py b/gbure/utils.py @@ -0,0 +1,387 @@ +from typing import Any, Callable, Dict, List, NoReturn, Optional, Union +import hashlib +import importlib +import logging +import multiprocessing +import os +import pathlib +import signal +import subprocess +import sys +import time +import types + +import torch + +logger = logging.getLogger(__name__) + +_HAS_DYNAMIC_ATTRIBUTES = True + + +def import_environment(name: str, cast: type = str) -> None: + """ Import an environment variable into the global namespace. """ + try: + globals()[name] = cast(os.environ[name]) + except KeyError: + print(f"ERROR: {name} environment variable is not set.", file=sys.stderr) + sys.exit(1) + + +import_environment("DATA_PATH", pathlib.Path) +import_environment("LOG_PATH", pathlib.Path) + + +class dotdict(dict): + """ Dictionary which can be access through var.key instead of var["key"]. """ + def __getattr__(self, name: str) -> Any: + if name not in self: + raise AttributeError(f"Config key {name} not found") + return dotdict(self[name]) if type(self[name]) is dict else self[name] + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +def eval_arg(config: Dict[str, Any], arg: str) -> None: + """ + Evaluate arg in the context config, and update it. + + The argument is expected to be of the form: + (parent.)*key(=value)? + If no value is provided, the key is assumed to be a boolean and True is assigned to it. + When passing a string argument through the shell, it must be enclosed in quote (like all python string), which usually need to be escaped. + """ + key: str + value: Any + if '=' in arg: + key, value = arg.split('=', maxsplit=1) + value = eval(value, config) + else: + key, value = arg, True + path: List[str] = key.split('.') + for component in path[:-1]: + config = config[component] + config[path[-1]] = value + config.pop("__builtins__", None) + + +def import_arg(config: Dict[str, Any], arg: str) -> None: + """ + Load file arg, and update config with its content. + + The file is loaded in an independent context, all the variable defined in the file (even through import) are added to config, with the exception of builtins and whole modules. + """ + if arg.endswith(".py"): + arg = arg[:-3].replace('/', '.') + module: types.ModuleType = importlib.import_module(arg) + for key, value in vars(module).items(): + if key not in module.__builtins__ and not key.startswith("__") and not isinstance(value, types.ModuleType): # pytype: disable=attribute-error + config[key] = value + + +def parse_args() -> dotdict: + """ + Parse command line arguments and return config dictionary. + + Two kind of argument are supported: + - When the argument starts with -- it is evaluated by the eval_arg function + - Otherwise the argument is assumed to be a file which is loaded by the import_arg function + """ + config: Dict[str, Any] = {} + config["config"] = config + for arg in sys.argv[1:]: + if arg.startswith("--"): + eval_arg(config, arg[2:]) + else: + import_arg(config, arg) + config.pop("config") + return dotdict(config) + + +def display_dict(output: Callable[[str], None], input: Dict[str, Any], depth: int = 0) -> None: + """ Display nested dictionaries in input using the provided output function. """ + for key, value in input.items(): + indent = '\t'*depth + output(f"{indent}{key}:") + if isinstance(value, dict): + output('\n') + display_dict(output, value, depth+1) + else: + output(f" {value}\n") + + +def print_dict(input: Dict[str, Any]) -> None: + """ Print dictionary to standard output. """ + display_dict(lambda x: print(x, end=""), input) + + +def log_dict(logger: logging.Logger, input: Dict[str, Any]) -> None: + """ Log dictionary to the provided logger. """ + class log: + buf: str = "" + + def __call__(self, x: str) -> None: + self.buf += x + if self.buf.endswith('\n'): + logger.info(self.buf[:-1]) + self.buf = "" + display_dict(log(), input) + + +def flatten_dict(input: Dict[str, Any]) -> Dict[str, Union[bool, int, float, str]]: + """ + Replace nested dict by dot-separated keys, and cast keys to simple types. + + repr() is used to cast non-base-type to str. + """ + def impl(result: Dict[str, Union[bool, int, float, str]], input: Dict[str, Any], prefix: str): + for key, value in input.items(): + if isinstance(value, dict): + impl(result, value, f"{key}.") + else: + result[f"{prefix}{key}"] = value if type(value) in [bool, int, float, str] else repr(value) + + result: Dict[str, Union[bool, int, float, str]] = {} + impl(result, input, "") + return result + + +def get_repo_version() -> str: + """ Get the code repository version. """ + repo_dir = pathlib.Path(__file__).parents[0] + result: subprocess.CompletedProcess = subprocess.run( + ["git", "rev-parse", "HEAD"], + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + encoding="utf-8", + cwd=repo_dir) + + if result.returncode != 0: + return "release" + commit_hash: str = result.stdout.strip()[:8] + + result = subprocess.run( + ["git", "status", "--porcelain"], + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + encoding="utf-8", + cwd=repo_dir) + modified_flag: str = "" + for line in result.stdout.split('\n'): + if line.startswith(" M "): + modified_flag = "+" + break + + return f"{commit_hash}{modified_flag}" + + +def experiment_name(name: str) -> str: + """ Name of the experiment (contains repository version, argument and time). """ + args: str = ' '.join(sys.argv[1:]) + version: str = get_repo_version() + stime: str = time.strftime("%FT%H:%M:%S") + return f"{name} {version} {args} {stime}" + + +def logdir_name(name: str) -> str: + """ Name of the experiment directory, it should be the experiment_name clipped because of filesystem constraints. """ + subdir: str = experiment_name(name).replace('/', '_') + if len(subdir) > 255: + sha1: str = hashlib.sha1(subdir.encode("utf-8")).digest().hex()[:16] + subdir = subdir[:255-17] + ' ' + sha1 + return subdir + + +def fix_transformers_logging_handler() -> None: + """ The transformers package from huggingface install its own logger on import, I don't want it. """ + logger: logging.Logger = logging.getLogger() + for handler in logger.handlers: + logger.removeHandler(handler) + + +def add_logging_handler(logdir: pathlib.Path) -> None: + logfile: pathlib.Path = logdir / "log" + logging.basicConfig(format="%(asctime)s\t%(levelname)s:%(name)s:%(message)s", filename=logfile, filemode='a', level=logging.INFO) + + +def save_patch(outpath: pathlib.Path) -> None: + """ Save a file at the given patch containing the diff between the current code and the last commit. """ + repo_dir = pathlib.Path(__file__).parents[0] + + with outpath.open("w") as outfile: + result: subprocess.CompletedProcess = subprocess.run( + ["git", "diff", "HEAD"], + stdout=outfile, + stderr=subprocess.DEVNULL, + encoding="utf-8", + cwd=repo_dir) + + assert(result.returncode == 0) + + +class Experiment: + """ + Base class for running an experiment. + + Calling run on an instance of this class will call the init() then main() functions. + The sole purpose of this class is to make an experiment "prettier": it displays config values, store a diff of the repo in the experiment logdir, etc. + + Config: + deterministic: run in deterministic mode + seed: seed for random number generators + """ + + _HAS_DYNAMIC_ATTRIBUTES = True + + def __init__(self, config: dotdict, logdir: pathlib.Path, state_dicts: Optional[Dict[str, Any]] = None) -> None: + self.config = config + self.logdir = logdir + self.state_dicts = state_dicts + + def init(self) -> None: + """ Prepare the experiment (e.g. initialize datasets and models). """ + pass + + def main(self) -> NoReturn: + """ Run the experiment in itself. """ + raise NotImplementedError("Subclasses must implement a main method") + + def close(self) -> None: + """ Free used resources (e.g. close opened files). """ + pass + + def run(self) -> None: + """ Run the whole experiment with setting ups, etc. """ + self.log_environment() + self.log_patch() + self.initialize_rng() + self.init() + self.hook_signals() + self.main() + self.close() + + def log_environment(self) -> None: + """ Display information about the environment. """ + print(f"logdir is \033[1m\033[33m{self.logdir}\033[0m") + self.version_check() + self.detect_gpus() + print("") + + print("\033[1m\033[33mConfiguration\033[0m") + print_dict(self.config) + log_dict(logging.getLogger("config"), self.config) + print("") + + def version_check(self) -> None: + """ Check the version of the main dependencies. """ + python_version: str = '.'.join(map(str, sys.version_info[:3])) + torch_version: str = torch.__version__ + cuda_available: str = str(torch.cuda.is_available()) + + logger.info(f"python version {python_version}") + logger.info(f"torch version {torch_version}") + logger.info(f"cuda available {cuda_available}") + + def problem(msg: str) -> str: + return f"\033[1m\033[31m{msg}\033[0m" + if sys.version_info < (3, 7): + python_version = problem(python_version) + if list(map(int, torch_version.split('+')[0].split('.'))) < [1, 6]: + torch_version = problem(torch_version) + if cuda_available == "False": + cuda_available = problem(cuda_available) + print(f"python version: {python_version}, torch version: {torch_version}, cuda available: {cuda_available}") + + def detect_gpus(self) -> None: + """ Display available gpus and set self.device. """ + count: int = torch.cuda.device_count() + + if count == 0: + print(f"\033[1m\033[31mNo GPU available\033[0m") + logger.warning("no GPU available") + self.device = torch.device("cpu") + else: + self.device = torch.device("cuda:0") + + for i in range(count): + gp = torch.cuda.get_device_properties(i) + print(f"GPU{i}: \033[33m{gp.name}\033[0m (Mem: {gp.total_memory/2**30:.2f}GiB CC: {gp.major}.{gp.minor})") + logger.info(f"GPU{i} {gp.name} {gp.total_memory} {gp.major}.{gp.minor}") + + def log_patch(self) -> None: + """ Check the version of the code, and save a patch to logpath if it was modified. """ + version: str = get_repo_version() + logger.info(f"repository_version {version}") + if version == "release": + print(f"\033[41mRelease version\033[0m\n") + elif version.endswith('+'): + print(f"\033[31mUncommited changes detected, saving patch to logdir.\033[0m\n") + suffix: str = "" + if self.state_dicts: # Reloading an existing Trainer + suffix = time.strftime("%FT%H:%M:%S") + save_patch(self.logdir / f"patch{suffix}") + + def initialize_rng(self) -> None: + if self.state_dicts and "torch_rng" in self.state_dicts: + torch.random.set_rng_state(self.state_dicts["torch_rng"]) + assert(("cuda_rng" in self.state_dicts) == torch.cuda.is_available()) + if "cuda_rng" in self.state_dicts: + torch.cuda.random.set_rng_state_all(self.state_dicts["cuda_rng"]) + else: + torch.manual_seed(self.config.seed) + + if self.config.get("deterministic") and torch.backends.cudnn.enabled: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + def hook_signals(self) -> None: + """ Change the behavior of SIGINT (^C) to change a variable `self.interrupted' before killing the process. """ + self.interrupted: bool = False + + def handler(sig: int, frame: types.FrameType) -> None: + if multiprocessing.current_process().name != "MainProcess": + return + + print("\n\033[31mInterrupted, execution will stop at the end of this epoch.\n\033[1mNEXT ^C WILL KILL THE PROCESS!\033[0m\n", file=sys.stderr) + self.interrupted = True + signal.signal(signal.SIGINT, signal.SIG_DFL) + + signal.signal(signal.SIGINT, handler) + + +class SharedLongTensorList: + def __init__(self, tensor_list: List[torch.Tensor], view: List[int] = [-1]): + self.view = view + + total_element: int = 0 + for tensor in tensor_list: + total_element += tensor.numel() + self.data: torch.Tensor = torch.empty(total_element, dtype=torch.int64) + self.indices: torch.Tensor = torch.empty(len(tensor_list)+1, dtype=torch.int64) + + data_pos: int = total_element + indices_pos: int = len(tensor_list) + self.indices[indices_pos] = data_pos + while tensor_list: + tensor: torch.Tensor = tensor_list.pop() + tensor_size: int = tensor.numel() + + indices_pos -= 1 + data_pos -= tensor_size + + self.indices[indices_pos] = data_pos + self.data[data_pos:data_pos+tensor_size] = tensor.flatten() + assert(data_pos == 0) + assert(indices_pos == 0) + + def __len__(self) -> int: + return self.indices.shape[0]-1 + + def __getitem__(self, key: Union[int, slice]) -> torch.Tensor: + if isinstance(key, slice): + return [self[value] for value in range(*key.indices(len(self)))] + elif isinstance(key, int): + return self.data[self.indices[key]:self.indices[key+1]].view(*self.view) + elif isinstance(key, torch.Tensor): + return self[key.item()] + else: + raise TypeError("Invalid argument type.") diff --git a/requirements.txt b/requirements.txt @@ -1,4 +0,0 @@ -numpy>=1.16 -torch==1.3.0 -tqdm>=4 -transformers==2.0.0 diff --git a/scripts/contrastive_alignment.sh b/scripts/contrastive_alignment.sh @@ -0,0 +1,58 @@ +#!/bin/bash +#SBATCH -A lco@gpu +#SBATCH -C v100-32g +#SBATCH --job-name=gbure_contrastive +#SBATCH --ntasks=1 +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=10 +#SBATCH --distribution=block:block +#SBATCH --hint=nomultithread +#SBATCH --time=20:00:00 +#SBATCH --output=/gpfswork/rech/lco/url46ht/Étienne/runs/%x_%j.stdout +#SBATCH --error=/gpfswork/rech/lco/url46ht/Étienne/runs/%x_%j.stderr +#SBATCH --array=0-23 + +# %x = nom du job +# %j = id du job + +module purge +source ~/.bashrc +conda activate Étienne +export DATA_PATH=$WORK/Étienne/data +export LOG_PATH=$WORK/Étienne/log +export HF_DATASETS_OFFLINE=1 +export TRANSFORMERS_OFFLINE=1 +cd $WORK/Étienne/code + +# echo des commandes lancées +set -x + +config="" + +if [ $(( $SLURM_ARRAY_TASK_ID % 3 )) -eq 0 ]; then + config="$config --language_model_weight=0" +elif [ $(( $SLURM_ARRAY_TASK_ID % 3 )) -eq 1 ]; then + config="$config --language_model_weight=0.1" +else + config="$config --language_model_weight=1" +fi + +if [ $(( $SLURM_ARRAY_TASK_ID / 3 % 4 )) -eq 0 ]; then + config="$config --margin=0.1" +elif [ $(( $SLURM_ARRAY_TASK_ID / 3 % 4 )) -eq 1 ]; then + config="$config --margin=1" +elif [ $(( $SLURM_ARRAY_TASK_ID / 3 % 4 )) -eq 2 ]; then + config="$config --margin=10" +else + config="$config --margin=100" +fi + +if [ $(( $SLURM_ARRAY_TASK_ID / 12 % 2 )) -eq 0 ]; then + config="$config --topological_weight=1" +else + config="$config --topological_weight=0.2" +fi + +config="$config --seed=$(($SLURM_ARRAY_TASK_ID / 24))" + +python -m gbure.train gbure/config/contrastive_alignment.py $config --post_transformer_layer=\"layer_norm\" diff --git a/scripts/mtb.sh b/scripts/mtb.sh @@ -0,0 +1,47 @@ +#!/bin/bash +#SBATCH -A lco@gpu +#SBATCH -C v100-32g +#SBATCH --job-name=gbure_mtb +#SBATCH --ntasks=1 +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=10 +#SBATCH --distribution=block:block +#SBATCH --hint=nomultithread +#SBATCH --time=20:00:00 +#SBATCH --output=/gpfswork/rech/lco/url46ht/Étienne/runs/%x_%j.stdout +#SBATCH --error=/gpfswork/rech/lco/url46ht/Étienne/runs/%x_%j.stderr +#SBATCH --array=0-3 + +# %x = nom du job +# %j = id du job + +module purge +source ~/.bashrc +conda activate Étienne +export DATA_PATH=$WORK/Étienne/data +export LOG_PATH=$WORK/Étienne/log +export HF_DATASETS_OFFLINE=1 +export TRANSFORMERS_OFFLINE=1 +cd $WORK/Étienne/code + +# echo des commandes lancées +set -x + +config="" + +if [ $(( $SLURM_ARRAY_TASK_ID % 2 )) -eq 0 ]; then + config="$config --latent_metric_scale=\"standard\" --latent_dot_mean=1067.65 --latent_dot_std=111.17" +else + config="$config --latent_metric_scale=\"sqrt\" --latent_dot_mean=None --latent_dot_std=None" +fi + +if [ $(( $SLURM_ARRAY_TASK_ID / 2 % 2 )) -eq 0 ]; then + config="$config --filter_empty_neighborhood=True" +else + config="$config --filter_empty_neighborhood=False" +fi + +config="$config --seed=$(($SLURM_ARRAY_TASK_ID / 4))" + +# We use the gcn_mtb config with a 0 topological weight in order to get accuracies in function of neighborhood sizes +python -m gbure.train gbure/config/gcn_mtb.py $config --post_transformer_layer=\"layer_norm\" --topological_weight=0 diff --git a/scripts/mtb_gcn.sh b/scripts/mtb_gcn.sh @@ -0,0 +1,54 @@ +#!/bin/bash +#SBATCH -A lco@gpu +#SBATCH -C v100-32g +#SBATCH --job-name=gbure_gcn_mtb +#SBATCH --ntasks=1 +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=10 +#SBATCH --distribution=block:block +#SBATCH --hint=nomultithread +#SBATCH --time=20:00:00 +#SBATCH --output=/gpfswork/rech/lco/url46ht/Étienne/runs/%x_%j.stdout +#SBATCH --error=/gpfswork/rech/lco/url46ht/Étienne/runs/%x_%j.stderr +#SBATCH --array=0-11 + +# %x = nom du job +# %j = id du job + +module purge +source ~/.bashrc +conda activate Étienne +export DATA_PATH=$WORK/Étienne/data +export LOG_PATH=$WORK/Étienne/log +export HF_DATASETS_OFFLINE=1 +export TRANSFORMERS_OFFLINE=1 +cd $WORK/Étienne/code + +# echo des commandes lancées +set -x + +config="" + +if [ $(( $SLURM_ARRAY_TASK_ID % 3 )) -eq 0 ]; then + config="$config --gcn_aggregator=\"none\"" +elif [ $(( $SLURM_ARRAY_TASK_ID % 3 )) -eq 1 ]; then + config="$config --gcn_aggregator=\"mean\"" +else + config="$config --gcn_aggregator=\"chebyshev\"" +fi + +if [ $(( $SLURM_ARRAY_TASK_ID / 3 % 2 )) -eq 0 ]; then + config="$config --topological_weight=1" +else + config="$config --topological_weight=0.2" +fi + +if [ $(( $SLURM_ARRAY_TASK_ID / 6 % 2 )) -eq 0 ]; then + config="$config --filter_empty_neighborhood=True" +else + config="$config --filter_empty_neighborhood=False" +fi + +config="$config --seed=$(($SLURM_ARRAY_TASK_ID / 12))" + +python -m gbure.train gbure/config/gcn_mtb.py $config --post_transformer_layer=\"layer_norm\" diff --git a/scripts/nonparametric.sh b/scripts/nonparametric.sh @@ -0,0 +1,42 @@ +#!/bin/bash +#SBATCH -A lco@gpu +#SBATCH -C v100-32g +#SBATCH --job-name=gbure_nonparametric +#SBATCH --ntasks=1 +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=20 +#SBATCH --distribution=block:block +#SBATCH --hint=nomultithread +#SBATCH --time=20:00:00 +#SBATCH --output=/gpfswork/rech/lco/url46ht/Étienne/runs/%x_%j.stdout +#SBATCH --error=/gpfswork/rech/lco/url46ht/Étienne/runs/%x_%j.stderr +#SBATCH --array=0-2 + +# %x = nom du job +# %j = id du job + +module purge +source ~/.bashrc +conda activate Étienne +export DATA_PATH=$WORK/Étienne/data +export LOG_PATH=$WORK/Étienne/log +export HF_DATASETS_OFFLINE=1 +export TRANSFORMERS_OFFLINE=1 +cd $WORK/Étienne/code + +# echo des commandes lancées +set -x + +config="" + +if [ $(( $SLURM_ARRAY_TASK_ID % 3 )) -eq 0 ]; then + config="$config --undefined_poison_whole_meta=True" +elif [ $(( $SLURM_ARRAY_TASK_ID % 3 )) -eq 1 ]; then + config="$config --undefined_poison_whole_meta=False --neutral_topological_similarity=None" +else + config="$config --undefined_poison_whole_meta=False --neutral_topological_similarity=(1536*2)**0.5" +fi + +for topological_weight in 0.1 0.15 0.2 0.22 0.25 0.3 0.5 1; do + python -m gbure.train gbure/config/nonparametric.py $config --topological_weight=$topological_weight +done