commit a52d7245480b4e0be22a2c21691afd7a35b26d16
parent 35e7296c5fbe3ed83103ff94676a3cff90030d28
Author: Étienne Simon <esimon@esimon.eu>
Date: Wed, 18 May 2022 00:23:20 +0000
Implement new graph-based approaches
Diffstat:
53 files changed, 4192 insertions(+), 1346 deletions(-)
diff --git a/README b/README
@@ -1,35 +1,17 @@
-Reproduction of Matching the Blanks: Distributional Similarity for Relation Learning by Livio Baldini Soares, Nicholas FitzGerald, Jeffrey Ling, and Tom Kwiatkowski.
+This repository presents some results of graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem.
-This repository currently contains the supervised "entity markers-entity start" model for the Semeval and KBP37 datasets, the unsupervised MTB model and the FewRel dataset will be added latter on.
-In order to reproduce the results, you first need to manually download and extract the datasets (e.g. in /tmp), then execute the following:
+$ export DATA_PATH="/tmp" # the directory for storing datasets (<100MiB for each dataset, except T-REx which need a bit less than 40GiB)
+$ export LOG_PATH="/tmp" # the directory where the models will be saved (around 12GiB per model)
+$ https://esimon.eu/repos/gbure.git
+$ cd gbure
+$ python3 -m gbure.data.prepare_semeval
+$ python3 -m gbure.data.prepare_kbp37
+$ python3 -m gbure.data.prepare_fewrel
+$ python3 -m gbure.train gbure/config/soares_semeval.py
+$ python3 -m gbure.train gbure/config/soares_kbp37.py
+$ python3 -m gbure.train gbure/config/soares_fewrel.py
-$ export DATA_PATH="/tmp" # the directory containing the extracted datasets
-$ export LOG_PATH="/tmp" # the directory where the models will be saved
-$ git clone https://gitlab.lip6.fr/esimon/few-shot-relation-extraction
-$ cd few-shot-relation-extraction
-$ python -m fsre.data.prepare_semeval
-$ python -m fsre.train fsre/config/soares_supervised_semeval.py
-$ python -m fsre.data.prepare_kbp37
-$ python -m fsre.train fsre/config/soares_supervised_kbp37.py
-
-I must be missing something since I'm a bit away from the results reported in the paper.
-Here are the score reached by this repository (official SemEval macro F1 "taking directionality into account"):
-
-On Semeval:
- paper valid: 82.1
- BERT cased valid: 88.55 (std: 0.70)
- BERT uncased valid: 87.49 (std: 0.87)
- paper test: 89.2
- BERT cased test: 88.24 (std: 0.37)
- BERT uncased test: 88.02 (std: 0.65)
-
-On KBP37:
- paper valid: 70.0
- BERT cased valid: 66.92 (std: 0.61)
- BERT uncased valid: 66.48 (std: 0.46)
- paper test: 68.3
- BERT cased test: 67.18 (std: 0.51)
- BERT uncased test: 66.21 (std: 0.62)
-
-The mean and std are computed over 5 runs, all reported results are for "large" BERT models.
-Detailed results can be found at http://www-ia.lip6.fr/~esimon/results.xhtml (temporary link).
+To evaluate on the static (already sampled) FewRel subset used for evaluation you should use the sample_io.py script from FewRel:
+$ python3 sample_io.py $DATA_PATH/FewRel/val.json 50000 5 1 0 input > $DATA_PATH/FewRel/val_50000_5_1_0_input
+$ python3 sample_io.py $DATA_PATH/FewRel/val.json 50000 5 1 0 output > $DATA_PATH/FewRel/val_50000_5_1_0_output
+$ python3 -m gbure.data.prepare_sampled_fewrel bert-base-cased $DATA_PATH/FewRel/val_50000_5_1_0_input $DATA_PATH/FewRel/val_50000_5_1_0_output
diff --git a/fsre/__init__.py b/fsre/__init__.py
@@ -1,4 +0,0 @@
-import fsre.data
-import fsre.utils
-import fsre.metrics
-import fsre.train
diff --git a/fsre/config/soares_supervised_kbp37.py b/fsre/config/soares_supervised_kbp37.py
@@ -1,21 +0,0 @@
-from fsre.model.mtb_supervised import Model
-
-
-dataset_name = "KBP37"
-
-# From Table 1
-bert_model = "bert-large-cased"
-post_transformer_layer = "linear"
-max_epoch = 10
-learning_rate = 3e-5
-true_batch_size = 64
-
-# Guessed
-validation_metric = "half_directed_macro_f1"
-early_stopping_patience = 2
-
-# Implementation details
-seed = 0
-batch_size = 2
-batch_per_sort_bucket = 8
-sort_per_shuffle_bucket = 8
diff --git a/fsre/config/soares_supervised_semeval.py b/fsre/config/soares_supervised_semeval.py
@@ -1,21 +0,0 @@
-from fsre.model.mtb_supervised import Model
-
-
-dataset_name = "SemEval2010_task8_all_data"
-
-# From Table 1
-bert_model = "bert-large-cased"
-post_transformer_layer = "layer_norm"
-max_epoch = 10
-learning_rate = 3e-5
-true_batch_size = 64
-
-# Guessed
-validation_metric = "half_directed_macro_f1"
-early_stopping_patience = 2
-
-# Implementation details
-seed = 0
-batch_size = 2
-batch_per_sort_bucket = 8
-sort_per_shuffle_bucket = 8
diff --git a/fsre/data/__init__.py b/fsre/data/__init__.py
@@ -1,2 +0,0 @@
-from fsre.data.dataset import RelationExtractionDataset
-from fsre.data.relation_dictionary import RelationDictionary
diff --git a/fsre/data/dataset.py b/fsre/data/dataset.py
@@ -1,116 +0,0 @@
-import math
-import numpy
-import torch
-
-
-class RelationExtractionDataset(torch.utils.data.IterableDataset):
- """
- Read a preprocessed Relation Extraction dataset from a .npy file.
-
- When generating data, we first read a large shuffle bucket which is
- shuffled (unless the dataset is for evaluation). This shuffle_bucket
- is then cut down into several sort buckets, each of them is sorted
- so that sentences of similar length end up next to each other. The
- sort buckets are then cut down into batches which pad the sentences.
- However this class generate samples, the proper batching should be
- done by a DataLoader.
- A preprocessed dataset can be created from the fsre.data.prepare_*
- modules.
-
- Config:
- batch_per_sort_bucket: the number of batches in a sort bucket
- batch_size: the number of samples in a batch
- seed: the seed for the random number generator
- sort_per_shuffle_bucket: the number of sort buckets in a shuffle bucket
- """
-
- def __init__(self, config, path, pad, evaluation, rng=None):
- """
- Initialize a Relation Extraction dataset and load the data in RAM.
-
- Args:
- config: global config object
- path: path to the dataset to load, this should be a .npy file
- pad: the value used to pad text in a batch
- evaluation: whether this dataset is an evaluation one (no need to sort then)
- rng: the random number generator to use for shuffling
- """
-
- super().__init__()
-
- self.config = config
- self.pad = pad
- self.evaluation = evaluation
-
- self.data = numpy.load(path, allow_pickle=True)
- if not evaluation:
- self.rng = rng if rng is not None else numpy.random.RandomState(config.seed)
-
- self.batch_size = config.batch_size
- self.sort_bucket_size = config.batch_size * config.batch_per_sort_bucket
- self.shuffle_bucket_size = self.sort_bucket_size * config.sort_per_shuffle_bucket
-
- def __len__(self):
- return len(self.data)
-
- def pad_text(self, text, diff, pad):
- """ Append diff tokens pad to the end of text """
- text = torch.tensor(text, dtype=torch.int64)
- padding = text.new_full((diff,), pad)
- return torch.cat((text, padding))
-
- def iter_sample(self, samples):
- """ Generate samples from a batch """
- lengths = [sample[1].shape[0] for sample in samples]
- max_len = max(lengths)
-
- for sample, length in zip(samples, lengths):
- ds = dict(zip(["id", "text", "e1_pos", "e2_pos", "relation"], sample))
- ds["length"] = length
- ds["mask"] = self.pad_text(numpy.ones_like(ds["text"]), max_len - length, 0)
- ds["text"] = self.pad_text(ds["text"], max_len - length, self.pad)
- yield ds
-
- def iter_batch(self, sort_bucket):
- """ Generate batches from a sort_bucket """
- for batch_start in range(0, len(sort_bucket), self.batch_size):
- batch_end = batch_start + self.batch_size
- batch_end = min(batch_end, len(sort_bucket))
-
- yield from self.iter_sample([self.data[i] for i in sort_bucket[batch_start:batch_end]])
-
- def iter_shuffle_bucket(self, shuffle_bucket):
- """ Generate sort_buckets from a shuffle_bucket """
- worker_info = torch.utils.data.get_worker_info()
- if worker_info is None:
- start = 0
- end = len(shuffle_bucket)
- else:
- work_size = len(shuffle_bucket) / self.sort_bucket_size
- per_worker = int(math.ceil(work_size / float(worker_info.num_workers))) * self.sort_bucket_size
- start = worker_info.id * per_worker
- end = min(start + per_worker, len(shuffle_bucket))
-
- for sort_start in range(start, end, self.sort_bucket_size):
- sort_end = sort_start + self.sort_bucket_size
- sort_end = min(sort_end, end)
-
- sort_bucket = shuffle_bucket[sort_start:sort_end]
-
- sort_bucket.sort(key=lambda x: self.data[x, 1].shape[0])
- yield from self.iter_batch(sort_bucket)
-
- def __iter__(self):
- """ Generate shuffle_buckets from the dataset """
- for shuffle_start in range(0, len(self), self.shuffle_bucket_size):
- shuffle_end = shuffle_start + self.shuffle_bucket_size
- shuffle_end = min(shuffle_end, len(self))
-
- shuffle_bucket = list(range(shuffle_start, shuffle_end))
-
- if not self.evaluation:
- self.rng.shuffle(shuffle_bucket)
- yield from self.iter_shuffle_bucket(shuffle_bucket)
-
- def state_dict(self):
- return {"rng": self.rng}
diff --git a/fsre/data/prepare_kbp37.py b/fsre/data/prepare_kbp37.py
@@ -1,72 +0,0 @@
-import argparse
-import numpy
-import transformers
-import tqdm
-
-from fsre.utils import DATA_PATH
-from fsre.data.relation_dictionary import RelationDictionary
-from fsre.data.prepare_semeval import load_semeval_dataset
-
-TRAIN_SIZE = 15917
-VALID_SIZE = 1724
-TEST_SIZE = 3405
-UNKNOWN_RELATION = "no_relation"
-
-
-def prepare_kbp37(args):
- rng = numpy.random.RandomState(args.seed)
- kbp37_path = DATA_PATH / "KBP37"
- output_path = kbp37_path / args.tokenizer
-
- if not output_path.is_dir():
- output_path.mkdir()
-
- relation_dictionary = RelationDictionary(unknown=UNKNOWN_RELATION)
-
- tokenizer = transformers.BertTokenizer.from_pretrained(args.tokenizer)
- tokenizer.add_special_tokens({"additional_special_tokens": ["<e1>", "</e1>", "<e2>", "</e2>"]})
- tokenizer_path = output_path / "tokenizer"
- if not tokenizer_path.is_dir():
- tokenizer_path.mkdir()
- tokenizer.save_pretrained(tokenizer_path)
-
- train = load_semeval_dataset(
- kbp37_path / "train.txt",
- tokenizer,
- relation_dictionary,
- TRAIN_SIZE)
- rng.shuffle(train)
-
- valid = load_semeval_dataset(
- kbp37_path / "dev.txt",
- tokenizer,
- relation_dictionary,
- VALID_SIZE)
-
- test = load_semeval_dataset(
- kbp37_path / "test.txt",
- tokenizer,
- relation_dictionary,
- TEST_SIZE)
-
- numpy.save(output_path / "train.npy", numpy.array(train))
- numpy.save(output_path / "valid.npy", numpy.array(valid))
- numpy.save(output_path / "test.npy", numpy.array(test))
-
- relation_dictionary.save(output_path / "relations")
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- description="Prepare the KBP37 dataset.")
- parser.add_argument("tokenizer",
- type=str,
- nargs='?',
- default="bert-large-cased",
- help="Name of the transformers tokenizer")
- parser.add_argument("-s", "--seed",
- type=int,
- default=0,
- help="Seed of the RNG for shuffling the dataset")
-
- prepare_kbp37(parser.parse_args())
diff --git a/fsre/data/prepare_semeval.py b/fsre/data/prepare_semeval.py
@@ -1,106 +0,0 @@
-import argparse
-import numpy
-import transformers
-import tqdm
-
-from fsre.utils import DATA_PATH
-from fsre.data.relation_dictionary import RelationDictionary
-
-TRAIN_SIZE = 8000
-TEST_SIZE = 2717
-UNKNOWN_RELATION = "Other"
-
-
-def load_semeval_dataset(path, tokenizer, relation_dictionary, size):
- be1_id = tokenizer.added_tokens_encoder["<e1>"]
- be2_id = tokenizer.added_tokens_encoder["<e2>"]
-
- dataset = []
- with open(path) as infile:
- for _ in tqdm.trange(size):
- idtext_line = infile.readline()
- relation_line = infile.readline()
- if not (idtext_line and relation_line):
- break
-
- id, raw_text = idtext_line.rstrip().split('\t')
- id = int(id)
-
- raw_text = raw_text[1:-1] # remove quotes around text
- text = tokenizer.encode(raw_text, add_special_tokens=True)
- e1_pos = text.index(be1_id)
- e2_pos = text.index(be2_id)
- if len(text) > tokenizer.max_len:
- text = text[:tokenizer.max_len]
- e1_pos = min(tokenizer.max_len-1, e1_pos)
- e2_pos = min(tokenizer.max_len-1, e2_pos)
- text = numpy.array(text, dtype=numpy.int32)
-
- relation_line = relation_line.rstrip()
- dir_start = relation_line.find('(')
- relation_base = relation_line[:dir_start] if dir_start >= 0 else relation_line
- relation = relation_dictionary.encode(relation_line, relation_base)
-
- dataset.append([id, text, e1_pos, e2_pos, relation])
- infile.readline() # Ignore Comment line
- infile.readline() # Ignore empty line
- return dataset
-
-
-def prepare_semeval(args):
- rng = numpy.random.RandomState(args.seed)
- semeval_path = DATA_PATH / "SemEval2010_task8_all_data"
- output_path = semeval_path / args.tokenizer
-
- if not output_path.is_dir():
- output_path.mkdir()
-
- relation_dictionary = RelationDictionary(unknown=UNKNOWN_RELATION)
-
- tokenizer = transformers.BertTokenizer.from_pretrained(args.tokenizer)
- tokenizer.add_special_tokens({"additional_special_tokens": ["<e1>", "</e1>", "<e2>", "</e2>"]})
- tokenizer_path = output_path / "tokenizer"
- if not tokenizer_path.is_dir():
- tokenizer_path.mkdir()
- tokenizer.save_pretrained(tokenizer_path)
-
- dataset = load_semeval_dataset(
- semeval_path / "SemEval2010_task8_training" / "TRAIN_FILE.TXT",
- tokenizer,
- relation_dictionary,
- TRAIN_SIZE)
- rng.shuffle(dataset)
- train = dataset[args.valid_size:]
- valid = dataset[:args.valid_size]
-
- test = load_semeval_dataset(
- semeval_path / "SemEval2010_task8_testing_keys" / "TEST_FILE_FULL.TXT",
- tokenizer,
- relation_dictionary,
- TEST_SIZE)
-
- numpy.save(output_path / "train.npy", numpy.array(train))
- numpy.save(output_path / "valid.npy", numpy.array(valid))
- numpy.save(output_path / "test.npy", numpy.array(test))
-
- relation_dictionary.save(output_path / "relations")
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- description="Prepare the SemEval 2010 Task 8 dataset.")
- parser.add_argument("tokenizer",
- type=str,
- nargs='?',
- default="bert-large-cased",
- help="Name of the transformers tokenizer")
- parser.add_argument("-s", "--seed",
- type=int,
- default=0,
- help="Seed of the RNG for shuffling the dataset")
- parser.add_argument("-v", "--valid-size",
- type=int,
- default=1500,
- help="Size of the validation set")
-
- prepare_semeval(parser.parse_args())
diff --git a/fsre/data/relation_dictionary.py b/fsre/data/relation_dictionary.py
@@ -1,94 +0,0 @@
-import pickle
-
-
-class RelationDictionary:
- """
- A dictionary to be used for relations.
-
- The tokens held by this class are divided between:
- - *relation* such as "Entity-Destination(e1,e2)"
- - *base* such as "Entity-Destination"
- """
-
- def __init__(self, *, unknown=None, path=None):
- self.encoder = {}
- self.decoder = []
-
- self.base_encoder = {}
- self.base_decoder = []
- self.id_to_bid = []
-
- self.unknown = unknown
- if unknown is not None:
- self.encoder[unknown] = 0
- self.decoder.append(unknown)
- self.base_encoder[unknown] = 0
- self.base_decoder.append(unknown)
- self.id_to_bid.append(0)
-
- if path is not None:
- self.load(path)
-
- def __len__(self):
- """ Number of relations in the dictionary. """
- return len(self.decoder)
-
- def base_size(self):
- """ Number of bases in the dictionary. """
- return len(self.base_decoder)
-
- def encode(self, relation, base=None):
- """
- Returns the id corresponding to a relation string.
-
- Args:
- relation: the string of the relation (e.g. "Entity-Destination(e1,e2)")
- base: the string of the base relation (e.g. "Entity-Destination")
- """
-
- if relation is None:
- return None
-
- if base is None:
- return self.encoder[relation]
-
- id = self.encoder.get(relation)
- if id is not None:
- return id
-
- bid = self.base_encoder.get(base)
- if bid is None:
- bid = len(self.base_decoder)
- self.base_encoder[base] = bid
- self.base_decoder.append(base)
-
- id = len(self.decoder)
- self.encoder[relation] = id
- self.decoder.append(relation)
- self.id_to_bid.append(bid)
- return id
-
- def decode(self, id):
- """ Returns the string corresponding to a relation id. """
- return self.decoder[id]
-
- def base_id(self, id):
- """ Returns the base id corresponding to a relation id. """
- return self.id_to_bid[id]
-
- def save(self, path):
- with open(path, "wb") as file:
- pickle.dump({
- "unknown": self.unknown,
- "decoder": self.decoder,
- "encoder": self.encoder,
- "base_encoder": self.base_encoder,
- "base_decoder": self.base_decoder,
- "id_to_bid": self.id_to_bid,
- }, file)
-
- def load(self, path):
- with open(path, "rb") as file:
- data = pickle.load(file)
- for key, value in data.items():
- setattr(self, key, value)
diff --git a/fsre/eval.py b/fsre/eval.py
@@ -1,30 +0,0 @@
-import torch
-
-import fsre
-
-
-class Evaluator(fsre.train.Trainer):
- """
- Evaluate a model.
- """
-
- def __init__(self, config, state_dicts):
- self.eval_config = config
- super().__init__(state_dicts["config"], None, state_dicts)
-
- def run(self):
- self.epoch = self.state_dicts["epoch"]
- self.info()
- self.initialize_rng()
- self.prepare_dataset()
- self.build_model()
- self.count_parameters()
- self.evaluate("test")
-
-
-if __name__ == "__main__":
- fsre.utils.fix_transformers_logging_handler()
- config = fsre.utils.parse_args()
-
- state_dicts = torch.load(config.load)
- Evaluator(config, state_dicts).run()
diff --git a/fsre/metrics.py b/fsre/metrics.py
@@ -1,250 +0,0 @@
-import math
-import numpy
-import torch
-
-
-class Metrics:
- """
- Class for computing metrics.
-
- Twenty metrics are computed:
- - Accuracy
- - Negative Log Likelihood (nll)
- - {directed, undirected, half_directed} {micro, macro} {f1, precision, recall}
- Note that the Accuracy is the true accuracy, taking directionality into account and scoring the unknown relation as any other relation.
- The last 18 metrics follow the SemEval scorer:
- - The unknown ("Other") relation is only scored indirectly
- - Directed is equivalent to the metrics "USING DIRECTIONALITY"
- - Undirected is equivalent to the metrics "IGNORING DIRECTIONALITY"
- - Half-directed is equivalent to the metrics "TAKING DIRECTIONALITY INTO ACCOUNT -- OFFICIAL"
- Note that the directed and half_directed micro metrics are equivalents.
- """
-
- def __init__(self, relation_dictionary):
- """
- Initialize all metrics.
-
- Args:
- relation_dictionary: see class RelationDictionary
- """
-
- self.relation_dictionary = relation_dictionary
- self.n = len(relation_dictionary)
- self.m = relation_dictionary.base_size()
- self.build_mask()
- self.build_base_transition()
- self.crossentropy = torch.nn.CrossEntropyLoss(reduction="sum")
-
- self.size = 0
- self.correct = 0
- self.ce_sum = 0
- self.confusion = numpy.zeros((self.n, self.n), numpy.int64)
-
- def build_mask(self):
- self.mask = numpy.ones(self.n)
- if self.relation_dictionary.unknown is not None:
- assert(self.relation_dictionary.decode(0) == self.relation_dictionary.unknown)
- self.mask[0] = 0
-
- def build_base_transition(self):
- self.base_transition = numpy.zeros((self.n, self.m))
- for id, bid in enumerate(self.relation_dictionary.id_to_bid):
- self.base_transition[id, bid] = 1
-
- def update(self, predictions, target):
- """
- Update metrics with a batch of predictions and corresponding targets.
-
- Args:
- predictions: the predicted logits (before softmax)
- target: the gold relations
- """
-
- self.size += predictions.shape[0]
- self.ce_sum += self.crossentropy(predictions, target).item()
-
- prediction = predictions.argmax(1)
- for p, t in zip(prediction.tolist(), target.tolist()):
- self.confusion[p, t] += 1
- self.correct += (p == t)
-
- @property
- def summary(self):
- return {"accuracy": f"{self.accuracy*100:.2f}",
- "nll": f"{self.nll:.2f}"}
-
- @property
- def all(self):
- keys = ["accuracy", "nll"] + [
- f"{direction}_{level}_{metric}"
- for direction in ["directed", "undirected", "half_directed"]
- for level in ["macro", "micro"]
- for metric in ["f1", "precision", "recall"]]
- return {key: getattr(self, key) for key in keys}
-
- @property
- def base_mask(self):
- return self.mask.dot(self.base_transition).clip(0, 1)
-
- @property
- def base_confusion(self):
- return self.base_transition.T.dot(self.confusion).dot(self.base_transition)
-
- @property
- def accuracy(self):
- return math.nan if self.size == 0 else self.correct / self.size
-
- @property
- def nll(self):
- return math.nan if self.size == 0 else self.ce_sum / self.size
-
- ##########################
- # Directed macro metrics #
- ##########################
-
- @property
- def directed_class_precision(self):
- norm = self.confusion.sum(1)
- norm[norm == 0] = 1
- return self.confusion.diagonal() / norm
-
- @property
- def directed_class_recall(self):
- norm = self.confusion.sum(0)
- norm[norm == 0] = 1
- return self.confusion.diagonal() / norm
-
- @property
- def directed_class_f1(self):
- norm = self.directed_class_precision + self.directed_class_recall
- norm[norm == 0] = 1
- return 2 * self.directed_class_precision * self.directed_class_recall / norm
-
- @property
- def directed_macro_precision(self):
- return numpy.sum(self.directed_class_precision * self.mask) / self.mask.sum()
-
- @property
- def directed_macro_recall(self):
- return numpy.sum(self.directed_class_recall * self.mask) / self.mask.sum()
-
- @property
- def directed_macro_f1(self):
- return numpy.sum(self.directed_class_f1 * self.mask) / self.mask.sum()
-
- ############################
- # Undirected macro metrics #
- ############################
-
- @property
- def undirected_class_precision(self):
- norm = self.base_confusion.sum(1)
- norm[norm == 0] = 1
- return self.base_confusion.diagonal() / norm
-
- @property
- def undirected_class_recall(self):
- norm = self.base_confusion.sum(0)
- norm[norm == 0] = 1
- return self.base_confusion.diagonal() / norm
-
- @property
- def undirected_class_f1(self):
- norm = self.undirected_class_precision + self.undirected_class_recall
- norm[norm == 0] = 1
- return 2 * self.undirected_class_precision * self.undirected_class_recall / norm
-
- @property
- def undirected_macro_precision(self):
- return numpy.sum(self.undirected_class_precision * self.base_mask) / self.base_mask.sum()
-
- @property
- def undirected_macro_recall(self):
- return numpy.sum(self.undirected_class_recall * self.base_mask) / self.base_mask.sum()
-
- @property
- def undirected_macro_f1(self):
- return numpy.sum(self.undirected_class_f1 * self.base_mask) / self.base_mask.sum()
-
- ###############################
- # Half-directed macro metrics #
- ###############################
-
- @property
- def half_directed_class_precision(self):
- norm = self.base_confusion.sum(1)
- norm[norm == 0] = 1
- return self.confusion.diagonal().dot(self.base_transition) / norm
-
- @property
- def half_directed_class_recall(self):
- norm = self.base_confusion.sum(0)
- norm[norm == 0] = 1
- return self.confusion.diagonal().dot(self.base_transition) / norm
-
- @property
- def half_directed_class_f1(self):
- norm = self.half_directed_class_precision + self.half_directed_class_recall
- norm[norm == 0] = 1
- return 2 * self.half_directed_class_precision * self.half_directed_class_recall / norm
-
- @property
- def half_directed_macro_precision(self):
- return numpy.sum(self.half_directed_class_precision * self.base_mask) / self.base_mask.sum()
-
- @property
- def half_directed_macro_recall(self):
- return numpy.sum(self.half_directed_class_recall * self.base_mask) / self.base_mask.sum()
-
- @property
- def half_directed_macro_f1(self):
- return numpy.sum(self.half_directed_class_f1 * self.base_mask) / self.base_mask.sum()
-
- #################
- # Micro metrics #
- #################
-
- @property
- def directed_micro_precision(self):
- norm = numpy.sum(self.confusion.sum(1) * self.mask)
- return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm
-
- @property
- def directed_micro_recall(self):
- norm = numpy.sum(self.confusion.sum(0) * self.mask)
- return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm
-
- @property
- def directed_micro_f1(self):
- norm = self.directed_micro_precision + self.directed_micro_recall
- return 0 if norm == 0 else 2 * (self.directed_micro_precision * self.directed_micro_recall) / norm
-
- @property
- def half_directed_micro_precision(self):
- norm = numpy.sum(self.confusion.sum(1) * self.mask)
- return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm
-
- @property
- def half_directed_micro_recall(self):
- norm = numpy.sum(self.confusion.sum(0) * self.mask)
- return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm
-
- @property
- def half_directed_micro_f1(self):
- norm = self.half_directed_micro_precision + self.half_directed_micro_recall
- return 0 if norm == 0 else 2 * (self.half_directed_micro_precision * self.half_directed_micro_recall) / norm
-
- @property
- def undirected_micro_precision(self):
- norm = numpy.sum(self.base_confusion.sum(1) * self.base_mask)
- return 0 if norm == 0 else numpy.sum(self.base_confusion.diagonal() * self.base_mask) / norm
-
- @property
- def undirected_micro_recall(self):
- norm = numpy.sum(self.base_confusion.sum(0) * self.base_mask)
- return 0 if norm == 0 else numpy.sum(self.base_confusion.diagonal() * self.base_mask) / norm
-
- @property
- def undirected_micro_f1(self):
- norm = self.undirected_micro_precision + self.undirected_micro_recall
- return 0 if norm == 0 else 2 * (self.undirected_micro_precision * self.undirected_micro_recall) / norm
diff --git a/fsre/model/mtb_classifier.py b/fsre/model/mtb_classifier.py
@@ -1,58 +0,0 @@
-import torch
-import torch.nn as nn
-import transformers
-
-
-class Classifier(nn.Module):
- """
- Transformer classifier from Soares et al.
-
- Correspond to the left part of each subfigure of Figure 2 (Deep Transformer Encoder and the green layer above).
-
- Config:
- bert_model: The version of BERT to use (e.g. bert-large-uncased).
- post_transformer_layer: The transformation applied after BERT (must be "linear" or "layer_norm")
- """
-
- def __init__(self, config, tokenizer):
- """
- Instantiate a Soares et al. classifier.
-
- Args:
- config: global config object
- tokenizer: tokenizer used to create the vocabulary
- """
-
- super().__init__()
-
- self.config = config
- self.tokenizer = tokenizer
-
- if self.config.get("load"):
- bert_config = transformers.BertConfig.from_pretrained(self.config.bert_model)
- bert_config.vocab_size = len(tokenizer)
- self.bert = transformers.BertModel(bert_config)
- else:
- self.bert = transformers.BertModel.from_pretrained(self.config.bert_model)
- self.bert.resize_token_embeddings(len(tokenizer))
-
- if self.config.post_transformer_layer == "linear":
- self.post_transformer = nn.Linear(
- in_features=self.output_size,
- out_features=self.output_size)
- elif self.config.post_transformer_layer == "layer_norm":
- self.post_transformer = torch.nn.LayerNorm(self.output_size)
- else:
- assert(False)
-
- @property
- def output_size(self):
- return self.bert.config.hidden_size * 2
-
- def forward(self, inputs):
- bert_out = self.bert(inputs["text"], attention_mask=inputs["mask"])[0]
- batch_ids = torch.arange(bert_out.shape[0], device=bert_out.device, dtype=torch.int64)
- e1_out = bert_out[batch_ids, inputs["e1_pos"]]
- e2_out = bert_out[batch_ids, inputs["e2_pos"]]
- sentence = torch.cat((e1_out, e2_out), dim=1)
- return self.post_transformer(sentence)
diff --git a/fsre/model/mtb_supervised.py b/fsre/model/mtb_supervised.py
@@ -1,39 +0,0 @@
-import torch
-import torch.nn as nn
-import transformers
-
-from fsre.model.mtb_classifier import Classifier
-
-
-class Model(nn.Module):
- """
- Supervised model from Soares et al.
-
- Correspond to the left subfigure of Figure 2.
- """
-
- def __init__(self, config, tokenizer, relation_dictionary):
- """
- Instantiate a Soares et al. supervised model.
-
- Args:
- config: global config object
- tokenizer: tokenizer used to create the vocabulary
- relation_dictionary: dictionary of all relations
- """
-
- super().__init__()
-
- self.config = config
- self.tokenizer = tokenizer
- self.relation_dictionary = relation_dictionary
-
- self.classifier = Classifier(config, tokenizer)
- self.relation_classifier = nn.Linear(
- in_features=self.classifier.output_size,
- out_features=len(relation_dictionary),
- bias=False)
-
- def forward(self, inputs):
- latent = self.classifier(inputs)
- return self.relation_classifier(latent)
diff --git a/fsre/train.py b/fsre/train.py
@@ -1,353 +0,0 @@
-import sys
-import os
-import math
-import time
-import contextlib
-import multiprocessing
-import signal
-import logging
-
-import tqdm
-import torch
-import transformers
-
-import fsre
-
-logger = logging.getLogger(__name__)
-
-
-class Trainer:
- """
- Train a model.
-
- Config:
- Model: the model class to use for training
- batch_size: the number of samples in the batch of data loaded
- bert_model: the model of transformer to use
- dataset_name: name of the dataset to load
- deterministic: run in deterministic mode
- early_stopping_patience: how many epoch to train after best validation score has been reached
- learning_rate: learning rate
- max_epoch: maximum number of epoch
- no_initial_validation: do not run evaluation on the valid dataset before first epoch
- seed: the seed for the random number generator
- sort_per_shuffle_bucket: the number of sort buckets in a shuffle bucket
- test_output: path to a file where the test predictions will be written
- true_batch_size: the actual number of sample in a batch, the number of sample seen before a backward (must be a multiple of batch_size)
- validation_metric: metric used for early stopping
- """
-
- def __init__(self, config, logdir, state_dicts=None):
- self.config = config
- self.logdir = logdir
- self.state_dicts = state_dicts
-
- def run(self):
- self.info()
- self.log_patch()
- self.initialize_rng()
- self.prepare_dataset()
- self.build_model()
- self.count_parameters()
- self.setup_optimizer()
- self.hook_signals()
- self.train()
-
- def environment_check(self):
- python_version = '.'.join(map(str, sys.version_info[:3]))
- torch_version = torch.__version__
- cuda_available = torch.cuda.is_available()
-
- logger.info(f"python version {python_version}")
- logger.info(f"torch version {torch_version}")
- logger.info(f"cuda available {cuda_available}")
-
- def problem(str):
- return f"\033[1m\033[31m{str}\033[0m"
-
- if sys.version_info < (3, 7):
- python_version = problem(python_version)
- if list(map(int, torch_version.split('.'))) < [1, 3]:
- torch_version = problem(torch_version)
- if not cuda_available:
- cuda_available = problem(cuda_available)
-
- print(f"python version: {python_version}, torch version: {torch_version}, cuda available: {cuda_available}")
-
- def detect_gpus(self):
- count = torch.cuda.device_count()
-
- if count == 0:
- print(f"\033[1m\033[31mNo GPU available\033[0m")
- logger.warning("no GPU available")
- self.device = torch.device("cpu")
- else:
- self.device = torch.device("cuda:0")
-
- for i in range(count):
- gp = torch.cuda.get_device_properties(i)
- print(f"GPU{i}: \033[33m{gp.name}\033[0m (Mem: {gp.total_memory/2**30:.2f}GiB CC: {gp.major}.{gp.minor})")
- logger.info(f"GPU{i} {gp.name} {gp.total_memory} {gp.major}.{gp.minor}")
-
- def info(self):
- if self.logdir is None:
- print(f"logdir is \033[1m\033[31mnot set\033[0m, log messages will be discarded")
- else:
- print(f"logdir is \033[1m\033[33m{self.logdir}\033[0m")
-
- self.environment_check()
- self.detect_gpus()
- print("")
-
- print("\033[1m\033[33mConfiguration\033[0m")
- fsre.utils.print_dict(self.config)
- fsre.utils.log_dict(logging.getLogger("config"), self.config)
- print("")
-
- def log_patch(self):
- version = fsre.utils.get_repo_version()
- logger.info(f"repository_version {version}")
- if version == "release":
- print(f"\033[41mRelease version\033[0m\n")
- elif version.endswith('+'):
- print(f"\033[31mUncommited changes detected, saving patch to logdir.\033[0m\n")
- suffix = ""
- if self.state_dicts:
- suffix = time.strftime("%FT%H:%M:%S")
- fsre.utils.save_patch(self.logdir / f"patch{suffix}")
-
- def initialize_rng(self):
- if self.state_dicts:
- torch.random.set_rng_state(self.state_dicts["torch_rng"])
- assert(("cuda_rng" in self.state_dicts) == torch.cuda.is_available())
- if "cuda_rng" in self.state_dicts:
- torch.cuda.random.set_rng_state_all(self.state_dicts["cuda_rng"])
- else:
- torch.manual_seed(self.config.seed)
-
- if self.config.get("deterministic"):
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
-
- def prepare_dataset(self):
- data_dir = fsre.utils.DATA_PATH / self.config.dataset_name / self.config.bert_model
- self.relation_dictionary = fsre.data.RelationDictionary(path=data_dir / "relations")
- self.tokenizer = transformers.BertTokenizer.from_pretrained(data_dir / "tokenizer")
-
- self.dataset = {}
- self.iterator = {}
-
- for dataset in ["train", "valid", "test"]:
- if dataset == "train":
- kwargs = {"rng": self.state_dicts["train_rng"]} if self.state_dicts else {}
-
- self.dataset[dataset] = fsre.data.RelationExtractionDataset(
- self.config,
- data_dir / f"{dataset}.npy",
- pad=self.tokenizer.pad_token_id,
- evaluation=(dataset != "train"),
- **kwargs)
-
- self.iterator[dataset] = lambda dataset=dataset: torch.utils.data.DataLoader(
- dataset=self.dataset[dataset],
- batch_size=self.config.batch_size,
- num_workers=self.config.sort_per_shuffle_bucket,
- pin_memory=(self.device.type == "cuda"))
-
- def build_model(self):
- self.model = self.config.Model(self.config, self.tokenizer, self.relation_dictionary)
-
- if self.state_dicts:
- self.model.load_state_dict(self.state_dicts["model"])
-
- self.loss = torch.nn.CrossEntropyLoss()
- if self.device.type == "cuda":
- self.model.to(self.device)
- self.loss.to(self.device)
-
- def count_parameters(self):
- total = 0
- for parameter in self.model.parameters():
- total += parameter.shape.numel()
- print(f"\033[33mNumber of parameters: {total:,}\033[0m")
-
- def setup_optimizer(self):
- self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.learning_rate)
- if self.state_dicts:
- self.optimizer.load_state_dict(self.state_dicts["optimizer"])
-
- def hook_signals(self):
- self.interrupted = False
-
- def handler(sig, frame):
- if multiprocessing.current_process().name != "MainProcess":
- return
-
- print("\n\033[31mInterrupted, training will stop at the end of this epoch.\n\033[1mNEXT ^C WILL KILL THE PROCESS!\033[0m\n", file=sys.stderr)
- self.interrupted = True
- signal.signal(signal.SIGINT, signal.SIG_DFL)
-
- signal.signal(signal.SIGINT, handler)
-
- def eval_context(self, dataset):
- self.test_output_file = None
- if dataset == "test" and self.config.get("test_output"):
- self.test_output_file = open(self.config.test_output, 'w')
- return self.test_output_file
- return contextlib.nullcontext()
-
- def eval_handle_predictions(self, ids, predictions):
- if self.test_output_file is None:
- return
- ids = ids.tolist()
- predictions = predictions.argmax(1).tolist()
- for id, prediction in zip(ids, predictions):
- prediction = self.relation_dictionary.decode(prediction)
- self.test_output_file.write(f"{id}\t{prediction}\n")
-
- def evaluate(self, dataset):
- loop = tqdm.tqdm(
- iterable=self.iterator[dataset](),
- desc=f"Epoch {self.epoch:2} {dataset:5}",
- unit="samples",
- unit_scale=self.config.batch_size,
- total=math.ceil(len(self.dataset[dataset]) / self.config.batch_size),
- leave=False)
-
- self.model.eval()
- output_prediction_to_file = True
- with torch.no_grad(), self.eval_context(dataset):
- has_target = False
- scorer = fsre.metrics.Metrics(self.relation_dictionary)
- correct_prediction = 0
- for batch in loop:
- batch = {key: value.to(self.device) for key, value in batch.items()}
-
- # Pop the target to ensure it's not used by the model
- if "relation" in batch:
- target = batch.pop("relation")
- has_target = True
- predictions = self.model(batch)
- self.eval_handle_predictions(batch["id"], predictions)
- if has_target:
- scorer.update(predictions, target)
- loop.set_postfix(**scorer.summary, refresh=False)
-
- if has_target:
- print(f"Epoch {self.epoch} {dataset:5} accuracy: {scorer.accuracy*100:8.4f}% Half-directed Macro F1: {scorer.half_directed_macro_f1*100:8.4f}% (P: {scorer.half_directed_macro_precision*100:8.4f}% R: {scorer.half_directed_macro_recall*100:8.4f}%) NLL: {scorer.nll:8.4f}")
- for metric, value in scorer.all.items():
- logger.info(f"epoch {self.epoch} {dataset} {metric} {value}")
- return getattr(scorer, self.config.validation_metric)
- return None
-
- def save(self, path):
- state_dicts = {
- "logdir": self.logdir,
- "config": self.config,
- "model": self.model.state_dict(),
- "optimizer": self.optimizer.state_dict(),
- "train_rng": self.dataset["train"].state_dict(),
- "torch_rng": torch.random.get_rng_state(),
- "epoch": self.epoch,
- "best_epoch": self.best_epoch,
- "best_eval": self.best_eval,
- }
- if torch.cuda.is_available():
- state_dicts["cuda_rng"] = torch.cuda.random.get_rng_state_all()
- torch.save(state_dicts, path)
-
- def train(self):
- self.epoch = 0
- self.best_epoch = 0
- self.best_eval = -float("inf")
-
- if self.state_dicts:
- self.epoch = self.state_dicts["epoch"]
- self.best_epoch = self.state_dicts["best_epoch"]
- self.best_eval = self.state_dicts["best_eval"]
-
- best_path = self.logdir / "best"
- batch_per_epoch = int(math.ceil(len(self.dataset["train"]) / self.config.batch_size))
- assert(self.config.true_batch_size % self.config.batch_size == 0)
- data_batch_per_true_batch = self.config.true_batch_size // self.config.batch_size
-
- if not self.config.get("no_initial_validation"):
- self.best_eval = self.evaluate("valid")
-
- for self.epoch in range(self.epoch+1, self.config.max_epoch+1):
- if self.interrupted:
- break
-
- loop = tqdm.tqdm(
- iterable=self.iterator["train"](),
- desc=f"Epoch {self.epoch} train",
- unit="samples",
- unit_scale=self.config.batch_size,
- total=math.ceil(len(self.dataset["train"]) / self.config.batch_size),
- leave=False)
-
- self.model.train()
- self.optimizer.zero_grad()
- total_loss = 0
- total_sample = 0
- for batch_id, batch in enumerate(loop):
- batch = {key: value.to(self.device) for key, value in batch.items()}
-
- # Pop the target to ensure it's not used by the model
- target = batch.pop("relation")
-
- output = self.model(batch)
- loss = self.loss(output, target)
- loss.backward()
-
- total_loss += loss.item()
- total_sample += target.shape[0]
- loop.set_postfix(loss=f"{total_loss / total_sample:.2f}", refresh=False)
-
- if batch_id % data_batch_per_true_batch == data_batch_per_true_batch - 1:
- self.optimizer.step()
- self.optimizer.zero_grad()
-
- print(f"Epoch {self.epoch} train mean loss: {total_loss / total_sample:8.4f}")
- logger.info(f"epoch {self.epoch} train mean_loss {total_loss / total_sample}")
-
- self.save(self.logdir / "checkpoint.new")
- os.rename(self.logdir / "checkpoint.new", self.logdir / "checkpoint")
-
- candidate = self.evaluate("valid")
- if candidate > self.best_eval:
- self.best_epoch = self.epoch
- self.best_eval = candidate
- self.save(best_path)
- logger.info(f"Model saved to {best_path}")
- elif self.epoch - self.best_epoch > self.config.get("early_stopping_patience", self.config.max_epoch):
- break
-
- if self.best_eval != -float("inf") and best_path.exists():
- print(f"Loading best model from {best_path}…", end="", flush=True)
- self.model.load_state_dict(torch.load(best_path)["model"])
- print(" done")
- logger.info(f"{best_path} loaded for evaluation on test set")
-
- self.evaluate("test")
-
-
-if __name__ == "__main__":
- fsre.utils.fix_transformers_logging_handler()
- config = fsre.utils.parse_args()
-
- state_dicts = None
- if config.get("load"):
- state_dicts = torch.load(config.load)
- logdir = state_dicts["logdir"]
- if config.get("reuse_config"):
- config = state_dicts["config"]
- else:
- logdir = fsre.utils.logdir_name("FSRE")
- assert(not logdir.exists())
- logdir.mkdir()
-
- logfile = logdir / "log"
- logging.basicConfig(format="%(asctime)s\t%(levelname)s:%(name)s:%(message)s", filename=logfile, filemode='a', level=logging.INFO)
-
- Trainer(config, logdir, state_dicts).run()
diff --git a/fsre/utils.py b/fsre/utils.py
@@ -1,143 +0,0 @@
-import os
-import sys
-import types
-import importlib
-import subprocess
-import logging
-import time
-import hashlib
-import pathlib
-
-
-def import_environment(name, cast=str):
- try:
- globals()[name] = cast(os.environ[name])
- except KeyError:
- print(f"ERROR: {name} environment variable is not set.",
- file=sys.stderr)
- sys.exit(1)
-
-
-import_environment("DATA_PATH", pathlib.Path)
-import_environment("LOG_PATH", pathlib.Path)
-
-
-class dotdict(dict):
- def __getattr__(self, name):
- if name not in self:
- raise AttributeError(f"Config key {name} not found")
- return dotdict(self[name]) if type(self[name]) is dict else self[name]
- __setattr__ = dict.__setitem__
- __delattr__ = dict.__delitem__
-
-
-def eval_arg(config, arg):
- if '=' in arg:
- key, value = arg.split('=', maxsplit=1)
- value = eval(value, config)
- else:
- key, value = arg, True
- path = key.split('.')
- for d in path[:-1]:
- config = config[d]
- config[path[-1]] = value
- config.pop("__builtins__", None)
-
-
-def import_arg(config, arg):
- if arg.endswith(".py"):
- arg = arg[:-3].replace('/', '.')
- module = importlib.import_module(arg)
- for key, value in vars(module).items():
- if key not in module.__builtins__ \
- and not key.startswith("__") \
- and not isinstance(value, types.ModuleType):
- config[key] = value
-
-
-def parse_args():
- config = {}
- for arg in sys.argv[1:]:
- if arg.startswith("--"):
- eval_arg(config, arg[2:])
- else:
- import_arg(config, arg)
- return dotdict(config)
-
-
-def map_dict(output, input, depth=0):
- for key, value in input.items():
- indent = '\t'*depth
- output(f"{indent}{key}:")
- if isinstance(value, dict):
- output('\n')
- map_dict(output, value, depth+1)
- else:
- output(f" {value}\n")
-
-
-def print_dict(input):
- map_dict(lambda x: print(x, end=""), input)
-
-
-def log_dict(logger, input):
- class log:
- buf = ""
-
- def __call__(self, x):
- self.buf += x
- if self.buf.endswith('\n'):
- logger.info(self.buf[:-1])
- self.buf = ""
- map_dict(log(), input)
-
-
-def get_repo_version():
- repo_dir = pathlib.Path(__file__).parents[0]
- result = subprocess.run(["hg", "id", "-i"],
- stdout=subprocess.PIPE,
- stderr=subprocess.DEVNULL,
- encoding="utf-8",
- cwd=repo_dir)
-
- if result.returncode != 0:
- return "release"
- return result.stdout.rstrip()
-
-
-def experiment_name(name):
- args = ' '.join(sys.argv[1:])
- version = get_repo_version()
- stime = time.strftime("%FT%H:%M:%S")
- return f"{name} {version} {args} {stime}"
-
-
-def logdir_name(name):
- subdir = experiment_name(name).replace('/', '_')
- if len(subdir) > 255:
- sha1 = hashlib.sha1(subdir.encode("utf-8")).digest().hex()[:16]
- subdir = subdir[:255-17] + ' ' + sha1
- return LOG_PATH / subdir
-
-
-def fix_transformers_logging_handler():
- """
- The transformers package from huggingface install its own logger on import,
- I don't want it.
- """
- logger = logging.getLogger()
- for handler in logger.handlers:
- logger.removeHandler(handler)
-
-
-def save_patch(outpath):
- repo_dir = pathlib.Path(__file__).parents[0]
-
- with open(outpath, "w") as outfile:
- result = subprocess.run(["hg", "diff"],
- stdout=outfile,
- stderr=subprocess.DEVNULL,
- encoding="utf-8",
- cwd=repo_dir)
-
- assert(result.returncode == 0)
diff --git a/gbure/__init__.py b/gbure/__init__.py
diff --git a/gbure/config/contrastive_alignment.py b/gbure/config/contrastive_alignment.py
@@ -0,0 +1,53 @@
+from gbure.model.contrastive_alignment import Model
+from gbure.model.fewshot import Model as EvalModel
+from torch.optim import Adam as Optimizer
+from torch.optim.lr_scheduler import LinearLR as Scheduler
+
+
+dataset_name = "T-REx"
+graph_name = "T-REx"
+unsupervised = "triplet"
+
+eval_dataset_name = "FewRel"
+valid_name = "7def1330ba9527d6"
+shot = 1
+way = 5
+
+margin = 1
+neighborhood_size = 3
+filter_empty_neighborhood = True
+sinkhorn_blur = 0.05
+
+# Necessary to make a distance
+linguistic_similarity = "euclidean"
+undefined_poison_whole_meta = True
+
+# From section 4.3
+blank_probability = 0.7
+
+# From section 5
+transformer_model = "bert-base-cased"
+max_epoch = 10
+sample_per_epoch = 100000
+learning_rate = 3e-5
+accumulated_batch_size = 256
+clip_gradient = 1
+
+# Guessed
+post_transformer_layer = "linear" # Maybe we should change this depending on the subsequent task?
+max_sentence_length = 100 # Maybe should be 40 (from footnote 2, guessed from ACL slides)
+language_model_weight = 0
+edge_sampling = "uniform-inverse degree"
+
+# From BERT
+mlm_probability = 0.15
+mlm_masked_probability = 0.8
+mlm_random_probability = 0.1
+
+# Implementation details
+seed = 0
+amp = True
+initial_grad_scale = 1
+batch_size = 2
+eval_batch_size = 1
+workers = 2
diff --git a/gbure/config/gcn_mtb.py b/gbure/config/gcn_mtb.py
@@ -0,0 +1,63 @@
+from gbure.model.matching_the_blanks import Model
+from gbure.model.fewshot import Model as EvalModel
+from torch.optim import Adam as Optimizer
+from torch.optim.lr_scheduler import LinearLR as Scheduler
+
+
+dataset_name = "T-REx"
+graph_name = "T-REx"
+unsupervised = "mtb"
+
+eval_dataset_name = "FewRel"
+valid_name = "7def1330ba9527d6"
+shot = 1
+way = 5
+
+# From Section 4.1
+linguistic_similarity = "dot"
+undefined_poison_whole_meta = True
+
+# Observed to be better
+latent_metric_scale = "standard"
+latent_dot_mean = 1067.65
+latent_dot_std = 111.17
+
+# GCN
+neighborhood_size = 3
+gcn_aggregator = "mean"
+
+# From Section 4.3
+blank_probability = 0.7
+
+# From Section 5
+transformer_model = "bert-base-cased"
+sample_per_epoch = 100000
+learning_rate = 3e-5
+accumulated_batch_size = 2048
+
+# Stated to be 10 in Section 5, but found 5 was better on T-REx dataset.
+max_epoch = 5
+
+# From BERT
+mlm_probability = 0.15
+mlm_masked_probability = 0.8
+mlm_random_probability = 0.1
+
+# Guessed
+# post_transformer_layer might need to be changed depending on the subsequent task
+# the "layer_norm" gives results within expectations for non-finetuned few-shot.
+max_sentence_length = 100 # Maybe should be 40 (from footnote 2, guessed from ACL slides)
+language_model_weight = 1
+edge_sampling = "uniform-inverse degree"
+clip_gradient = 1
+
+strong_negative_probability = 0.5
+weak_negative_probability = 0.0
+
+# Implementation details
+seed = 0
+amp = True
+initial_grad_scale = 1
+batch_size = 2
+eval_batch_size = 1
+workers = 2
diff --git a/gbure/config/mtb.py b/gbure/config/mtb.py
@@ -0,0 +1,57 @@
+from gbure.model.matching_the_blanks import Model
+from gbure.model.fewshot import Model as EvalModel
+from torch.optim import Adam as Optimizer
+from torch.optim.lr_scheduler import LinearLR as Scheduler
+
+
+dataset_name = "T-REx"
+unsupervised = "mtb"
+
+eval_dataset_name = "FewRel"
+valid_name = "7def1330ba9527d6"
+shot = 1
+way = 5
+
+# From Section 4.1
+linguistic_similarity = "dot"
+
+# Observed to be better
+latent_metric_scale = "standard"
+latent_dot_mean = 1067.65
+latent_dot_std = 111.17
+
+# From Section 4.3
+blank_probability = 0.7
+
+# From Section 5
+transformer_model = "bert-base-cased"
+sample_per_epoch = 100000
+learning_rate = 3e-5
+accumulated_batch_size = 2048
+
+# Stated to be 10 in Section 5, but found 5 was better on T-REx dataset.
+max_epoch = 5
+
+# From BERT
+mlm_probability = 0.15
+mlm_masked_probability = 0.8
+mlm_random_probability = 0.1
+
+# Guessed
+# post_transformer_layer might need to be changed depending on the subsequent task
+# the "layer_norm" gives results within expectations for non-finetuned few-shot.
+max_sentence_length = 100 # Maybe should be 40 (from footnote 2, guessed from ACL slides)
+language_model_weight = 1
+edge_sampling = "uniform-inverse degree"
+clip_gradient = 1
+
+strong_negative_probability = 0.5
+weak_negative_probability = 0.0
+
+# Implementation details
+seed = 0
+amp = True
+initial_grad_scale = 1
+batch_size = 8
+eval_batch_size = 2
+workers = 8
diff --git a/gbure/config/nonparametric.py b/gbure/config/nonparametric.py
@@ -0,0 +1,32 @@
+from gbure.model.fewshot import Model
+from torch.optim import SGD as Optimizer
+
+
+max_epoch = 0
+dataset_name = "FewRel"
+graph_name = "T-REx"
+valid_name = "7def1330ba9527d6"
+shot = 1
+way = 5
+
+transformer_model = "bert-base-cased"
+post_transformer_layer = "none"
+learning_rate = 0
+accumulated_batch_size = 256
+
+neighborhood_size = 3
+linguistic_similarity = "dot"
+topological_weight = 0.2
+linguistic_weight = 1
+undefined_poison_whole_meta = True
+
+validation_metric = "accuracy"
+latent_metric_scale = "standard"
+latent_dot_mean = 1067.65
+latent_dot_std = 111.17
+
+# Implementation details
+seed = 0
+amp = True
+batch_size = 2
+workers = 2
diff --git a/gbure/config/soares_fewrel.py b/gbure/config/soares_fewrel.py
@@ -0,0 +1,30 @@
+from gbure.model.fewshot import Model
+from torch.optim import SGD as Optimizer
+
+
+dataset_name = "FewRel"
+shot = 1
+way = 5
+linguistic_similarity = "dot"
+
+# From Table 1
+transformer_model = "bert-base-cased"
+post_transformer_layer = "layer_norm"
+max_epoch = 10
+learning_rate = 1e-4
+accumulated_batch_size = 256
+
+# Guessed
+validation_metric = "accuracy"
+early_stopping_patience = 2
+latent_metric_scale = "standard"
+latent_dot_mean = 1067.65
+latent_dot_std = 111.17
+clip_gradient = 1
+
+# Implementation details
+seed = 0
+amp = True
+initial_grad_scale = 1
+batch_size = 2
+workers = 8
diff --git a/gbure/config/soares_kbp37.py b/gbure/config/soares_kbp37.py
@@ -0,0 +1,23 @@
+from gbure.model.supervised import Model
+from torch.optim import Adam as Optimizer
+
+
+dataset_name = "KBP37"
+linguistic_similarity = "dot"
+
+# From Table 1
+transformer_model = "bert-base-cased"
+post_transformer_layer = "linear"
+max_epoch = 10
+learning_rate = 3e-5
+accumulated_batch_size = 64
+
+# Guessed
+validation_metric = "half_directed_macro_f1"
+early_stopping_patience = 2
+clip_gradient = 1
+
+# Implementation details
+seed = 0
+batch_size = 2
+workers = 8
diff --git a/gbure/config/soares_semeval.py b/gbure/config/soares_semeval.py
@@ -0,0 +1,23 @@
+from gbure.model.supervised import Model
+from torch.optim import Adam as Optimizer
+
+
+dataset_name = "SemEval 2010 Task 8"
+linguistic_similarity = "dot"
+
+# From Table 1
+transformer_model = "bert-base-cased"
+post_transformer_layer = "layer_norm"
+max_epoch = 10
+learning_rate = 3e-5
+accumulated_batch_size = 64
+
+# Guessed
+validation_metric = "half_directed_macro_f1"
+early_stopping_patience = 2
+clip_gradient = 1
+
+# Implementation details
+seed = 0
+batch_size = 8
+workers = 8
diff --git a/gbure/data/__init__.py b/gbure/data/__init__.py
diff --git a/gbure/data/batcher.py b/gbure/data/batcher.py
@@ -0,0 +1,166 @@
+from typing import Any, Dict, List, Tuple
+import collections
+
+import torch
+
+
+# Must be kept prefix-sorted!
+# (prefix, list depth)
+FEATURE_PREFIXES: List[Tuple[str, int]] = [
+ ("query_e1_neighborhood_", 1),
+ ("query_e2_neighborhood_", 1),
+ ("candidates_e1_neighborhood_", 3),
+ ("candidates_e2_neighborhood_", 3),
+ ("first_e1_neighborhood_", 1),
+ ("first_e2_neighborhood_", 1),
+ ("second_e1_neighborhood_", 1),
+ ("second_e2_neighborhood_", 1),
+ ("third_e1_neighborhood_", 1),
+ ("third_e2_neighborhood_", 1),
+ ("query_", 0),
+ ("candidates_", 2),
+ ("first_", 0),
+ ("second_", 0),
+ ("third_", 0),
+ ("", 0)]
+
+
+class Batcher:
+ """
+ Batch a group of sample together.
+
+ Two new features are derived from the "text": its length and a mask.
+ """
+ def __init__(self, pad_value: int) -> None:
+ """ Initialize a Batcher, using the provided value to pad text. """
+ self.pad_value: int = pad_value
+
+ def add_length_field(self, batch: Dict[str, Any], prefix: str, depth: int) -> None:
+ """ Add the length field for the given prefix. """
+ text: List[Any] = batch[f"{prefix}text"]
+ batch_size: int = len(text)
+
+ if depth == 0:
+ # text is a list of sentences
+ lengths: torch.Tensor = torch.empty((batch_size,), dtype=torch.int64)
+ for b, sentence in enumerate(text):
+ lengths[b] = sentence.shape[0]
+ elif depth == 1:
+ # text is a list of list of sentences (each sample contains several candidates)
+ size: int = len(text[0])
+ lengths: torch.Tensor = torch.empty((batch_size, size), dtype=torch.int64)
+ for b, sample in enumerate(text):
+ for i, sentence in enumerate(sample):
+ lengths[b, i] = sentence.shape[0]
+ elif depth == 2:
+ # text is a list of list of list of sentences (each sample contains several candidates)
+ way: int = len(text[0])
+ shot: int = len(text[0][0])
+ lengths: torch.Tensor = torch.empty((batch_size, way, shot), dtype=torch.int64)
+ for b, sample in enumerate(text):
+ for w, candidates in enumerate(sample):
+ for s, candidate in enumerate(candidates):
+ lengths[b, w, s] = candidate.shape[0]
+ elif depth == 3:
+ # text is a list of list of list of list of sentences (each sample contains several candidates' neighborhoods)
+ way: int = len(text[0])
+ shot: int = len(text[0][0])
+ size: int = len(text[0][0][0])
+ lengths: torch.Tensor = torch.empty((batch_size, way, shot, size), dtype=torch.int64)
+ for b, sample in enumerate(text):
+ for w, candidates in enumerate(sample):
+ for s, candidate in enumerate(candidates):
+ for n, neighbor in enumerate(candidate):
+ lengths[b, w, s, n] = neighbor.shape[0]
+
+ batch[f"{prefix}length"] = lengths
+
+ def process_text(self, batch: Dict[str, Any], prefix: str, depth: int, key: str) -> None:
+ """ Build mask and text batch by padding sentences. """
+ in_text: List[Any] = batch[f"{prefix}{key}"]
+ if isinstance(batch[f"{prefix}length"], list):
+ self.add_length_field(batch, prefix, depth)
+ max_seq_len: int = max(batch[f"{prefix}length"].max(), 1)
+ batch_size: int = len(in_text)
+
+ if depth == 0:
+ # text is a list of sentences
+ text: torch.Tensor = torch.empty((batch_size, max_seq_len), dtype=torch.int64)
+ mask: torch.Tensor = torch.empty((batch_size, max_seq_len), dtype=torch.bool)
+ for b, sentence in enumerate(in_text):
+ text[b, :sentence.shape[0]] = sentence
+ text[b, sentence.shape[0]:] = self.pad_value
+ mask[b, :sentence.shape[0]] = 1
+ mask[b, sentence.shape[0]:] = 0
+ elif depth == 1:
+ # text is a list of list of sentences (each sample contains several candidates)
+ # In this case, we are not sure the tensor is full (some neighborhoods might be of different sizes or even empty)
+ size: int = len(in_text[0])
+ text: torch.Tensor = torch.empty((batch_size, size, max_seq_len), dtype=torch.int64)
+ mask: torch.Tensor = torch.zeros((batch_size, size, max_seq_len), dtype=torch.bool)
+ for b, samples in enumerate(in_text):
+ for i, sentence in enumerate(samples):
+ text[b, i, :sentence.shape[0]] = sentence
+ text[b, i, sentence.shape[0]:] = self.pad_value
+ mask[b, i, :sentence.shape[0]] = 1
+ elif depth == 2:
+ # text is a list of list of list of sentences (each sample contains several candidates)
+ # In this case, we are sure the tensor is full (all n way have the save k shots)
+ way: int = len(in_text[0])
+ shot: int = len(in_text[0][0])
+ text: torch.Tensor = torch.empty((batch_size, way, shot, max_seq_len), dtype=torch.int64)
+ mask: torch.Tensor = torch.empty((batch_size, way, shot, max_seq_len), dtype=torch.bool)
+ for b, samples in enumerate(in_text):
+ for w, candidates in enumerate(samples):
+ for s, candidate in enumerate(candidates):
+ text[b, w, s, :candidate.shape[0]] = candidate
+ text[b, w, s, candidate.shape[0]:] = self.pad_value
+ mask[b, w, s, :candidate.shape[0]] = 1
+ mask[b, w, s, candidate.shape[0]:] = 0
+ elif depth == 3:
+ # text is a list of list of list of list of sentences (each sample contains several candidates' neighborhoods)
+ # In this case, we are not sure the tensor is full (some neighborhoods might be of different sizes or even empty)
+ way: int = len(in_text[0])
+ shot: int = len(in_text[0][0])
+ size: int = len(in_text[0][0][0])
+ text: torch.Tensor = torch.empty((batch_size, way, shot, size, max_seq_len), dtype=torch.int64)
+ mask: torch.Tensor = torch.empty((batch_size, way, shot, size, max_seq_len), dtype=torch.bool)
+ for b, samples in enumerate(in_text):
+ for w, candidates in enumerate(samples):
+ for s, candidate in enumerate(candidates):
+ for n, neighbor in enumerate(candidate):
+ text[b, w, s, n, :neighbor.shape[0]] = neighbor
+ text[b, w, s, n, neighbor.shape[0]:] = self.pad_value
+ mask[b, w, s, n, :neighbor.shape[0]] = 1
+ mask[b, w, s, n, neighbor.shape[0]:] = 0
+
+ batch[f"{prefix}{key}"] = text
+ if f"{prefix}mask" not in batch:
+ batch[f"{prefix}mask"] = mask
+
+ def process_int_feature(self, batch: Dict[str, Any], prefix: str, feature: str) -> None:
+ """ Transform a list of integer into a torch LongTensor. """
+ # TODO handle neighborhoods of different sizes
+ if isinstance(batch[f"{prefix}{feature}"][0], torch.Tensor):
+ batch[f"{prefix}{feature}"] = torch.stack(batch[f"{prefix}{feature}"])
+ else:
+ batch[f"{prefix}{feature}"] = torch.tensor(batch[f"{prefix}{feature}"], dtype=torch.int64)
+
+ def __call__(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """ Batch the provided samples """
+ batch = collections.defaultdict(list)
+ for sample in samples:
+ for key, value in sample.items():
+ batch[key].append(value)
+
+ for key in list(batch.keys()):
+ for prefix, depth in FEATURE_PREFIXES:
+ if key.startswith(prefix):
+ break
+ feature: str = key[len(prefix):]
+ if feature in ["text", "mlm_input", "mlm_target"]:
+ self.process_text(batch, prefix, depth, feature)
+ if feature in ["relation", "entity_positions", "entity_identifiers", "entity_degrees", "edge_identifier", "polarity", "answer", "eid"]:
+ self.process_int_feature(batch, prefix, feature)
+
+ return batch
diff --git a/gbure/data/dataset.py b/gbure/data/dataset.py
@@ -0,0 +1,668 @@
+from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union
+import collections
+import pathlib
+import random
+
+import torch
+import transformers
+
+import gbure.data.dictionary
+from gbure.data.graph import Graph
+import gbure.utils
+
+
+class SupervisedDataset(torch.utils.data.Dataset):
+ """
+ Read a preprocessed supervised relation extraction dataset.
+
+ A preprocessed dataset can be created from the gbure.data.prepare_* scripts.
+ """
+ shuffleable: bool = True
+
+ def __init__(self, config: gbure.utils.dotdict, path: pathlib.Path, tokenizer: transformers.PreTrainedTokenizer, evaluation: bool, rng: Optional[random.Random] = None, data: Optional[List[Tuple[torch.Tensor, int, int, int]]] = None) -> None:
+ """ Initialize a supervised dataset and load the data in RAM. """
+ super().__init__()
+
+ self.config: gbure.utils.dotdict = config
+ self.path: pathlib.Path = path
+ self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
+ self.evaluation: bool = evaluation
+ if data is None:
+ self.load()
+ else:
+ self.data = data
+
+ def load(self) -> None:
+ """ Load the dataset into RAM. """
+ dstype: str
+ self.data: List[Tuple[torch.Tensor, int, int, int]]
+ dstype, self.data = torch.load(self.path)
+ assert(dstype == "supervised")
+
+ def __len__(self) -> int:
+ """ Get the number of samples in the dataset. """
+ return len(self.data)
+
+ def __getitem__(self, index: int) -> Dict[str, Any]:
+ """ Get the sample at the given index. """
+ sample: Dict[str, Any] = {}
+ sample["text"] = self.data[index][0]
+ sample["entity_positions"] = torch.tensor(self.data[index][1:3], dtype=torch.int64)
+ sample["relation"] = self.data[index][3]
+ return sample
+
+
+class SampledFewShotDataset(torch.utils.data.Dataset):
+ """
+ Read a preprocessed few shot relation extraction dataset.npy file containing samples.
+
+ A preprocessed dataset can be created from the gbure.data.prepare_* scripts.
+ """
+ shuffleable: bool = True
+
+ def __init__(self, config: gbure.utils.dotdict, path: pathlib.Path, tokenizer: transformers.PreTrainedTokenizer, evaluation: bool, rng: Optional[random.Random] = None, data: Optional[List[Tuple[torch.Tensor, int, int, int, int, List[List[torch.Tensor]], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]]] = None) -> None:
+ """ Initialize a few shot dataset and load the samples in RAM. """
+ super().__init__()
+
+ self.config: gbure.utils.dotdict = config
+ self.path: pathlib.Path = path
+ self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
+ self.evaluation: bool = evaluation
+ if data is None:
+ self.load()
+ else:
+ self.data = data
+
+ def load(self) -> None:
+ """ Load the dataset into RAM. """
+ dstype: str
+ self.data: List[Tuple[torch.Tensor, int, int, int, int, List[List[torch.Tensor]], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]]
+ dstype, self.data = torch.load(self.path)
+ assert(dstype == "sampled fewshot")
+
+ def __len__(self) -> int:
+ """ Get the number of samples in the dataset. """
+ return len(self.data)
+
+ def __getitem__(self, index: int) -> Dict[str, Any]:
+ """ Get the sample at the given index. """
+ sample: Dict[str, Any] = {}
+ sample["query_text"] = self.data[index][0]
+ sample["query_entity_positions"] = torch.tensor(self.data[index][1:3], dtype=torch.int64)
+ sample["query_entity_identifiers"] = torch.tensor(self.data[index][3:5], dtype=torch.int64)
+ sample["candidates_text"] = self.data[index][5]
+ sample["candidates_entity_positions"] = torch.stack(self.data[index][6:8], dim=2)
+ sample["candidates_entity_identifiers"] = torch.stack(self.data[index][8:10], dim=2)
+ sample["answer"] = self.data[index][10]
+ return sample
+
+
+class FewShotDataset(torch.utils.data.IterableDataset):
+ """
+ Read a preprocessed few shot relation extraction dataset.npy file.
+
+ A preprocessed dataset can be created from the gbure.data.prepare_*
+ modules.
+
+ Config:
+ seed: the seed for the random number generator
+ shot: the number of candidates per relation
+ way: the number of relation classes used for candidates
+ """
+ shuffleable: bool = False # FIXME ?
+
+ def __init__(self, config: gbure.utils.dotdict, path: pathlib.Path, tokenizer: transformers.PreTrainedTokenizer, evaluation: bool, rng: Optional[random.Random] = None, data: Optional[List[List[Tuple[torch.Tensor, int, int, int, int, int]]]] = None) -> None:
+ """ Initialize a few shot dataset and load the data in RAM. """
+ super().__init__()
+
+ self.config: gbure.utils.dotdict = config
+ self.path: pathlib.Path = path
+ self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
+ self.evaluation: bool = evaluation
+
+ if data is None:
+ self.load()
+ else:
+ self.data = data
+ self.num_relations: int = len(self.data)
+ self.num_samples_per_relation: int = len(self.data[0])
+
+ def init_seed(self, worker_id: Optional[int] = None) -> None:
+ """ Initialize the RNG. """
+ if not self.evaluation:
+ seed: int = self.config.seed
+ worker_info = torch.utils.data.get_worker_info()
+ seed += worker_id if worker_id is not None else (worker_info.id if worker_info is not None else 0)
+ rng = random.Random(seed)
+ self.rng = rng
+
+ def load(self) -> None:
+ """ Load the dataset into RAM. """
+ dstype: str
+ self.data: List[List[Tuple[torch.Tensor, int, int, int, int, int]]]
+ dstype, self.data = torch.load(self.path)
+ assert(dstype == "fewshot")
+
+ def __len__(self):
+ """ Get the number of samples in the dataset. """
+ return self.num_relations * self.num_samples_per_relation * self.config.get("meta_per_sample", 1)
+
+ def get_rng(self, relation: int, sentence: int) -> random.Random:
+ """ Get the random number generator for the given query. """
+ if self.evaluation:
+ return random.Random(self.config.seed * len(self) + relation * self.num_samples_per_relation + sentence)
+ else:
+ return self.rng
+
+ @staticmethod
+ def sample_exclude(rng: random.Random, population: int, exclude: int, size: int) -> List[int]:
+ """ Chooses size unique random elements from [0, population)\\{exclude}. """
+ samples: List[int] = rng.sample(range(population-1), size)
+ return [sample + (1 if sample >= exclude else 0) for sample in samples]
+
+ def sample_meta(self, query_relation: int, query_sid: int) -> Dict[str, Any]:
+ """ Build a fewshot sample from the given query. """
+ rng: random.Random = self.get_rng(query_relation, query_sid)
+
+ # positives
+ candidates: List[List[Tuple[int, int]]] = [[(query_relation, sid) for sid in self.sample_exclude(rng, self.num_samples_per_relation, query_sid, self.config.shot)]]
+ # negatives
+ for negative_relation in self.sample_exclude(rng, self.num_relations, query_relation, self.config.way-1):
+ candidates.append([(negative_relation, sid) for sid in rng.sample(range(self.num_samples_per_relation), self.config.shot)])
+
+ order: List[int] = list(range(self.config.way))
+ rng.shuffle(order)
+ candidates = [candidates[i] for i in order]
+ answer = order.index(0)
+
+ meta: Dict[str, Any] = {}
+ meta[f"query_text"] = self.data[query_relation][query_sid][0]
+ meta[f"query_entity_positions"] = torch.tensor(self.data[query_relation][query_sid][1:3], dtype=torch.int64)
+ meta[f"query_relation"] = self.data[query_relation][query_sid][3]
+ meta[f"query_entity_identifiers"] = torch.tensor(self.data[query_relation][query_sid][4:6], dtype=torch.int64)
+ meta[f"candidates_text"] = [[self.data[shot_relation][shot_sid][0] for shot_relation, shot_sid in way] for way in candidates]
+ meta[f"candidates_entity_positions"] = torch.tensor([[self.data[shot_relation][shot_sid][1:3] for shot_relation, shot_sid in way] for way in candidates], dtype=torch.int64)
+ meta[f"candidates_relation"] = torch.tensor([[self.data[shot_relation][shot_sid][3] for shot_relation, shot_sid in way] for way in candidates], dtype=torch.int64)
+ meta[f"candidates_entity_identifiers"] = torch.tensor([[self.data[shot_relation][shot_sid][4:6] for shot_relation, shot_sid in way] for way in candidates], dtype=torch.int64)
+ meta["answer"] = answer
+ return meta
+
+ def __iter__(self) -> Iterator[Dict[str, Any]]:
+ """ Generate samples from the dataset. """
+ self.order: List[Tuple[int, int]] = [(relation, sid) for relation in range(self.num_relations) for sid in range(self.num_samples_per_relation)]
+ if not self.evaluation:
+ self.rng.shuffle(self.order)
+
+ worker_info = torch.utils.data.get_worker_info()
+ if worker_info is None:
+ worker_modulo: int = 1
+ worker_residue: int = 0
+ else:
+ worker_modulo: int = worker_info.num_workers
+ worker_residue: int = worker_info.id
+
+ mps = self.config.get("meta_per_sample", 1)
+ for index, (relation, sid) in enumerate(self.order):
+ for j in range(mps):
+ if (index*mps+j) % worker_modulo == worker_residue:
+ yield self.sample_meta(relation, sid)
+
+
+class UnsupervisedDataset(torch.utils.data.IterableDataset):
+ """
+ Read a preprocessed unsupervised relation extraction dataset.
+
+ A preprocessed dataset can be created from the gbure.data.prepare_* scripts.
+
+ Config:
+ blank_probability: the probability to replace an entity with <blank/>
+ edge_sampling: the sampling strategy to avoid (or not) popular entities
+ sample_per_epoch: the number of sample in an epoch
+ seed: the seed for the random number generator
+ """
+ shuffleable: bool = False
+
+ def __init__(self, config: gbure.utils.dotdict, path: Optional[pathlib.Path], tokenizer: transformers.PreTrainedTokenizer, evaluation: bool, rng: Optional[random.Random] = None) -> None:
+ """ Initialize a supervised dataset and load the data in RAM. """
+ super().__init__()
+
+ self.config: gbure.utils.dotdict = config
+ self.path: Optional[pathlib.Path] = path
+ self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
+ self.evaluation: bool = evaluation
+ self.load()
+ self.init_seed()
+
+ def init_seed(self, worker_id: Optional[int] = None) -> None:
+ """ Initialize the RNG. """
+ seed: int = self.config.seed
+ worker_info = torch.utils.data.get_worker_info()
+ seed += worker_id if worker_id is not None else (worker_info.id if worker_info is not None else 0)
+ rng = random.Random(seed)
+ self.rng = rng
+
+ def load(self) -> None:
+ """ Load the dataset into RAM. """
+ if self.path is not None:
+ self.graph = Graph(path=self.path)
+ if self.config.get("share_memory"):
+ self.graph.share_memory()
+
+ def __len__(self) -> int:
+ """ Get the number of samples in the dataset. """
+ return self.config.sample_per_epoch
+
+ def filter_edge(self, eid: int) -> bool:
+ """ Filter edges according to the length of the corresponding sentence and the size of its neighborhoods. """
+ edge: torch.Tensor = self.graph.edges[eid]
+ if self.graph.sentences[edge[2]].shape[0] > self.config.max_sentence_length:
+ return False
+ if self.config.get("filter_empty_neighborhood") and (self.graph.degree(edge[0]) <= 1 or self.graph.degree(edge[1]) <= 1):
+ return False
+ return True
+
+ def sample_main(self) -> int:
+ """ Sample the main edge, from which positive and negative edges can be selected. """
+ # From Soares et al.
+ # "To prevent a large bias towards relation statements that involve popular entities, we limit the number of relation statements that contain the same entity by randomly sampling a constant number of relation statements that contain any given entity."
+ # It's hard to guess what was exactly done, so we propose several sampling strategies.
+ while True:
+ if self.config.edge_sampling == "uniform-uniform":
+ vid: int = self.rng.randint(0, self.graph.order-1)
+ reid: int = self.rng.randint(0, self.graph.degree(vid)-1)
+ eid: int = self.graph.adj[vid][reid, 1]
+ elif self.config.edge_sampling == "uniform-inverse degree":
+ vid: int = self.rng.randint(0, self.graph.order-1)
+
+ v2_candidates: torch.Tensor = torch.zeros(self.graph.degree(vid))
+ for i, edge in enumerate(self.graph.adj[vid]):
+ v2_candidates[i] = self.graph.degree(edge[0])
+ v2_candidates /= torch.nn.functional.normalize(v2_candidates, p=1, dim=0)
+
+ # FIXME slow, double check worker asynchronicity
+ reid: int = torch.multinomial(v2_candidates, 1).item()
+ eid: int = self.graph.adj[vid][reid, 1]
+ else:
+ raise RuntimeError("Unsuported config value for edge_sampling")
+ if self.filter_edge(eid):
+ return eid
+
+ def eid_to_sample(self, first_eid: int, second_eid: int, polarity: int) -> Dict[str, Any]:
+ """ Build a pair with the given polarity from two edge ids. """
+ first_edge: torch.Tensor = self.graph.edges[first_eid].clone()
+ second_edge: torch.Tensor = self.graph.edges[second_eid].clone()
+
+ self.shuffle_entities(first_edge)
+ self.align_entities_as(second_edge, first_edge)
+
+ sample: Dict[str, Any] = {"polarity": polarity}
+ sample.update(self.edge_to_features(first_eid, first_edge, "first_", mlm=True))
+ sample.update(self.edge_to_features(second_eid, second_edge, "second_", mlm=False))
+ return sample
+
+ @staticmethod
+ def invert_entities(edge: torch.Tensor) -> None:
+ """ Invert the <e1> and <e2> tags of the edge, the text of the entities are not inverted, only the tags. """
+ # invert vertex ids
+ tmp = edge[0].clone()
+ edge[0] = edge[1]
+ edge[1] = tmp
+
+ # invert entity positions
+ tmp = edge[3:5].clone()
+ edge[3:5] = edge[5:7]
+ edge[5:7] = tmp
+
+ def shuffle_entities(self, edge: torch.Tensor) -> None:
+ """ Invert the <e1> and <e2> tags with probability ½. """
+ if self.rng.randint(0, 1):
+ self.invert_entities(edge)
+
+ @staticmethod
+ def align_entities_as(edge: torch.Tensor, pattern: torch.Tensor) -> None:
+ """ Invert entities of an edge if neither of them are in the same position as in the provided pattern. """
+ if edge[0] != pattern[0] and edge[1] != pattern[1]:
+ UnsupervisedDataset.invert_entities(edge)
+
+ def mlm_features(self, text: List[int], prefix: str) -> Dict[str, Any]:
+ """ Extract mlm_input and mlm_target for masked language model loss. """
+ # Function inspired by HuggingFace's code
+ mlm_target = torch.tensor(text, dtype=torch.int64)
+ mlm_input = torch.tensor(text, dtype=torch.int64)
+
+ # Do not mask special tokens
+ st_mask = self.tokenizer.get_special_tokens_mask(text, already_has_special_tokens=True)
+ st_mask = torch.tensor(st_mask, dtype=torch.bool)
+
+ mlm_p = torch.full((len(text),), self.config.mlm_probability)
+ mlm_mask = torch.bernoulli(mlm_p).bool() & st_mask
+ mlm_target[~mlm_mask] = -100
+
+ masked_p = self.config.mlm_masked_probability
+ masked_mask = torch.bernoulli(torch.full((len(text),), masked_p)).bool() & mlm_mask
+ mlm_input[masked_mask] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
+
+ random_p = self.config.mlm_random_probability / (1-masked_p)
+ random_mask = torch.bernoulli(torch.full((len(text),), random_p)).bool() & mlm_mask & ~masked_mask
+ random_value = torch.randint(len(self.tokenizer), (len(text),), dtype=torch.long)
+ mlm_input[random_mask] = random_value[random_mask]
+
+ return {f"{prefix}mlm_input": mlm_input, f"{prefix}mlm_target": mlm_target}
+
+ def edge_to_features(self, eid: int, edge: torch.Tensor, prefix: str, mlm: bool) -> Dict[str, Any]:
+ """
+ Convert an edge to the corresponding set of features (token list of the sentence, etc).
+
+ If mlm is True, features for Masked Language Model training are also generated.
+ """
+ sample: Dict[str, Any] = {}
+ sample[f"{prefix}edge_identifier"] = eid
+ sample[f"{prefix}entity_identifiers"] = edge[0:2]
+ sample[f"{prefix}entity_degrees"] = torch.tensor([self.graph.degree(edge[0]), self.graph.degree(edge[1])], dtype=torch.int64)
+ text: List[int] = self.graph.sentences[edge[2]].tolist()
+
+ # Abuse the fact that "</eX>" < "<eX>"
+ tags: List[Tuple[int, str]] = [(edge[3], "<e1>"), (edge[4], "</e1>"), (edge[5], "<e2>"), (edge[6], "</e2>")]
+ tags.sort(reverse=True)
+
+ # When we see a start tag <eX>, we know the last tag was </eX>
+ last_position: int = -1
+ for position, tag in tags:
+ if tag.startswith("<e"): # begin tag
+ if self.rng.random() < self.config.get("blank_probability", 0):
+ del text[position:last_position]
+ text.insert(position, self.tokenizer.convert_tokens_to_ids("<blank/>"))
+ text.insert(position, self.tokenizer.convert_tokens_to_ids(tag))
+ last_position = position
+
+ if mlm:
+ sample.update(self.mlm_features(text, prefix))
+
+ sample[f"{prefix}text"] = torch.tensor(text, dtype=torch.int32)
+ sample[f"{prefix}entity_positions"] = torch.tensor([
+ text.index(self.tokenizer.convert_tokens_to_ids("<e1>")),
+ text.index(self.tokenizer.convert_tokens_to_ids("<e2>"))
+ ], dtype=torch.int64)
+ return sample
+
+ def sample_parallel(self) -> Dict[str, Any]:
+ """ Sample two parallel edges and create a positive pair from them. """
+ while True:
+ first_eid: int = self.sample_main()
+ if self.graph.eid_simple_adjacency(first_eid):
+ # This edge has no parallel edges from which a positive can be selected
+ continue
+
+ adjacency_range: Tuple[int, int] = self.graph.eid_adjacency_range(first_eid)
+ if self.graph.edges[adjacency_range[0], 2] == self.graph.edges[adjacency_range[1]-1, 2]:
+ # All the edges are caused by repetition of an entity in the same sentence
+ continue
+
+ # Avoid the range of parallel edges sharing the same sentence
+ sentence_range: Tuple[int, int] = self.graph.eid_adjacency_range(first_eid, prefix=3)
+ sentence_card: int = sentence_range[1] - sentence_range[0]
+
+ second_eid: int = self.rng.randint(adjacency_range[0], adjacency_range[1]-sentence_card-1)
+ if second_eid >= sentence_range[0]:
+ second_eid += sentence_card
+
+ if self.filter_edge(second_eid):
+ return self.eid_to_sample(first_eid, second_eid, 1)
+
+ def sample_strong_negative(self) -> Dict[str, Any]:
+ """ Sample a strong negative edge around the two given vertices. """
+ # TODO consider biaising the sampling away from popular entities here too.
+ while True:
+ first_eid: int = self.sample_main()
+ adjacency_range: Tuple[int, int] = self.graph.eid_adjacency_range(first_eid)
+ adjacency_size: int = adjacency_range[1] - adjacency_range[0]
+ vid1: int = self.graph.edges[first_eid, 0]
+ vid2: int = self.graph.edges[first_eid, 1]
+ vertex1_degree: int = self.graph.degree(vid1)
+ vertex2_degree: int = self.graph.degree(vid2)
+ if vertex1_degree + vertex2_degree <= 2 * adjacency_size:
+ # This edge has no other incident edges from which a negative can be selected
+ continue
+
+ second_reid: int = self.rng.randint(0, vertex1_degree + vertex2_degree - 2 * adjacency_size - 1)
+ if second_reid < vertex1_degree - adjacency_size:
+ first_reid_begin: int = self.graph.reid_adjacency_begin(vid1, vid2)
+ if second_reid >= first_reid_begin:
+ second_reid += adjacency_size
+ second_eid: int = self.graph.adj[vid1][second_reid, 1]
+ else:
+ second_reid -= vertex1_degree - adjacency_size
+ first_reid_begin: int = self.graph.reid_adjacency_begin(vid2, vid1)
+ if second_reid >= first_reid_begin:
+ second_reid += adjacency_size
+ second_eid: int = self.graph.adj[vid2][second_reid, 1]
+
+ if self.filter_edge(second_eid):
+ return self.eid_to_sample(first_eid, second_eid, -1)
+
+ def sample_weak_negative(self) -> Dict[str, Any]:
+ while True:
+ first_eid: int = self.sample_main()
+ second_eid: int = self.sample_main()
+ entities: Set[int] = set([
+ self.graph.edges[first_eid, 0],
+ self.graph.edges[first_eid, 1],
+ self.graph.edges[second_eid, 0],
+ self.graph.edges[second_eid, 1]])
+ if len(entities) == 4:
+ return self.eid_to_sample(first_eid, second_eid, -1)
+
+ def sample_triplet(self) -> Dict[str, Any]:
+ sample: Dict[str, Any] = {}
+ for prefix in ["first_", "second_", "third_"]:
+ eid: int = self.sample_main()
+ edge: torch.Tensor = self.graph.edges[eid].clone()
+ self.shuffle_entities(edge)
+ sample.update(self.edge_to_features(eid, edge, prefix, mlm=(prefix == "first_" and self.config.get("language_model_weight", 0) > 0)))
+ return sample
+
+ def sample(self) -> Dict[str, Any]:
+ """ Generate a single sample from the dataset. """
+ if self.config.unsupervised == "mtb":
+ p: float = self.rng.random()
+ if p < self.config.strong_negative_probability:
+ return self.sample_strong_negative()
+ elif p < self.config.strong_negative_probability + self.config.weak_negative_probability:
+ return self.sample_weak_negative()
+ else:
+ return self.sample_parallel()
+ elif self.config.unsupervised == "triplet":
+ return self.sample_triplet()
+ else:
+ raise RuntimeError(f"Unknown unsupervised mode {self.config.unsupervised}.")
+
+ def __iter__(self) -> Iterator[Dict[str, Any]]:
+ """ Generate samples from the dataset. """
+ worker_info = torch.utils.data.get_worker_info()
+ if worker_info is None:
+ sample_count: int = len(self)
+ else:
+ sample_count: int = len(self) // worker_info.num_workers
+ sample_count += (worker_info.id < (len(self) % worker_info.num_workers))
+
+ for index in range(sample_count):
+ yield self.sample()
+
+
+TYPE_MAGIC: Dict[str, torch.utils.data.Dataset] = {
+ "supervised": SupervisedDataset,
+ "fewshot": FewShotDataset,
+ "sampled fewshot": SampledFewShotDataset,
+ "unsupervised": UnsupervisedDataset # not normally used
+ }
+
+
+class GraphAdapter(UnsupervisedDataset):
+ """
+ Post-process a Dataset to add graph features.
+
+ The new features include neighborhood_text, neighborhood_entity_identifiers, etc and are extracted from the entity_identifiers features present in the original sample.
+ """
+ def __init__(self, dataset: torch.utils.data.Dataset, entity_dictionary: gbure.data.dictionary.Dictionary, path: pathlib.Path, graph: Optional[gbure.data.graph.Graph]) -> None:
+ if isinstance(dataset, UnsupervisedDataset) or graph is not None:
+ super().__init__(dataset.config, None, dataset.tokenizer, dataset.evaluation, None)
+ if graph is not None:
+ self.graph = graph
+ else:
+ self.graph = dataset.graph
+ else:
+ super().__init__(dataset.config, path, dataset.tokenizer, dataset.evaluation, None)
+ self.dataset = dataset
+ self.entity_dictionary = entity_dictionary
+
+ def empty_neighborhood(self, prefix: str) -> Dict[str, Any]:
+ neighborhood_size: int = self.config.neighborhood_size
+ if not self.evaluation and self.config.get("filter_empty_neighborhood"):
+ return {}
+ # FIXME We pad to the same number of neighbors for now, since Batcher.process_int_feature does not support neighborhoods of different sizes yet.
+ # Once it is implemented, we can set neighborhood_size = 0
+ return {f"{prefix}edge_identifier": torch.full((neighborhood_size,), -1, dtype=torch.int64),
+ f"{prefix}entity_identifiers": torch.full((neighborhood_size, 2), -1, dtype=torch.int64),
+ f"{prefix}entity_degrees": torch.zeros((neighborhood_size, 2), dtype=torch.int64),
+ f"{prefix}text": [torch.zeros((0,), dtype=torch.int64) for _ in range(neighborhood_size)],
+ f"{prefix}entity_positions": torch.zeros((neighborhood_size, 2), dtype=torch.int64)}
+
+ def sample_neighborhood(self, vid: int, exclude: Optional[int], incoming: bool, prefix: str) -> Dict[str, Any]:
+ """ Sample the neighborhood around the given vertex, excluding a given edge. """
+ number_reids: int = self.graph.degree(vid) - (0 if exclude is None else 1)
+ if number_reids <= 0:
+ reids: List[int] = []
+ elif number_reids <= self.config.neighborhood_size:
+ reids: List[int] = list(range(number_reids)) + self.rng.choices(range(number_reids), k=self.config.neighborhood_size-number_reids)
+ else:
+ reids: List[int] = self.rng.sample(range(number_reids), self.config.neighborhood_size)
+
+ neighbors: List[Dict[str, Any]] = []
+ for reid in reids:
+ eid: int = self.graph.adj[vid][reid, 1]
+ if exclude is not None and eid == exclude:
+ eid = self.graph.adj[vid][-1, 1]
+ edge: torch.Tensor = self.graph.edges[eid].clone()
+ if edge[int(incoming)] != vid:
+ self.invert_entities(edge)
+ neighbors.append(self.edge_to_features(eid, edge, "", mlm=False))
+
+ if not neighbors:
+ return self.empty_neighborhood(prefix)
+
+ sample: Dict[str, Any] = {}
+ for feature in neighbors[0].keys():
+ if feature == "text":
+ sample[f"{prefix}text"] = [neighbor["text"] for neighbor in neighbors]
+ else:
+ sample[f"{prefix}{feature}"] = torch.stack([neighbor[feature] for neighbor in neighbors])
+ return sample
+
+ def sample_neighborhoods(self, head: int, tail: int, eid: Optional[int], prefix: str) -> Dict[str, Any]:
+ """
+ Sample the neighborhood around the given edge.
+
+ edge should be self.graph.edges[eid], optionaly with the entities reversed.
+ """
+ head: Optional[int] = self.graph.entity_dictionary.encoder.get(self.entity_dictionary.decode(head))
+ tail: Optional[int] = self.graph.entity_dictionary.encoder.get(self.entity_dictionary.decode(tail))
+
+ sample: Dict[str, Any] = {}
+ if head is None:
+ e1 = self.empty_neighborhood(f"{prefix}e1_neighborhood_")
+ e1_degree: int = 0
+ else:
+ e1 = self.sample_neighborhood(head, eid, True, f"{prefix}e1_neighborhood_")
+ e1_degree: int = self.graph.degree(head)
+ if not e1:
+ return {}
+ sample.update(e1)
+ if tail is None:
+ e2 = self.empty_neighborhood(f"{prefix}e2_neighborhood_")
+ e2_degree: int = 0
+ else:
+ e2 = self.sample_neighborhood(tail, eid, True, f"{prefix}e2_neighborhood_")
+ e2_degree: int = self.graph.degree(tail)
+ if not e2:
+ return {}
+ sample.update(e2)
+ sample[f"{prefix}entity_degrees"] = torch.tensor([e1_degree, e2_degree], dtype=torch.int64)
+ return sample
+
+ def adapt(self, sample: Dict[str, Any]) -> bool:
+ """
+ Add neighborhood features to sample.
+
+ Returns whether the sample should be kept.
+ """
+ extras: Dict[str, Any] = {}
+ for feature in sample:
+ if feature.endswith("entity_identifiers"):
+ prefix: str = feature[:-len("entity_identifiers")]
+ entity_identifiers: torch.Tensor = sample[f"{prefix}entity_identifiers"]
+
+ # If the underlying dataset is built upon a graph, we can exclude the main sample edge, otherwise there is no risk of sampling an edge as being its own neighbor.
+ eid: Optional[Union[int, torch.Tensor]] = sample.get(f"{prefix}edge_identifier")
+
+ if prefix == "candidates_":
+ prefix_extras: Dict[str, Any] = collections.defaultdict(list)
+ for i, way in enumerate(entity_identifiers):
+ extras_way: Dict[str, List[Any]] = collections.defaultdict(list)
+ for j, shot in enumerate(way):
+ eid: Optional[int] = None if eid is None else eid[i, j].item()
+ extras_shot: Dict[str, Any] = self.sample_neighborhoods(shot[0], shot[1], eid, prefix)
+ if not extras_shot:
+ return False
+ for feature, value in extras_shot.items():
+ extras_way[feature].append(value)
+ for feature, values in extras_way.items():
+ prefix_extras[feature].append(values if feature.endswith("text") else torch.stack(values))
+ for feature, values in prefix_extras.items():
+ prefix_extras[feature] = values if feature.endswith("text") else torch.stack(values)
+ else:
+ prefix_extras: Dict[str, Any] = self.sample_neighborhoods(entity_identifiers[0], entity_identifiers[1], eid, prefix)
+ if prefix_extras:
+ extras.update(prefix_extras)
+ else:
+ return False
+ if f"{prefix}entity_degrees" not in sample and f"{prefix}entity_degrees" not in extras:
+ extras[f"{prefix}entity_degrees"] = torch.zeros_like(extras[f"{prefix}entity_identifiers"], dtype=torch.int64)
+ sample.update(extras)
+ return True
+
+ def __len__(self) -> int:
+ return len(self.dataset)
+
+ def process_sample(self, sample: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
+ # TODO define config value to repeat the sampling of neighbors
+ if self.config.get("neighborhood_size", 0) > 0:
+ if self.adapt(sample):
+ yield sample
+ else:
+ yield sample
+
+ def __iter__(self) -> Iterator[Dict[str, Any]]:
+ if isinstance(self.dataset, torch.utils.data.IterableDataset):
+ for sample in self.dataset:
+ yield from self.process_sample(sample)
+ else: # Map-style dataset
+ worker_info = torch.utils.data.get_worker_info()
+ if worker_info is None:
+ worker_modulo: int = 1
+ worker_residue: int = 0
+ else:
+ worker_modulo: int = worker_info.num_workers
+ worker_residue: int = worker_info.id
+
+ for i in range(worker_residue, len(self.dataset), worker_modulo):
+ yield from self.process_sample(self.dataset[i])
+
+
+def load_dataset(config: gbure.utils.dotdict, split: str, path: pathlib.Path, **kwargs) -> torch.utils.data.Dataset:
+ if split == "train" and config.get("unsupervised"):
+ return UnsupervisedDataset(config=config, path=path, **kwargs)
+
+ dstype: str
+ data: Any
+ dstype, data = torch.load(path)
+ return TYPE_MAGIC[dstype](config=config, path=path, data=data, **kwargs)
diff --git a/gbure/data/dictionary.py b/gbure/data/dictionary.py
@@ -0,0 +1,112 @@
+from typing import Any, Dict, List, Optional
+import pathlib
+import pickle
+
+
+class Dictionary:
+ keys: List[str] = ["keys", "unknown", "decoder", "encoder"]
+
+ def __init__(self, *, unknown: Optional[str] = None, path: Optional[pathlib.Path] = None) -> None:
+ self.encoder: Dict[str, int] = {}
+ self.decoder: List[str] = []
+
+ self.unknown: Optional[str] = unknown
+ if unknown is not None:
+ self.encoder[unknown] = 0
+ self.decoder.append(unknown)
+
+ if path is not None:
+ self.load(path)
+
+ def __len__(self) -> int:
+ """ Number of tokens in the dictionary. """
+ return len(self.decoder)
+
+ def encode(self, token: str) -> int:
+ """ Returns the id corresponding to a token. """
+ id: Optional[int] = self.encoder.get(token)
+ if id is not None:
+ return id
+
+ id = len(self.decoder)
+ self.encoder[token] = id
+ self.decoder.append(token)
+ return id
+
+ def decode(self, id: int) -> str:
+ """ Returns the token corresponding to an id. """
+ return self.decoder[id]
+
+ def save(self, path: pathlib.Path) -> None:
+ with path.open("wb") as file:
+ pickle.dump({key: getattr(self, key) for key in self.keys}, file)
+
+ def load(self, path: pathlib.Path) -> None:
+ with path.open("rb") as file:
+ data: Dict[str, Any] = pickle.load(file)
+ for key, value in data.items():
+ setattr(self, key, value)
+
+
+class RelationDictionary(Dictionary):
+ """
+ A dictionary to be used for relations.
+
+ The tokens held by this class are divided between:
+ - *relation* such as "Entity-Destination(e1,e2)"
+ - *base* such as "Entity-Destination"
+ """
+
+ keys = ["keys", "unknown", "decoder", "encoder", "base_encoder", "base_decoder", "id_to_bid"]
+
+ def __init__(self, *, unknown: Optional[str] = None, path: Optional[pathlib.Path] = None) -> None:
+ self.base_encoder: Dict[str, int] = {}
+ self.base_decoder: List[str] = []
+ self.id_to_bid: List[int] = []
+
+ if unknown is not None:
+ self.base_encoder[unknown] = 0
+ self.base_decoder.append(unknown)
+ self.id_to_bid.append(0)
+
+ super().__init__(unknown=unknown, path=path)
+
+ def base_size(self) -> int:
+ """ Number of bases in the dictionary. """
+ return len(self.base_decoder)
+
+ def encode(self, relation: str, base: Optional[str] = None) -> int:
+ """
+ Returns the id corresponding to a relation string.
+
+ If base is none, do not attempt to insert a new id and returns the id of relation immediately
+
+ Args:
+ relation: the string of the relation (e.g. "Entity-Destination(e1,e2)")
+ base: the string of the base relation (e.g. "Entity-Destination")
+ """
+ if relation is None:
+ return None
+
+ if base is None:
+ return self.encoder[relation]
+
+ id: Optional[int] = self.encoder.get(relation)
+ if id is not None:
+ return id
+
+ bid: Optional[int] = self.base_encoder.get(base)
+ if bid is None:
+ bid = len(self.base_decoder)
+ self.base_encoder[base] = bid
+ self.base_decoder.append(base)
+
+ id = len(self.decoder)
+ self.encoder[relation] = id
+ self.decoder.append(relation)
+ self.id_to_bid.append(bid)
+ return id
+
+ def base_id(self, id: int) -> int:
+ """ Returns the base id corresponding to a relation id. """
+ return self.id_to_bid[id]
diff --git a/gbure/data/graph.py b/gbure/data/graph.py
@@ -0,0 +1,146 @@
+from typing import List, Optional, Tuple, Union
+import pathlib
+
+import torch
+import tqdm
+import transformers
+
+from gbure.data.dictionary import Dictionary
+from gbure.utils import SharedLongTensorList
+
+
+class Graph:
+ """
+ Graph represented by an adjacency list.
+
+ The node of the Graph correspond to entities, an edge between node n1 and node n2 indicates that the two corresponding entities appear together in a sentence, this sentence being the label of the edge.
+ To avoid storing the same sentence several time, a list of sentences is stored and the edges are labeled with the sentence id and entities postions.
+
+ The data is stored in four objects:
+ sentences: sid -> list of tokens without tags (<e1>, etc)
+ entity_dictionary: KB entity identifier -> vid (vertex id)
+ adj: source vid -> [(destination vid, eid)] sorted in lexicographic order
+ edges: eid (edge id) -> (e1 vid, e2 vid, sid, e1 start position, e1 end position, e2 start position, e2 end position)
+ """
+
+ def __init__(self,
+ sentences: Optional[List[torch.Tensor]] = None,
+ entity_dictionary: Optional[Dictionary] = None,
+ degrees: Optional[List[int]] = None,
+ edges: Optional[List[Tuple[int, int, int, int, int, int, int]]] = None,
+ *, path: Optional[pathlib.Path] = None) -> None:
+ """ Initialize a Graph, either with the provided data or by loading the given file. """
+ if path is not None:
+ assert(sentences is None and entity_dictionary is None and degrees is None and edges is None)
+ self.load(path)
+ else:
+ assert(sentences is not None and entity_dictionary is not None and degrees is not None and edges is not None)
+ self.sentences: Union[List[torch.Tensor], SharedLongTensorList] = sentences
+ self.entity_dictionary: Dictionary = entity_dictionary
+ self.compile_edges_adj(edges, degrees)
+
+ def compile_edges_adj(self, edges: List[Tuple[int, int, int, int, int, int, int]], degrees: List[int]) -> None:
+ """ Compile the edge and adjacency list of the graph. """
+ edge_order: List[int] = sorted(range(len(edges)), key=edges.__getitem__)
+
+ self.edges: torch.Tensor = torch.empty((len(edges), 7), dtype=torch.int32)
+ self.adj: Union[List[torch.Tensor], SharedLongTensorList] = [torch.empty((degree, 2), dtype=torch.int32) for degree in degrees]
+
+ for new_id, old_id in enumerate(tqdm.tqdm(edge_order, desc="compiling graph")):
+ edge: torch.Tensor = torch.tensor(edges[old_id], dtype=torch.int32)
+ self.edges[new_id] = edge
+ for source, destination in [(edge[0], edge[1]), (edge[1], edge[0])]:
+ degrees[source] -= 1
+ self.adj[source][degrees[source], 0] = destination
+ self.adj[source][degrees[source], 1] = new_id
+ # Free the memory along the way to avoid storing the graph twice
+ edges[old_id] = None
+
+ for i, vertex in enumerate(tqdm.tqdm(self.adj, desc="sorting adjacency list")):
+ self.adj[i] = torch.stack(sorted(vertex, key=torch.Tensor.tolist)) # pytype: disable=unsupported-operands
+
+ @property
+ def order(self) -> int:
+ """ Number of vertices. """
+ return len(self.adj)
+
+ @property
+ def size(self) -> int:
+ """ Number of edges. """
+ return self.edges.shape[0]
+
+ def degree(self, vertex: int) -> int:
+ """ Number of edges connected to a given vertex. """
+ return self.adj[vertex].shape[0]
+
+ def eid_simple_adjacency(self, eid: int) -> bool:
+ """ Decide whether a given edge is the sole edge between two nodes. """
+ if eid > 0 and (self.edges[eid, :2] == self.edges[eid-1, :2]).all():
+ return False
+ if eid < self.size-1 and (self.edges[eid, :2] == self.edges[eid+1, :2]).all():
+ return False
+ return True
+
+ def eid_adjacency_range(self, eid: int, prefix: int = 2) -> Tuple[int, int]:
+ """ Return the range of edges (in the global edge list) sharing the same end points as eid. """
+ range_start: int = eid
+ while range_start > 0 and (self.edges[eid, :prefix] == self.edges[range_start-1, :prefix]).all():
+ range_start -= 1
+
+ range_end: int = eid+1
+ while range_end < self.size and (self.edges[eid, :prefix] == self.edges[range_end, :prefix]).all():
+ range_end += 1
+
+ return range_start, range_end
+
+ def reid_adjacency_begin(self, source: int, destination: int) -> int:
+ """ Return the first edge from source to destination as relative index in source's adjacency list. """
+ left: int = 0
+ right: int = self.degree(source)
+ while left < right:
+ middle: int = (left + right) // 2
+ if self.adj[source][middle, 0] < destination:
+ left = middle + 1
+ else:
+ right = middle
+
+ return left
+
+ def tagged_sentence(self, eid: int, tokenizer: transformers.PreTrainedTokenizer, invert: bool = False) -> Tuple[List[int], int, int]:
+ """ Get the tagged sentence corresponding to an edge. """
+ edge: torch.Tensor = self.edges[eid]
+ text: List[int] = self.sentences[edge[2]].tolist()
+ # Abuse the fact that "</e1>" < "<e1>"
+ if invert:
+ tags: List[Tuple[int, str]] = [(edge[5], "<e1>"), (edge[6], "</e1>"), (edge[3], "<e2>"), (edge[4], "</e2>")]
+ else:
+ tags: List[Tuple[int, str]] = [(edge[3], "<e1>"), (edge[4], "</e1>"), (edge[5], "<e2>"), (edge[6], "</e2>")]
+ tags.sort(reverse=True)
+ for position, tag in tags:
+ text.insert(position, tokenizer.convert_tokens_to_ids(tag))
+ e1_pos: int = self.edges[eid, 5 if invert else 3].item()
+ e2_pos: int = self.edges[eid, 3 if invert else 5].item()
+ if e1_pos < e2_pos:
+ e2_pos += 2
+ else:
+ e1_pos += 2
+ return text, e1_pos, e2_pos
+
+ def save(self, path: pathlib.Path) -> None:
+ """ Save the graph to the given directory. """
+ if not path.is_dir():
+ path.mkdir()
+
+ self.entity_dictionary.save(path / "entities")
+ for attribute in ["sentences", "edges", "adj"]:
+ torch.save(getattr(self, attribute), path / attribute)
+
+ def load(self, path: pathlib.Path) -> None:
+ """ Load a graph from the given directory. """
+ self.entity_dictionary = Dictionary(path=path / "entities")
+ for attribute in ["sentences", "edges", "adj"]:
+ setattr(self, attribute, torch.load(path / attribute))
+
+ def share_memory(self) -> None:
+ self.sentences = SharedLongTensorList(self.sentences)
+ self.adj = SharedLongTensorList(self.adj, [-1, 2])
diff --git a/gbure/data/prepare_fewrel.py b/gbure/data/prepare_fewrel.py
@@ -0,0 +1,65 @@
+from typing import Any, Dict, Iterable, Tuple
+import argparse
+import json
+import pathlib
+
+import tqdm
+
+from gbure.utils import DATA_PATH
+import gbure.data.preprocessing as preprocessing
+
+DATASET_PATH: pathlib.Path = DATA_PATH / "FewRel"
+DOWNLOAD_URL: str = "https://thunlp.oss-cn-qingdao.aliyuncs.com/fewrel/"
+FILES_SHA512: Dict[str, str] = {
+ "fewrel_train.json": "2ec687d16999bd59bbcac39fdfed319cee3bec14963717c6ee262da981ad64e58b93d5c95a6ea4f6c5fe9c3d09a57d098d25ad61955f4df5f910bc28718e8220",
+ "fewrel_val.json": "32bdac8c9aba880484d00417d823a310657acca5604a06fa7c4c01f8dfb54b9e05ecf67c0c374aa61bd42cb9e90cd62c0d50377cce2d6bc8fa6d1fbbb61d0f5e"
+ }
+
+
+def get_data() -> None:
+ """ Download FewRel's train and val json files if needed. """
+ for filename, sha512 in FILES_SHA512.items():
+ if not (DATASET_PATH / filename).exists():
+ preprocessing.download(DOWNLOAD_URL + filename, DATASET_PATH / filename, filename, sha512)
+
+
+def read_data(path: pathlib.Path) -> Iterable[Tuple[str, str, str, str, str]]:
+ """ Read a FewRel json file and return (text, relation, relation_base, e1, e2) tuples. """
+ with open(path) as file:
+ data = json.load(file)
+
+ for relation, relset in tqdm.tqdm(data.items(), desc=f"loading {path.name}"):
+ relation_dataset = []
+ for sentence in relset:
+ # We assume we know the direction of the relation
+ yield (process_sentence(sentence), relation, relation, sentence["h"][1][1:], sentence["t"][1][1:])
+
+
+def process_sentence(sentence: Dict[str, Any]) -> str:
+ """ Transform a FewRel json sentence object to a tagged sentence string. """
+ tokens = sentence["tokens"]
+ tokens[sentence["h"][2][0][0]] = "<e1>" + tokens[sentence["h"][2][0][0]]
+ tokens[sentence["h"][2][0][-1]] += "</e1>"
+ tokens[sentence["t"][2][0][0]] = "<e2>" + tokens[sentence["t"][2][0][0]]
+ tokens[sentence["t"][2][0][-1]] += "</e2>"
+ return " ".join(tokens)
+
+
+def read_splits() -> Dict[str, Iterable[Tuple[str, str, str, str, str]]]:
+ return {"train": read_data(DATASET_PATH / "fewrel_train.json"),
+ "valid": read_data(DATASET_PATH / "fewrel_val.json")}
+
+
+if __name__ == "__main__":
+ parser: argparse.ArgumentParser = preprocessing.base_argument_parser("Prepare the few shot FewRel dataset.")
+ args: argparse.Namespace = parser.parse_args()
+ name: str = preprocessing.dataset_name(args)
+
+ get_data()
+ preprocessing.serialize_dataset(
+ supervision="fewshot",
+ path=DATASET_PATH / name,
+ splits=read_splits(),
+ unknown_entity=None,
+ unknown_relation=None,
+ **preprocessing.args_to_serialize(args))
diff --git a/gbure/data/prepare_kbp37.py b/gbure/data/prepare_kbp37.py
@@ -0,0 +1,41 @@
+from typing import Dict, Iterable, Tuple
+import argparse
+import pathlib
+
+import tqdm
+
+from gbure.utils import DATA_PATH
+import gbure.data.prepare_semeval
+import gbure.data.preprocessing as preprocessing
+
+DATASET_PATH: pathlib.Path = DATA_PATH / "KBP37"
+COMMIT_ID: str = "7d88486ad632a9c6e9fe6adbc2468049e89bc11d"
+DIRECTORY_NAME: str = f"kbp37-{COMMIT_ID}"
+ARCHIVE_NAME: str = "kbp37_data.zip"
+ARCHIVE_SHA512: str = "f6661df79d327a34ad4198f0405d7c06e05af4d9aab3723282c02270394c6511df41a330118d2be89f390b9d36c1eb2a32800db01a3d4387b990493befb011ac"
+DOWNLOAD_URL: str = f"https://github.com/zhangdongxu/kbp37/archive/{COMMIT_ID}.zip"
+
+TRAIN_SIZE: int = 15917
+VALID_SIZE: int = 1724
+TEST_SIZE: int = 3405
+UNKNOWN_RELATION: str = "no_relation"
+
+
+def read_splits() -> Dict[str, Iterable[Tuple[str, str, str, str, str]]]:
+ return {"train": gbure.data.prepare_semeval.read_data(DATASET_PATH / DIRECTORY_NAME / "train.txt", TRAIN_SIZE),
+ "valid": gbure.data.prepare_semeval.read_data(DATASET_PATH / DIRECTORY_NAME / "dev.txt", VALID_SIZE),
+ "test": gbure.data.prepare_semeval.read_data(DATASET_PATH / DIRECTORY_NAME / "test.txt", TEST_SIZE)}
+
+
+if __name__ == "__main__":
+ parser: argparse.ArgumentParser = preprocessing.base_argument_parser("Prepare the supervised KBP37 dataset.")
+ args: argparse.Namespace = parser.parse_args()
+ name: str = preprocessing.dataset_name(args)
+
+ preprocessing.get_zip_data(DATASET_PATH, DIRECTORY_NAME, ARCHIVE_NAME, ARCHIVE_SHA512, DOWNLOAD_URL)
+ preprocessing.serialize_dataset(
+ supervision="supervised",
+ path=DATASET_PATH / name,
+ splits=read_splits(),
+ unknown_relation=UNKNOWN_RELATION,
+ **preprocessing.args_to_serialize(args))
diff --git a/gbure/data/prepare_sampled_fewrel.py b/gbure/data/prepare_sampled_fewrel.py
@@ -0,0 +1,65 @@
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+import argparse
+import hashlib
+import itertools
+import json
+import pathlib
+
+import tqdm
+
+from gbure.utils import DATA_PATH
+from gbure.data.prepare_fewrel import DATASET_PATH, process_sentence
+import gbure.data.preprocessing as preprocessing
+
+
+def read_entry(entry: Dict[str, Any]) -> Tuple[str, str, str]:
+ """ Convert the json object of an entry to a preprocessing tuple (sentence, e1, e2). """
+ return process_sentence(entry), entry["h"][1], entry["t"][1]
+
+
+def read_file(inpath: pathlib.Path, outpath: Optional[pathlib.Path]) -> Iterable[Tuple[Tuple[str, str, str], List[List[Tuple[str, str, str]]], int]]:
+ """
+ Yield (test, [[train]]) pair from a file of samples.
+
+ Each input is a triplet (sentence, head entity, tail entity).
+ """
+ with open(inpath) as file:
+ data = json.load(file)
+
+ if outpath:
+ with open(outpath) as file:
+ answers = json.load(file)
+ else:
+ answers = itertools.repeat(-1)
+
+ for problem, answer in zip(tqdm.tqdm(data, desc=f"processing {inpath.name}"+(" and "+outpath.name if outpath else "")), answers):
+ test = read_entry(problem["meta_test"])
+ train = list(map(lambda candidates: list(map(read_entry, candidates)), problem["meta_train"]))
+ yield (test, train, answer)
+
+
+if __name__ == "__main__":
+ parser: argparse.ArgumentParser = preprocessing.base_argument_parser("Prepare a sampled few shot FewRel dataset (generated by sample_io.py).", deterministic=True)
+ parser.add_argument("inpath",
+ type=pathlib.Path,
+ help="Path to the file containing the input ")
+ parser.add_argument("outpath",
+ type=pathlib.Path,
+ nargs="?",
+ help="Path to the file containing the output (optional)")
+ parser.add_argument("-S", "--suffix",
+ type=str,
+ default="",
+ help="Suffix to add to the tokenizer to find the dataset.")
+
+ args: argparse.Namespace = parser.parse_args()
+ hashid: str = preprocessing.hash_file(args.inpath)[:8]
+ if args.outpath:
+ hashid += preprocessing.hash_file(args.outpath)[:8]
+ name: str = preprocessing.dataset_name(args, args.suffix)
+
+ preprocessing.serialize_fewshot_sampled_split(
+ path=DATASET_PATH / name,
+ name=hashid,
+ split=read_file(args.inpath, args.outpath),
+ **preprocessing.args_to_serialize(args))
diff --git a/gbure/data/prepare_semeval.py b/gbure/data/prepare_semeval.py
@@ -0,0 +1,83 @@
+from typing import Dict, Iterable, Tuple
+import argparse
+import pathlib
+import random
+
+import tqdm
+
+from gbure.utils import DATA_PATH
+import gbure.data.preprocessing as preprocessing
+
+DATASET_PATH: pathlib.Path = DATA_PATH / "SemEval 2010 Task 8"
+DIRECTORY_NAME: str = "SemEval2010_task8_all_data"
+ARCHIVE_NAME: str = f"{DIRECTORY_NAME}.zip"
+ARCHIVE_SHA512: str = "7ac2d71ba1772105c1f73e4278e4b85cebd9fb95187fb8a153c83215d890f0d2b98929fb2363e8c117d2e2c8e7f9926d7997e38fc57b5730fb00912ff376b66b"
+DOWNLOAD_URL: str = f"https://esimon.eu/GBURE/{ARCHIVE_NAME}"
+
+TRAIN_VALID_SIZE: int = 8000
+TEST_SIZE: int = 2717
+UNKNOWN_RELATION: str = "Other"
+
+
+def read_data(path: pathlib.Path, size: int) -> Iterable[Tuple[str, str, str, str, str]]:
+ """
+ Read a file in SemEval format and return (text, relation, relation_base, e1, e2) tuples.
+
+ For now, the entities are empty, we could encode them using their surface form or use a true entity linker if we want to use this information.
+ """
+ with path.open() as file:
+ for _ in tqdm.trange(size, desc=f"loading {path.name}"):
+ idtext_line: str = file.readline()
+ relation_line: str = file.readline()
+ file.readline() # Ignore Comment line
+ file.readline() # Ignore empty line
+
+ if not (idtext_line and relation_line):
+ break
+
+ id, raw_text = idtext_line.rstrip().split('\t')
+ text = raw_text[1:-1] # remove quotes around text
+ relation = relation_line.rstrip()
+
+ dir_start: int = relation.find('(')
+ relation_base: str = relation[:dir_start] if dir_start >= 0 else relation
+
+ # TODO handle entities
+ yield (text, relation, relation_base, "", "")
+
+
+def split_train_valid(data: Iterable[Tuple[str, str, str, str, str]], valid_size: int, seed: int) -> Tuple[Iterable[Tuple[str, str, str, str, str]], Iterable[Tuple[str, str, str, str, str]]]:
+ data = list(data)
+ rng = random.Random(seed)
+ rng.shuffle(data)
+
+ train = data[valid_size:]
+ valid = data[:valid_size]
+ return train, valid
+
+
+def read_splits(valid_size: int) -> Dict[str, Iterable[Tuple[str, str, str, str, str]]]:
+ splits = {}
+ train_valid = read_data(DATASET_PATH / DIRECTORY_NAME / "SemEval2010_task8_training" / "TRAIN_FILE.TXT", TRAIN_VALID_SIZE)
+ splits["train"], splits["valid"] = split_train_valid(train_valid, args.valid_size, args.seed)
+ splits["test"] = read_data(DATASET_PATH / DIRECTORY_NAME / "SemEval2010_task8_testing_keys" / "TEST_FILE_FULL.TXT", TEST_SIZE)
+ return splits
+
+
+if __name__ == "__main__":
+ parser: argparse.ArgumentParser = preprocessing.base_argument_parser("Prepare the supervised SemEval 2010 Task 8 dataset.")
+ parser.add_argument("-v", "--valid-size",
+ type=int,
+ default=1500,
+ help="Size of the validation set")
+
+ args: argparse.Namespace = parser.parse_args()
+ name: str = preprocessing.dataset_name(args, f"-v{args.valid_size}" if args.valid_size != 1500 else "")
+
+ preprocessing.get_zip_data(DATASET_PATH, DIRECTORY_NAME, ARCHIVE_NAME, ARCHIVE_SHA512, DOWNLOAD_URL)
+ preprocessing.serialize_dataset(
+ supervision="supervised",
+ path=DATASET_PATH / name,
+ splits=read_splits(args.valid_size),
+ unknown_relation=UNKNOWN_RELATION,
+ **preprocessing.args_to_serialize(args))
diff --git a/gbure/data/prepare_trex.py b/gbure/data/prepare_trex.py
@@ -0,0 +1,76 @@
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+import argparse
+import json
+import os
+import pathlib
+import tqdm
+
+from gbure.utils import DATA_PATH
+import gbure.data.preprocessing as preprocessing
+
+DATASET_PATH: pathlib.Path = DATA_PATH / "T-REx"
+DIRECTORY_NAME: str = "raw_data"
+ARCHIVE_NAME: str = f"T-REx.zip"
+ARCHIVE_SHA512: str = "30349fa6f01c1928ce15325521ebd05643787220f9a545eb23b280f9209cb1615f4a855b08604f943a1affb4d1f4f17b94f8434698f347a1cb7a0d820fa9de9f"
+DOWNLOAD_URL: str = f"https://esimon.eu/GBURE/{ARCHIVE_NAME}"
+
+
+def process_json_object(data: List[Dict[str, Any]]) -> Iterable[Tuple[str, List[Tuple[str, int, int]]]]:
+ """ Process a T-REx json object and return (sentence, list of entities) tuples. """
+ for article in data:
+ eid: int = 0
+ for sbs in article["sentences_boundaries"]:
+ entities: List[Tuple[str, int, int]] = []
+ while eid < len(article["entities"]) and article["entities"][eid]["boundaries"][0] < sbs[0]:
+ eid += 1
+
+ while eid < len(article["entities"]) and article["entities"][eid]["boundaries"][1] <= sbs[1]:
+ entity: Dict[str, Any] = article["entities"][eid]
+ eid += 1
+
+ # Ignore date entities
+ if entity["annotator"] != "Wikidata_Spotlight_Entity_Linker":
+ continue
+
+ uri: str = entity["uri"]
+ prefix: str = "http://www.wikidata.org/entity/Q"
+ assert(uri.startswith(prefix))
+ uri = uri[len(prefix):]
+ entities.append((uri, entity["boundaries"][0] - sbs[0], entity["boundaries"][1] - sbs[0]))
+
+ # ignore sentences with less than two entities
+ if len(entities) < 2:
+ continue
+
+ sentence = article["text"][sbs[0]:sbs[1]]
+ yield (sentence, entities)
+
+
+def read_data(subset: Optional[int]) -> Iterable[Tuple[str, List[Tuple[str, int, int]]]]:
+ """ Read all T-REx files and return (sentence, list of entities) tuples. """
+ filenames: List[str] = list(filter(lambda filename: filename.endswith(".json"), os.listdir(DATASET_PATH / DIRECTORY_NAME)))
+
+ # Make the order deterministic.
+ filenames.sort()
+ if subset is not None:
+ filenames = filenames[:subset]
+
+ for filename in tqdm.tqdm(filenames, desc="loading"):
+ with open(DATASET_PATH / DIRECTORY_NAME / filename, "r") as file:
+ data: List[Dict[str, Any]] = json.load(file)
+ yield from process_json_object(data)
+
+
+if __name__ == "__main__":
+ parser: argparse.ArgumentParser = preprocessing.base_argument_parser("Prepare the unsupervised TREx dataset.")
+ parser.add_argument("-S", "--subset",
+ type=int,
+ help="Number of file to process (default to all, only used for creating a debug dataset)")
+ args: argparse.Namespace = parser.parse_args()
+ name: str = preprocessing.dataset_name(args, "" if args.subset is None else f"-ss{args.subset}")
+
+ preprocessing.get_zip_data(DATASET_PATH, DIRECTORY_NAME, ARCHIVE_NAME, ARCHIVE_SHA512, DOWNLOAD_URL, unzip_directory=True)
+ preprocessing.serialize_unsupervised_dataset(
+ path=DATASET_PATH / name,
+ data=read_data(args.subset),
+ **preprocessing.args_to_serialize(args))
diff --git a/gbure/data/preprocessing.py b/gbure/data/preprocessing.py
@@ -0,0 +1,392 @@
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast
+import argparse
+import collections
+import hashlib
+import math
+import os
+import pathlib
+import random
+import urllib.request
+import zipfile
+
+import torch
+import tqdm
+import transformers
+
+from gbure.data.dictionary import Dictionary, RelationDictionary
+from gbure.data.graph import Graph
+
+
+def hash_file(path: pathlib.Path, filename: Optional[str] = None, filesize: Optional[int] = None) -> str:
+ """ Get a unique identifier for the file. """
+ hasher = hashlib.sha512()
+ with path.open("rb") as file:
+ loop = iter(lambda: file.read(2**16), b"")
+ if filename is not None and filesize is not None:
+ loop = tqdm.tqdm(loop,
+ desc=f"checking {filename} hash",
+ total=math.ceil(filesize / 2**16),
+ unit_scale=2**16, unit="B", unit_divisor=1024)
+ for chunk in loop:
+ hasher.update(chunk)
+ return hasher.hexdigest()
+
+
+def download(url: str, path: pathlib.Path, filename: str, sha512: str) -> None:
+ """ Download a file at the given path and check its hash. """
+ if not path.parent.is_dir():
+ path.parent.mkdir(parents=True)
+
+ unchecked: pathlib.Path = pathlib.Path(f"{path}.unchecked")
+ with tqdm.tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=f"downloading {filename}") as progress:
+ def report_hook(num_blocks: int, chunk_size: int, total_size: int):
+ if progress.total is None:
+ progress.total = total_size
+ progress.update(num_blocks * chunk_size - progress.n)
+ urllib.request.urlretrieve(url, unchecked, report_hook)
+
+ unchecked_hash = hash_file(unchecked, filename, progress.total)
+ if unchecked_hash != sha512:
+ raise RuntimeError(f"Downloaded file \"{filename}\" has wrong hash.")
+ os.rename(unchecked, path)
+
+
+def get_zip_data(dataset_path: pathlib.Path, directory_name: str, archive_name: str, archive_sha512: str, download_url: str, unzip_directory: bool = False) -> None:
+ """ Download and extract data zip archive if needed. """
+ if not (dataset_path / directory_name).exists():
+ if not (dataset_path / archive_name).exists():
+ download(download_url, dataset_path / archive_name, archive_name, archive_sha512)
+
+ with zipfile.ZipFile(str(dataset_path / archive_name), "r") as archive:
+ archive.extractall(dataset_path / directory_name if unzip_directory else dataset_path)
+
+
+def base_argument_parser(description: str = "", deterministic: bool = False, parser: Optional[argparse.ArgumentParser] = None) -> argparse.ArgumentParser:
+ assert(description != "" or parser is not None)
+ """ Return an argument parser with standard command line arguments used by preprocessing functions. """
+ parser: argparse.ArgumentParser = argparse.ArgumentParser(description=description) if parser is None else parser
+ parser.add_argument("tokenizer",
+ type=str,
+ nargs='?',
+ default="bert-base-cased",
+ help="Name of the transformers tokenizer")
+ if not deterministic:
+ parser.add_argument("-s", "--seed",
+ type=int,
+ default=0,
+ help="Seed of the RNG for shuffling the dataset")
+ return parser
+
+
+def dataset_name(args: argparse.Namespace, infix: str = "") -> str:
+ """ Returns the dataset name with suffix containing non-standard preprocessing parameters. """
+ suffix: str = ""
+ if "seed" in args and args.seed != 0:
+ suffix = f"-s{args.seed}"
+ return f"{args.tokenizer}{infix}{suffix}"
+
+
+def args_to_serialize(args: argparse.Namespace) -> Dict[str, Any]:
+ """ Map standard preprocessing command line arguments defined in base_argument_parser to serialize_supervised_dataset parameters. """
+ kwargs = {"tokenizer_name": args.tokenizer}
+ if "seed" in args:
+ kwargs["seed"] = args.seed
+ return kwargs
+
+
+def make_tokenizer(name: str, path: pathlib.Path) -> transformers.PreTrainedTokenizer:
+ """ Build the given tokenizer and save it. """
+ if not path.is_dir():
+ path.mkdir()
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(name)
+ special_tokens = ["<e1>", "</e1>", "<e2>", "</e2>", "<blank/>"]
+ tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
+ tokenizer.save_pretrained(path)
+
+ # fix huggingface transformers issue #6368
+ config_file = transformers.AutoConfig.from_pretrained(name)
+ config_file.save_pretrained(path)
+
+ return tokenizer
+
+
+def process_text_2(raw_text: str, tokenizer: transformers.PreTrainedTokenizer) -> Tuple[torch.Tensor, int, int]:
+ """
+ Transform a string with two entities tagged to a list of token ids together with the positions of the two entities.
+
+ The returned token list contains the token corresponding to the tags.
+ The two returned positions are the positions of <e1> and <e2>.
+ """
+ be1_id: int = tokenizer.convert_tokens_to_ids("<e1>")
+ be2_id: int = tokenizer.convert_tokens_to_ids("<e2>")
+
+ text: List[int] = tokenizer.encode(raw_text, add_special_tokens=True)
+ e1_pos: int = text.index(be1_id)
+ e2_pos: int = text.index(be2_id)
+ if len(text) > tokenizer.model_max_length:
+ text = text[:tokenizer.model_max_length]
+ e1_pos = min(tokenizer.model_max_length-1, e1_pos)
+ e2_pos = min(tokenizer.model_max_length-1, e2_pos)
+
+ return torch.tensor(text, dtype=torch.int32), e1_pos, e2_pos
+
+
+def process_text_n(raw_text: str, raw_entities: List[Tuple[str, int, int]], tokenizer: transformers.PreTrainedTokenizer) -> Tuple[torch.Tensor, List[Tuple[str, int, int]]]:
+ """
+ Transform a string with several entities tagged to a list of token id together with the positions of entities.
+
+ The returned token list does not contain the token corresponding to the tags.
+ The returned postions, are where the tags should be inserted.
+ If the leftmost tag is inserted first, the position of subsequent inserts should be shifted accordingly.
+ """
+ be1_id: int = tokenizer.convert_tokens_to_ids("<e1>")
+
+ # If one entity end at a position, and another entity start at the same position, we want to close the first entity before starting the sencond one, the second field "1 - extremity" has this function since the list is sorted in lexicographic order.
+ tag_positions: List[Tuple[int, int, int]] = [
+ (cast(int, entity[1 + extremity]), # Position of the tag (start or end of entity) in the sentence.
+ cast(int, 1 - extremity), # Whether this is a start or end of entity.
+ i) # The index of the entity used to rebuild the list at the end.
+ for i, entity in enumerate(raw_entities) for extremity in [0, 1]]
+ tag_positions.sort()
+
+ # We insert the tag <e1> at every tag postion in order to be able to convert postions in the raw text to positions in the token list.
+ pieces: List[str] = []
+ for piece_start, piece_end in zip([(0,)] + tag_positions, tag_positions + [(len(raw_text),)]):
+ pieces.append(raw_text[piece_start[0]:piece_end[0]])
+ pieces.append("<e1>")
+ # Remove the last <e1> added at the end of the sentence.
+ pieces.pop()
+
+ text: List[int] = tokenizer.encode("".join(pieces), add_special_tokens=True)
+ if len(text) > tokenizer.model_max_length:
+ text = text[:tokenizer.model_max_length]
+
+ # New entity list, with converted positions.
+ entities: List[List[Union[str, int]]] = [[entity[0], -1, -1] for entity in raw_entities]
+
+ j: int = 0 # Counter on the tags.
+ for i, token in enumerate(text):
+ if token == be1_id:
+ # The order of the <e1> in the text match the one in tag_positions.
+ tag_position: Tuple[int, int, int] = tag_positions[j]
+
+ # tag_position[2] is the index of the entity in raw_entities (and thus entities).
+ # tag_position[1] is 0 for the end of the entity and 1 for its start.
+ # Since the returned token list will be pruned of all the <e1>, the position of the tag should be shifted by the number of <e1> already met, thus "i - j".
+ entities[tag_position[2]][2 - tag_position[1]] = i - j
+
+ j += 1 # Move to the next tag.
+
+ # Remove all tags
+ text = list(filter(lambda x: x != be1_id, text))
+
+ # Remove entities which didn't fit inside tokenizer.model_max_length tokens.
+ entities = list(filter(lambda x: x[1] >= 0 and x[2] >= 0, entities))
+
+ tuple_entities: List[Tuple[str, int, int]] = list(map(tuple, entities))
+ return torch.tensor(text, dtype=torch.int32), tuple_entities
+
+
+def serialize_supervised_split(
+ path: pathlib.Path,
+ split: Iterable[Tuple[str, str, str, str, str]],
+ tokenizer: transformers.PreTrainedTokenizer,
+ entity_dictionary: Dictionary,
+ relation_dictionary: RelationDictionary) -> None:
+ """
+ Serialize a supervised split to a given path.
+
+ split is an iterable containing (text, directed relation, undirected relation, e1, e2) tuples.
+ Entities are ignored.
+ The relations are raw values (e.g. P42). This function performs the encoding.
+ """
+ data: List[Tuple[torch.Tensor, int, int, int]] = []
+
+ # TODO handle entities
+ for raw_text, relation, relation_base, _, _ in split:
+ text, e1_pos, e2_pos = process_text_2(raw_text, tokenizer)
+ relation_id: int = relation_dictionary.encode(relation, relation_base)
+ data.append((text, e1_pos, e2_pos, relation_id))
+
+ torch.save(("supervised", data), path)
+
+
+def serialize_fewshot_split(
+ path: pathlib.Path,
+ split: Iterable[Tuple[str, str, str, str, str]],
+ tokenizer: transformers.PreTrainedTokenizer,
+ entity_dictionary: Dictionary,
+ relation_dictionary: RelationDictionary) -> None:
+ """
+ Serialize a fewshot split to a given path.
+
+ split is an iterable containing (text, directed relation, undirected relation, e1, e2) tuples.
+ The relations and entities are raw values (e.g. P42, Q42). This function performs the encoding.
+ """
+ data: Dict[int, List[Tuple[torch.Tensor, int, int, int, int, int]]] = collections.defaultdict(list)
+
+ for raw_text, relation, relation_base, e1, e2 in split:
+ text, e1_pos, e2_pos = process_text_2(raw_text, tokenizer)
+ relation_id: int = relation_dictionary.encode(relation, relation_base)
+ e1_id: int = entity_dictionary.encode(e1)
+ e2_id: int = entity_dictionary.encode(e2)
+ data[relation_id].append((text, e1_pos, e2_pos, relation_id, e1_id, e2_id))
+
+ torch.save(("fewshot", list(data.values())), path)
+
+
+def serialize_fewshot_sampled_split(
+ path: pathlib.Path,
+ name: str,
+ split: Iterable[Tuple[Tuple[str, str, str], List[List[Tuple[str, str, str]]], int]],
+ tokenizer_name: str) -> None:
+ """
+ Serialize a sampled fewshot split.
+
+ split is an iterable of (query, candidates, answer) tuples.
+ In these tuples, query is a tuple (text, e1, e2).
+ The relations are not given.
+ """
+ tokenizer: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(str(path / "tokenizer"))
+ entity_dictionary = Dictionary()
+
+ data: List[Tuple[torch.Tensor, int, int, int, int, List[List[torch.Tensor]], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]] = []
+ for train, test, answer in split:
+ query_text, query_e1_pos, query_e2_pos = process_text_2(train[0], tokenizer)
+ query_e1 = entity_dictionary.encode(train[1])
+ query_e2 = entity_dictionary.encode(train[2])
+
+ way = len(test)
+ shot = len(test[0])
+ candidates_processed_text: List[List[Tuple[torch.Tensor, int, int]]] = list(map(lambda relation: list(map(lambda candidate: process_text_2(candidate[0], tokenizer), relation)), test))
+ candidates_text_len = max(map(lambda relation: max(map(lambda candidate: candidate[0].shape[0], relation)), candidates_processed_text))
+
+ candidates_text = [[None]*shot for _ in range(way)]
+ candidates_e1_pos = torch.empty((way, shot), dtype=torch.int64)
+ candidates_e2_pos = torch.empty((way, shot), dtype=torch.int64)
+ candidates_e1 = torch.empty((way, shot), dtype=torch.int64)
+ candidates_e2 = torch.empty((way, shot), dtype=torch.int64)
+
+ for n, (relation, relation_processed) in enumerate(zip(test, candidates_processed_text)):
+ for k, (candidate, candidate_processed) in enumerate(zip(relation, relation_processed)):
+ candidates_text[n][k] = candidate_processed[0]
+ candidates_e1_pos[n, k] = candidate_processed[1]
+ candidates_e2_pos[n, k] = candidate_processed[2]
+ candidates_e1[n, k] = entity_dictionary.encode(candidate[1])
+ candidates_e2[n, k] = entity_dictionary.encode(candidate[2])
+
+ data.append((query_text, query_e1_pos, query_e2_pos, query_e1, query_e2, candidates_text, candidates_e1_pos, candidates_e2_pos, candidates_e1, candidates_e2, answer))
+ entity_dictionary.save(path / f"{name}.entities")
+ torch.save(("sampled fewshot", data), path / name)
+
+
+def serialize_dataset(
+ supervision: str,
+ path: pathlib.Path,
+ splits: Dict[str, Iterable[Tuple[str, str, str, str, str]]],
+ tokenizer_name: str,
+ unknown_entity: Optional[str] = None,
+ unknown_relation: Optional[str] = None,
+ seed: Optional[int] = None) -> None:
+ """
+ Serialize a dataset to a given path.
+
+ The splits must be given as iterables of (text, relation, relation_base, e1, e2) tuples.
+ supervision must be one of "supervised" or "fewshot".
+ """
+ if not path.is_dir():
+ path.mkdir()
+
+ tokenizer: transformers.PreTrainedTokenizer = make_tokenizer(tokenizer_name, path / "tokenizer")
+ entity_dictionary = Dictionary(unknown=unknown_entity)
+ relation_dictionary = RelationDictionary(unknown=unknown_relation)
+
+ serialize_split = serialize_supervised_split if supervision == "supervised" else serialize_fewshot_split
+ for split_name in ["train", "valid", "test"]:
+ if split_name not in splits:
+ continue
+
+ split = list(splits[split_name])
+ if split_name == "train":
+ rng = random.Random(seed)
+ rng.shuffle(split)
+ split = tqdm.tqdm(split, desc=f"{split_name} tokenization")
+ serialize_split(path / split_name, split, tokenizer, entity_dictionary, relation_dictionary)
+ entity_dictionary.save(path / "entities")
+ relation_dictionary.save(path / "relations")
+
+
+def build_edge_list(data: Iterable[Tuple[str, List[Tuple[str, int, int]]]], tokenizer: transformers.PreTrainedTokenizer) -> Tuple[List[torch.Tensor], Dictionary, List[int], List[Tuple[int, int, int, int, int, int, int]]]:
+ """
+ Build a list of edges and nodes corresponding to the given data.
+
+ The tuples in the returned edge list are composed of the following elements:
+ (entity 1, entity 2, sentence id, entity 1 start, entity 1 end, entity 2 start, entity 2 end)
+ """
+ sentences: List[str] = []
+ entity_dictionary = Dictionary()
+ degrees: List[int] = []
+ edges: List[Tuple[int, int, int, int, int, int, int]] = []
+
+ for raw_sentence, raw_entities in data:
+ sentence: torch.Tensor
+ entities: List[Tuple[str, int, int]]
+ sentence, entities = process_text_n(raw_sentence, raw_entities, tokenizer)
+
+ # Buffer the ids to avoid re-hashing the entities
+ entity_ids: List[Optional[int]] = [None] * len(entities)
+ edge_added: bool = False
+
+ # Add all edges appearing in the clique corresponding to this sentence
+ for i, (e1_name, e1_start, e1_end) in enumerate(entities):
+ for j, (e2_name, e2_start, e2_end) in enumerate(entities[:i]):
+ # Soares et al. footnote 2 "We use a window of 40 tokens"
+ if max(e2_end - e1_start, e1_end - e2_start) < 40:
+ if entity_ids[i] is None:
+ entity_ids[i] = entity_dictionary.encode(e1_name)
+ if entity_ids[i] >= len(degrees):
+ degrees.append(0)
+ e1_id: int = cast(int, entity_ids[i])
+
+ if entity_ids[j] is None:
+ entity_ids[j] = entity_dictionary.encode(e2_name)
+ if entity_ids[j] >= len(degrees):
+ degrees.append(0)
+ e2_id: int = cast(int, entity_ids[j])
+
+ if e1_id <= e2_id:
+ edges.append((e1_id, e2_id, len(sentences), e1_start, e1_end, e2_start, e2_end))
+ else:
+ edges.append((e2_id, e1_id, len(sentences), e2_start, e2_end, e1_start, e1_end))
+ degrees[e1_id] += 1
+ degrees[e2_id] += 1
+ edge_added = True
+
+ if edge_added:
+ sentences.append(sentence)
+
+ return sentences, entity_dictionary, degrees, edges
+
+
+def serialize_unsupervised_dataset(
+ path: pathlib.Path,
+ data: Iterable[Tuple[str, List[Tuple[str, int, int]]]],
+ tokenizer_name: str,
+ seed: int) -> None:
+ """
+ Serialize an unsupervised dataset to a given path.
+
+ The data must be given as an iterable of (sentence, list of entities) tuples.
+ Where entities are tuples of (identifier, start indice in sentence, end indice in sentence).
+ """
+ if not path.is_dir():
+ path.mkdir()
+
+ tokenizer: transformers.PreTrainedTokenizer = make_tokenizer(tokenizer_name, path / "tokenizer")
+
+ sentences: List[str]
+ entities: Dictionary
+ edges: List[Tuple[int, int, int, int, int, int, int]]
+ graph = Graph(*build_edge_list(data, tokenizer))
+ graph.save(path / "train")
diff --git a/gbure/eval.py b/gbure/eval.py
@@ -0,0 +1,34 @@
+from typing import Any, Dict
+import pathlib
+
+import torch
+
+import gbure.train
+import gbure.utils
+
+
+class Evaluator(gbure.train.Trainer):
+ """
+ Evaluate a model.
+
+ Config:
+ valid: only evualte on validation split
+ test: only evualte on test split
+ """
+
+ def main(self) -> None:
+ """ Run the experiment (i.e. here, evaluate). """
+ if self.config.get("valid") or not self.config.get("test"):
+ self.evaluate("valid")
+ if self.config.get("test") or not self.config.get("valid"):
+ self.evaluate("test")
+
+
+if __name__ == "__main__":
+ gbure.utils.fix_transformers_logging_handler()
+ config: gbure.utils.dotdict = gbure.utils.parse_args()
+
+ state_dicts: Dict[str, Any] = torch.load(config.load)
+ logdir: pathlib.Path = state_dicts["logdir"]
+ gbure.utils.add_logging_handler(logdir)
+ Evaluator(config, logdir, state_dicts).run()
diff --git a/gbure/metrics.py b/gbure/metrics.py
@@ -0,0 +1,295 @@
+from typing import Dict, Optional, Union
+import math
+
+import torch
+import transformers
+
+from gbure.data.dictionary import RelationDictionary
+import gbure.data.graph
+
+
+class Metrics:
+ """
+ Class for computing metrics.
+
+ Twenty metrics are computed:
+ - Optimized loss (usually negative log likelihood)
+ - Accuracy
+ - {directed, undirected, half_directed} {micro, macro} {f1, precision, recall}
+ Note that the Accuracy is the true accuracy, taking directionality into account and scoring the unknown relation as any other relation.
+ The last 18 metrics follow the SemEval scorer:
+ - The unknown ("Other") relation is only scored indirectly
+ - Directed is equivalent to the metrics "USING DIRECTIONALITY"
+ - Undirected is equivalent to the metrics "IGNORING DIRECTIONALITY"
+ - Half-directed is equivalent to the metrics "TAKING DIRECTIONALITY INTO ACCOUNT -- OFFICIAL"
+ Note that the directed and half_directed micro metrics are equivalents.
+ """
+
+ def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: RelationDictionary, graph: Optional[gbure.data.graph.Graph]) -> None:
+ """ Initialize all metrics. """
+ self.config: gbure.utils.dotdict = config
+ self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
+ self.relation_dictionary: RelationDictionary = relation_dictionary
+ self.graph: Optional[gbure.data.graph.Graph] = graph
+ self.num_relations: int = len(relation_dictionary)
+ self.num_base_relations: int = relation_dictionary.base_size()
+
+ self.mask: torch.Tensor = self.build_mask()
+ self.base_transition: torch.Tensor = self.build_base_transition()
+ self.base_mask: torch.Tensor = self.base_transition.t().mv(self.mask).clamp(0, 1)
+
+ self.loss_sum: float = 0.0
+ self.relation_buckets: int = 1 + self.config.get("neighborhood_size", 0)
+ self.per_bucket_confusion: torch.Tensor = torch.zeros((self.relation_buckets, self.num_relations, self.num_relations), dtype=torch.int32)
+
+ @property
+ def confusion(self):
+ return self.per_bucket_confusion.sum(0)
+
+ def build_mask(self) -> torch.Tensor:
+ """ Return the relation mask used by semeval scorer (which partly ignore the unknown relation). """
+ mask: torch.Tensor = torch.ones(self.num_relations)
+ if self.relation_dictionary.unknown is not None:
+ assert(self.relation_dictionary.decode(0) == self.relation_dictionary.unknown)
+ mask[0] = 0
+ return mask
+
+ def build_base_transition(self) -> torch.Tensor:
+ """ Return the transition matrix from "directed relations" to "undirected relations". """
+ base_transition: torch.Tensor = torch.zeros((self.num_relations, self.num_base_relations))
+ for id, bid in enumerate(self.relation_dictionary.id_to_bid):
+ base_transition[id, bid] = 1
+ return base_transition
+
+ def compute_neighborhood_bucket(self, batch: Dict[str, torch.Tensor], index: int) -> int:
+ """ Return an index between 0 and self.config.neighborhood_size corresponding to the minimum number of neighbors in the sample. """
+ if self.graph is None:
+ return 0
+ neighborhood_size: Union[float, int] = math.inf
+ for feature, value in batch.items():
+ if "degree" in feature and "neighborhood" not in feature:
+ neighborhood_size = min(neighborhood_size, value[index].min().item())
+ # TODO here substract 1 for unsupervised (not that important since we don't care about unsupervised accuracies)
+ return min(neighborhood_size, self.relation_buckets-1) if neighborhood_size != math.inf else 0 # pytype: disable=bad-return-type
+
+ def update(self, batch: Dict[str, torch.Tensor], loss: torch.Tensor, losses: Dict[str, torch.Tensor], variables: Dict[str, torch.Tensor]) -> None:
+ """
+ Update metrics according to the given batch and the outputs of the model on this batch.
+
+ The variables dictionary returned by the model should usually contain a predicted_relation tensor.
+
+ Args:
+ batch: the input values used for evaluation
+ loss: the loss optimized by the model
+ losses: intermediary (unweighted) losses
+ variables: internal variables used by the model to compute the loss
+ """
+ predictions: torch.Tensor = variables.get("predicted_relation")
+ targets: torch.Tensor = batch.get("relation")
+ if targets is None:
+ targets = batch.get("query_relation")
+
+ if predictions is None and targets is None:
+ predictions = variables.get("prediction_relative")
+ targets = batch.get("answer")
+
+ for i, (prediction, target) in enumerate(zip(predictions, targets)):
+ neighborhood_bucket: int = self.compute_neighborhood_bucket(batch, i)
+ self.per_bucket_confusion[neighborhood_bucket, prediction, target] += 1
+
+ batch_size: int = predictions.shape[0]
+ self.loss_sum += loss.item() * batch_size
+
+ @property
+ def summary(self) -> Dict[str, str]:
+ """ Return a summary of metrics to be quickly displayed. """
+ metrics: Dict[str, str] = {"accuracy": f"{self.accuracy*100:.2f}",
+ "loss": f"{self.loss:.2f}"}
+ if self.relation_buckets > 1:
+ metrics.update({"accuracy_non_empty": f"{self.accuracy_non_empty*100:.2f}",
+ "accuracy_full": f"{self.accuracy_full*100:.2f}"})
+ return metrics
+
+ @property
+ def all(self) -> Dict[str, float]:
+ """ Return a dictionary of all metrics. """
+ keys = ["accuracy", "accuracy_non_empty", "accuracy_full", "loss"] + [
+ f"{direction}_{level}_{metric}"
+ for direction in ["directed", "undirected", "half_directed"]
+ for level in ["macro", "micro"]
+ for metric in ["f1", "precision", "recall"]]
+ return {key: getattr(self, key) for key in keys}
+
+ @property
+ def base_confusion(self) -> torch.Tensor:
+ """ Confusion matrix between "undirected" relation classes. """
+ return self.base_transition.t().mm(self.confusion.type_as(self.base_transition)).mm(self.base_transition)
+
+ @property
+ def accuracy(self) -> float:
+ return math.nan if self.confusion.sum() == 0 else self.confusion.diagonal().sum() / self.confusion.sum().type(torch.float32)
+
+ @property
+ def accuracy_non_empty(self) -> float:
+ non_empty: torch.Tensor = self.per_bucket_confusion[1:].sum(0)
+ return math.nan if non_empty.sum() == 0 else non_empty.diagonal().sum() / non_empty.sum().type(torch.float32)
+
+ @property
+ def accuracy_full(self) -> float:
+ full: torch.Tensor = self.per_bucket_confusion[-1]
+ return math.nan if full.sum() == 0 else full.diagonal().sum() / full.sum().type(torch.float32)
+
+ @property
+ def loss(self) -> float:
+ return math.nan if self.confusion.sum() == 0 else self.loss_sum / self.confusion.sum().type(torch.float32)
+
+ ##########################
+ # Directed macro metrics #
+ ##########################
+
+ @property
+ def directed_class_precision(self) -> torch.Tensor:
+ norm: torch.Tensor = self.confusion.sum(1)
+ norm[norm == 0] = 1
+ return self.confusion.diagonal() / norm.type(torch.float32)
+
+ @property
+ def directed_class_recall(self) -> torch.Tensor:
+ norm: torch.Tensor = self.confusion.sum(0)
+ norm[norm == 0] = 1
+ return self.confusion.diagonal() / norm.type(torch.float32)
+
+ @property
+ def directed_class_f1(self) -> torch.Tensor:
+ norm: torch.Tensor = self.directed_class_precision + self.directed_class_recall
+ norm[norm == 0] = 1
+ return 2 * self.directed_class_precision * self.directed_class_recall / norm
+
+ @property
+ def directed_macro_precision(self) -> float:
+ return ((self.directed_class_precision * self.mask).sum() / self.mask.sum()).item()
+
+ @property
+ def directed_macro_recall(self) -> float:
+ return ((self.directed_class_recall * self.mask).sum() / self.mask.sum()).item()
+
+ @property
+ def directed_macro_f1(self) -> float:
+ return ((self.directed_class_f1 * self.mask).sum() / self.mask.sum()).item()
+
+ ############################
+ # Undirected macro metrics #
+ ############################
+
+ @property
+ def undirected_class_precision(self) -> torch.Tensor:
+ norm: torch.Tensor = self.base_confusion.sum(1)
+ norm[norm == 0] = 1
+ return self.base_confusion.diagonal() / norm
+
+ @property
+ def undirected_class_recall(self) -> torch.Tensor:
+ norm: torch.Tensor = self.base_confusion.sum(0)
+ norm[norm == 0] = 1
+ return self.base_confusion.diagonal() / norm
+
+ @property
+ def undirected_class_f1(self) -> torch.Tensor:
+ norm: torch.Tensor = self.undirected_class_precision + self.undirected_class_recall
+ norm[norm == 0] = 1
+ return 2 * self.undirected_class_precision * self.undirected_class_recall / norm
+
+ @property
+ def undirected_macro_precision(self) -> float:
+ return ((self.undirected_class_precision * self.base_mask).sum() / self.base_mask.sum()).item()
+
+ @property
+ def undirected_macro_recall(self) -> float:
+ return ((self.undirected_class_recall * self.base_mask).sum() / self.base_mask.sum()).item()
+
+ @property
+ def undirected_macro_f1(self) -> float:
+ return ((self.undirected_class_f1 * self.base_mask).sum() / self.base_mask.sum()).item()
+
+ ###############################
+ # Half-directed macro metrics #
+ ###############################
+
+ @property
+ def half_directed_class_precision(self) -> torch.Tensor:
+ norm: torch.Tensor = self.base_confusion.sum(1)
+ norm[norm == 0] = 1
+ return self.base_transition.t().mv(self.confusion.diagonal().type_as(self.base_transition)) / norm
+
+ @property
+ def half_directed_class_recall(self) -> torch.Tensor:
+ norm: torch.Tensor = self.base_confusion.sum(0)
+ norm[norm == 0] = 1
+ return self.base_transition.t().mv(self.confusion.diagonal().type_as(self.base_transition)) / norm
+
+ @property
+ def half_directed_class_f1(self) -> torch.Tensor:
+ norm: torch.Tensor = self.half_directed_class_precision + self.half_directed_class_recall
+ norm[norm == 0] = 1
+ return 2 * self.half_directed_class_precision * self.half_directed_class_recall / norm
+
+ @property
+ def half_directed_macro_precision(self) -> float:
+ return ((self.half_directed_class_precision * self.base_mask).sum() / self.base_mask.sum()).item()
+
+ @property
+ def half_directed_macro_recall(self) -> float:
+ return ((self.half_directed_class_recall * self.base_mask).sum() / self.base_mask.sum()).item()
+
+ @property
+ def half_directed_macro_f1(self) -> float:
+ return ((self.half_directed_class_f1 * self.base_mask).sum() / self.base_mask.sum()).item()
+
+ #################
+ # Micro metrics #
+ #################
+
+ @property
+ def directed_micro_precision(self) -> float:
+ norm: torch.Tensor = (self.confusion.sum(1) * self.mask).sum()
+ return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item()
+
+ @property
+ def directed_micro_recall(self) -> float:
+ norm: torch.Tensor = (self.confusion.sum(0) * self.mask).sum()
+ return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item()
+
+ @property
+ def directed_micro_f1(self) -> float:
+ norm: float = self.directed_micro_precision + self.directed_micro_recall
+ return 0 if norm == 0 else 2 * (self.directed_micro_precision * self.directed_micro_recall) / norm
+
+ @property
+ def half_directed_micro_precision(self) -> float:
+ norm: torch.Tensor = (self.confusion.sum(1) * self.mask).sum()
+ return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item()
+
+ @property
+ def half_directed_micro_recall(self) -> float:
+ norm: torch.Tensor = (self.confusion.sum(0) * self.mask).sum()
+ return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item()
+
+ @property
+ def half_directed_micro_f1(self) -> float:
+ norm: float = self.half_directed_micro_precision + self.half_directed_micro_recall
+ return 0 if norm == 0 else 2 * (self.half_directed_micro_precision * self.half_directed_micro_recall) / norm
+
+ @property
+ def undirected_micro_precision(self) -> float:
+ norm: torch.Tensor = (self.base_confusion.sum(1) * self.base_mask).sum()
+ return 0 if norm == 0 else ((self.base_confusion.diagonal() * self.base_mask).sum() / norm).item()
+
+ @property
+ def undirected_micro_recall(self) -> float:
+ norm: torch.Tensor = (self.base_confusion.sum(0) * self.base_mask).sum()
+ return 0 if norm == 0 else ((self.base_confusion.diagonal() * self.base_mask).sum() / norm).item()
+
+ @property
+ def undirected_micro_f1(self) -> float:
+ norm: float = self.undirected_micro_precision + self.undirected_micro_recall
+ return 0 if norm == 0 else 2 * (self.undirected_micro_precision * self.undirected_micro_recall) / norm
diff --git a/gbure/model/__init__.py b/gbure/model/__init__.py
diff --git a/gbure/model/contrastive_alignment.py b/gbure/model/contrastive_alignment.py
@@ -0,0 +1,85 @@
+from typing import Any, Dict, List, Tuple
+
+import torch
+import transformers
+
+from gbure.model.linguistic_encoder import LinguisticEncoder
+from gbure.model.masked_lm import MaskedLM
+from gbure.model.similarity import LinguisticSimilarity, TopologicalSimilarity
+from gbure.model.topological_encoder import TopologicalEncoder
+import gbure.utils
+
+
+class Model(torch.nn.Module):
+ """
+ Unsupervised pre-training model from Soares et al.
+
+ Correspond to the model explained in section 4.
+ """
+
+ def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: None) -> None:
+ """
+ Instantiate a Soares et al. matching the blanks model.
+
+ Args:
+ config: global config object
+ tokenizer: tokenizer used to create the vocabulary
+ relation_dictionary: dictionary of all relations (unused)
+ margin: maximum enforced meta-distance between positive and negative distances
+ """
+ super().__init__()
+
+ self.config: gbure.utils.dotdict = config
+ self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
+
+ self.transformer: transformers.PreTrainedModel = transformers.AutoModelForMaskedLM
+ self.language_model: torch.nn.Module = MaskedLM(config, tokenizer)
+ self.linguistic_encoder: torch.nn.Module = LinguisticEncoder(config, tokenizer, transformer=self.language_model.encoder)
+ self.topological_encoder: torch.nn.Module = TopologicalEncoder(config, self.linguistic_encoder)
+ self.linguistic_similarity: torch.nn.Module = LinguisticSimilarity(config)
+ self.topological_similarity: torch.nn.Module = TopologicalSimilarity(config)
+
+ def forward(self, **batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
+ """ Compute the unsupervised matching the blanks loss between the given pairs. """
+ linguistic_embeddings: List[torch.Tensor] = []
+ topological_embeddings: List[Any] = []
+
+ for i, order in enumerate(["first_", "second_", "third_"]):
+ linguistic_embeddings.append(self.linguistic_encoder(batch[f"{order}text"], batch[f"{order}mask"], batch[f"{order}entity_positions"])[0])
+ topological_embeddings.append(self.topological_encoder(order, linguistic_embeddings[-1], **batch))
+
+ # d(a_1, a_2, [1, 0])
+ positive_linguistic_similarity: torch.Tensor = self.linguistic_similarity(linguistic_embeddings[0], linguistic_embeddings[1])
+ # d(a_1, a_3, [1, 0])
+ negative_linguistic_similarity: torch.Tensor = self.linguistic_similarity(linguistic_embeddings[0], linguistic_embeddings[2])
+
+ # d(a_1, a_2, [0, 1])
+ positive_topological_similarity: torch.Tensor = self.topological_similarity(topological_embeddings[0], topological_embeddings[1])[0]
+ # d(a_1, a_3, [0, 1])
+ negative_topological_similarity: torch.Tensor = self.topological_similarity(topological_embeddings[0], topological_embeddings[2])[0]
+
+ # (d(a_1, a_2, [1, 0]) - d(a_1, a_2, [0, 1]))²
+ positive: torch.Tensor = 2 * (positive_linguistic_similarity - positive_topological_similarity)**2
+ # (d(a_1, a_3, [1, 0]) - d(a_1, a_2, [0, 1]))² + (d(a_1, a_2, [1, 0]) - d(a_1, a_3, [0, 1]))²
+ negative: torch.Tensor = ((positive_linguistic_similarity - negative_topological_similarity)**2 + (negative_linguistic_similarity - positive_topological_similarity)**2)
+
+ contrastive_loss: torch.Tensor = torch.nn.functional.relu(self.config.margin + positive - negative).mean()
+ if self.config.get("language_model_weight", 0) > 0:
+ lm_loss: torch.Tensor = self.language_model(batch["first_mlm_input"], batch["first_mlm_target"], batch["first_mask"])
+ else:
+ lm_loss: int = 0
+ loss: torch.Tensor = contrastive_loss + self.config.get("language_model_weight", 0) * lm_loss
+
+ losses: Dict[str, torch.Tensor] = {
+ "positive": positive.mean(),
+ "negative": negative.mean(),
+ "contrastive": contrastive_loss,
+ "reconstruction": lm_loss}
+ variables: Dict[str, torch.Tensor] = {
+ "positive_linguistic_similarity": positive_linguistic_similarity,
+ "negative_linguistic_similarity": negative_linguistic_similarity,
+ "positive_topological_similarity": positive_topological_similarity,
+ "negative_topological_similarity": negative_topological_similarity
+ }
+
+ return loss, losses, variables
diff --git a/gbure/model/fewshot.py b/gbure/model/fewshot.py
@@ -0,0 +1,117 @@
+from typing import Dict, Optional, Tuple
+
+import torch
+import transformers
+
+import gbure.data.dictionary
+import gbure.model.linguistic_encoder
+import gbure.model.similarity
+import gbure.model.topological_encoder
+import gbure.utils
+
+
+class Model(torch.nn.Module):
+ """
+ Few shot model from Soares et al.
+
+ Correspond to the right subfigure of Figure 2.
+ """
+
+ def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: gbure.data.dictionary.RelationDictionary, train_model: Optional[torch.nn.Module] = None) -> None:
+ """
+ Instantiate a Soares et al. few shot model.
+
+ Args:
+ config: global config object
+ tokenizer: tokenizer used to create the vocabulary
+ relation_dictionary: dictionary of all relations
+ train_model: unsupervised model used to initialize the few shot model.
+ """
+ super().__init__()
+
+ self.config: gbure.utils.dotdict = config
+ self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
+ self.relation_dictionary: gbure.data.dictionary.RelationDictionary = relation_dictionary
+
+ if train_model is None:
+ self.linguistic_encoder: torch.nn.Module = gbure.model.linguistic_encoder.LinguisticEncoder(config, tokenizer)
+ self.linguistic_similarity: torch.nn.Module = gbure.model.similarity.LinguisticSimilarity(config)
+ else:
+ self.linguistic_encoder: torch.nn.Module = train_model.linguistic_encoder
+ self.linguistic_similarity: torch.nn.Module = train_model.linguistic_similarity
+ self.loss_module = torch.nn.NLLLoss(reduction="mean")
+
+ if self.config.get("neighborhood_size", 0) > 0:
+ if train_model is None:
+ self.topological_encoder: torch.nn.Module = gbure.model.topological_encoder.TopologicalEncoder(config, self.linguistic_encoder)
+ self.topological_similarity: torch.nn.Module = gbure.model.similarity.TopologicalSimilarity(config)
+ else:
+ self.topological_encoder: torch.nn.Module = train_model.topological_encoder
+ self.topological_similarity: torch.nn.Module = train_model.topological_similarity
+ if not self.config.get("undefined_poison_whole_meta"):
+ if train_model is not None:
+ self.neutral_topological_similarity = train_model.neutral_topological_similarity
+ elif self.config.get("neutral_topological_similarity") is not None:
+ self.neutral_topological_similarity: float = self.config.neutral_topological_similarity
+ else:
+ self.neutral_topological_similarity: torch.nn.Parameter = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
+
+ def combine_similarities(self, linguistic: torch.Tensor, topological: torch.Tensor, topological_mask: Optional[torch.Tensor]) -> torch.Tensor:
+ """ Combine linguistic and topological similarities into a single value. """
+ # topological is of dimension (batch, way, shot, slot)
+ if topological_mask is not None:
+ if self.config.get("undefined_poison_whole_meta"):
+ topological *= topological_mask.prod(1, keepdim=True).prod(2, keepdim=True)
+ else:
+ topological += (~topological_mask) * self.neutral_topological_similarity
+ topological = topological.mean(3)
+ return self.config.get("linguistic_weight", 1) * linguistic + self.config.get("topological_weight", 1) * topological
+
+ def forward(self,
+ query_text: torch.Tensor,
+ query_mask: torch.Tensor,
+ query_entity_positions: torch.Tensor,
+ candidates_text: torch.Tensor,
+ candidates_mask: torch.Tensor,
+ candidates_entity_positions: torch.Tensor,
+ answer: torch.Tensor,
+ query_relation: Optional[torch.Tensor] = None,
+ candidates_relation: Optional[torch.Tensor] = None,
+ **batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
+ """ Compute the fewshot loss on the given query and candidates. """
+ batch_size: int = query_text.shape[0]
+ fake_batch_size: int = batch_size * self.config.shot * self.config.way
+
+ query: torch.Tensor = self.linguistic_encoder(query_text, query_mask, query_entity_positions)[0]
+ candidates: torch.Tensor = self.linguistic_encoder(
+ candidates_text.view(fake_batch_size, -1),
+ candidates_mask.view(fake_batch_size, -1),
+ candidates_entity_positions.view(fake_batch_size, 2)
+ )[0].view(batch_size, self.config.way, self.config.shot, -1)
+ logits: torch.Tensor = self.linguistic_similarity(candidates, query.unsqueeze(1).unsqueeze(2))
+
+ if self.config.get("neighborhood_size", 0) > 0:
+ topological_query = self.topological_encoder("query_", query, degree_delta=1, **batch)
+ topological_candidates = self.topological_encoder("candidates_", candidates, degree_delta=1, **batch)
+ if isinstance(topological_query, torch.Tensor):
+ topological_query = topological_query.view(topological_query.shape[0], 1, 1, -1)
+ else:
+ topological_query = tuple(x.view(x.shape[0], 1, 1, *x.shape[1:]) for x in topological_query)
+
+ topological_similarity, topological_mask = self.topological_similarity(topological_query, topological_candidates)
+ logits = self.combine_similarities(logits, topological_similarity, topological_mask)
+
+ log_probabilities = torch.nn.functional.log_softmax(
+ logits.view(batch_size, self.config.way*self.config.shot),
+ dim=1).view(batch_size, self.config.way, self.config.shot)
+ log_probabilities = log_probabilities.logsumexp(2)
+
+ loss: torch.Tensor = self.loss_module(log_probabilities, answer)
+ prediction: torch.Tensor = log_probabilities.argmax(1)
+
+ variables = {"prediction_logits": logits, "prediction_relative": prediction}
+ if candidates_relation is not None:
+ batch_ids: torch.Tensor = torch.arange(batch_size, device=loss.device)
+ variables["predicted_relation"] = candidates_relation[batch_ids, prediction, 0]
+
+ return loss, {}, variables
diff --git a/gbure/model/linguistic_encoder.py b/gbure/model/linguistic_encoder.py
@@ -0,0 +1,75 @@
+from typing import Callable, Optional, Tuple
+
+import torch
+import transformers
+
+import gbure.utils
+
+
+class LinguisticEncoder(torch.nn.Module):
+ """
+ Transformer encoder from Soares et al.
+
+ Correspond to the left part of each subfigure of Figure 2 (Deep Transformer Encoder and the green layer above).
+ We only implement the "entity markers, entity start" variant (which is the one with the best performance).
+
+ Config:
+ transformer_model: Which transformer to use (e.g. bert-large-uncased).
+ post_transformer_layer: The transformation applied after the transformer (must be "linear" or "layer_norm")
+ """
+
+ def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, transformer: Optional[transformers.PreTrainedModel] = None) -> None:
+ """
+ Instantiate a Soares et al. encoder.
+
+ Args:
+ config: global config object
+ tokenizer: tokenizer used to create the vocabulary
+ transformer: the transformer to use instead of loading a pre-trained one
+ """
+ super().__init__()
+
+ self.config: gbure.utils.dotdict = config
+ self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
+
+ self.transformer: transformers.PreTrainedModel
+ if transformer is not None:
+ self.transformer = transformer
+ elif self.config.get("load") or self.config.get("pretrained"):
+ # TODO introduce a config parameter to change the initialization of <tags> embeddings
+ transformer_config = transformers.AutoConfig.from_pretrained(self.config.transformer_model)
+ transformer_config.vocab_size = len(tokenizer)
+ self.transformer = transformers.AutoModel.from_config(transformer_config)
+ else:
+ self.transformer = transformers.AutoModel.from_pretrained(self.config.transformer_model)
+ self.transformer.resize_token_embeddings(len(tokenizer))
+
+ self.post_transformer: Callable[[torch.Tensor], torch.Tensor]
+ if self.config.post_transformer_layer == "linear":
+ self.post_transformer_linear = torch.nn.Linear(in_features=self.output_size, out_features=self.output_size)
+ self.post_transformer = lambda x: self.post_transformer_linear(x)
+ elif self.config.post_transformer_layer == "layer_norm":
+ # It is not clear whether a Linear should be added before the layer_norm, see Soares et al. section 3.3
+ # Setting elementwise_affine to True (the default) makes little sense when computing similarity scores.
+ self.post_transformer_linear = torch.nn.Linear(in_features=self.output_size, out_features=self.output_size)
+ self.post_transformer_activation = torch.nn.LayerNorm(self.output_size, elementwise_affine=False)
+ self.post_transformer = lambda x: self.post_transformer_activation(self.post_transformer_linear(x))
+ elif self.config.post_transformer_layer == "none":
+ self.post_transformer = lambda x: x
+ else:
+ raise RuntimeError("Unsuported config value for post_transformer_layer")
+
+ @property
+ def output_size(self) -> int:
+ """ Dimension of the representation returned by the model. """
+ return 2 * self.transformer.config.hidden_size
+
+ def forward(self, text: torch.Tensor, mask: torch.Tensor, entity_positions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """ Encode the given text into a fixed size representation. """
+ batch_size: int = text.shape[0]
+ batch_ids: torch.Tensor = torch.arange(batch_size, device=text.device, dtype=torch.int64).unsqueeze(1)
+
+ # The first element of the tuple is the Batch×Sentence×Hidden output matrix.
+ transformer_out: torch.Tensor = self.transformer(text, attention_mask=mask)[0]
+ sentence: torch.Tensor = transformer_out[batch_ids, entity_positions].view(batch_size, self.output_size)
+ return self.post_transformer(sentence), transformer_out
diff --git a/gbure/model/masked_lm.py b/gbure/model/masked_lm.py
@@ -0,0 +1,59 @@
+from typing import Callable
+
+import torch
+import transformers
+
+import gbure.utils
+
+
+class MaskedLM(torch.nn.Module):
+ """
+ Masked language model to be used on top of a transformer.
+
+ This class is only useful for unsupervised pre-training, Soares et al. keep the BERT loss alongside their "matching the blanks" loss.
+
+ Config:
+ transformer_model: Which transformer to use (e.g. bert-large-uncased).
+ """
+
+ def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer) -> None:
+ """
+ Instantiate a masked language model.
+
+ Args:
+ config: global config object
+ tokenizer: tokenizer used to create the vocabulary
+ """
+ super().__init__()
+
+ self.config: gbure.utils.dotdict = config
+ self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
+
+ self.transformer: transformers.PreTrainedModel
+ if self.config.get("load") or self.config.get("pretrained"):
+ # TODO introduce a config parameter to change the initialization of <tags> embeddings
+ transformer_config = transformers.AutoConfig.from_pretrained(self.config.transformer_model)
+ transformer_config.vocab_size = len(tokenizer)
+ self.transformer = transformers.AutoModelForMaskedLM(transformer_config)
+ else:
+ self.transformer = transformers.AutoModelForMaskedLM.from_pretrained(self.config.transformer_model)
+ self.transformer.resize_token_embeddings(len(tokenizer))
+
+ @property
+ def encoder(self) -> transformers.PreTrainedModel:
+ if isinstance(self.transformer, transformers.BertForMaskedLM):
+ return self.transformer.bert
+ elif isinstance(self.transformer, transformers.DistilBertForMaskedLM):
+ return self.transformer.distilbert
+ else:
+ raise RuntimeError("Unknown transformer model, can't split masked language model off")
+
+ def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ """
+ Compute masked language model loss from transformer's output.
+
+ Sadly, huggingface does not provide an unified interface, so we need to do some copy-pasting.
+ """
+ masked_target: torch.Tensor = target * mask - 100 * (1 - mask.long())
+ output = self.transformer(input_ids=input, attention_mask=mask, labels=masked_target, return_dict=True)
+ return output.loss
diff --git a/gbure/model/matching_the_blanks.py b/gbure/model/matching_the_blanks.py
@@ -0,0 +1,109 @@
+from typing import Dict, Optional, Tuple
+import math
+
+import torch
+import transformers
+
+import gbure.model.linguistic_encoder
+import gbure.model.topological_encoder
+import gbure.model.masked_lm
+import gbure.model.similarity
+import gbure.utils
+
+
+class Model(torch.nn.Module):
+ """
+ Unsupervised pre-training model from Soares et al.
+
+ Correspond to the model explained in Section 4.
+
+ Config:
+ linguistic_weight: factor for the linguistic similarity
+ topological_weight: factor for the topological similarity
+ """
+
+ def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: None) -> None:
+ """
+ Instantiate a Soares et al. matching the blanks model.
+
+ Args:
+ config: global config object
+ tokenizer: tokenizer used to create the vocabulary
+ relation_dictionary: dictionary of all relations (unused)
+ """
+ super().__init__()
+
+ self.config: gbure.utils.dotdict = config
+ self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
+
+ self.transformer: transformers.PreTrainedModel = transformers.AutoModelForMaskedLM
+ self.language_model: torch.nn.Module = gbure.model.masked_lm.MaskedLM(config, tokenizer)
+ self.linguistic_encoder: torch.nn.Module = gbure.model.linguistic_encoder.LinguisticEncoder(config, tokenizer, transformer=self.language_model.encoder)
+ self.linguistic_similarity: torch.nn.Module = gbure.model.similarity.LinguisticSimilarity(config)
+
+ if self.config.get("neighborhood_size", 0) > 0:
+ self.topological_encoder: torch.nn.Module = gbure.model.topological_encoder.TopologicalEncoder(config, self.linguistic_encoder)
+ self.topological_similarity: torch.nn.Module = gbure.model.similarity.TopologicalSimilarity(config)
+ if not self.config.get("undefined_poison_whole_meta"):
+ if self.config.get("neutral_topological_similarity") is not None:
+ self.neutral_topological_similarity: float = self.config.neutral_topological_similarity
+ else:
+ self.neutral_topological_similarity: torch.nn.Parameter = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
+
+ def combine_similarities(self, linguistic: torch.Tensor, topological: torch.Tensor, topological_mask: Optional[torch.Tensor]) -> torch.Tensor:
+ """ Combine linguistic and topological similarities into a single value. """
+ # topological is of dimension (batch, slot)
+ if topological_mask is not None:
+ if not self.config.get("undefined_poison_whole_meta"):
+ topological += (~topological_mask) * self.neutral_topological_similarity
+ topological = topological.mean(1)
+ return self.config.get("linguistic_weight", 1) * linguistic + self.config.get("topological_weight", 1) * topological
+
+ def forward(self,
+ first_text: torch.Tensor,
+ first_mask: torch.Tensor,
+ first_entity_positions: torch.Tensor,
+ second_text: torch.Tensor,
+ second_mask: torch.Tensor,
+ second_entity_positions: torch.Tensor,
+ polarity: torch.Tensor,
+ first_mlm_input: torch.Tensor,
+ first_mlm_target: torch.Tensor,
+ **batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
+ """ Compute the unsupervised matching the blanks loss between the given pairs. """
+ first: torch.Tensor
+ first_transformer_out: torch.Tensor
+ first, first_transformer_out = self.linguistic_encoder(first_text, first_mask, first_entity_positions)
+ second: torch.Tensor = self.linguistic_encoder(second_text, second_mask, second_entity_positions)[0]
+ linguistic_similarity: torch.Tensor = self.linguistic_similarity(first, second)
+ lm_loss: torch.Tensor = self.language_model(first_mlm_input, first_mlm_target, first_mask)
+
+ similarity: torch.Tensor = linguistic_similarity
+ if self.config.get("neighborhood_size", 0) > 0:
+ topological_first = self.topological_encoder("first_", first, **batch)
+ topological_second = self.topological_encoder("second_", second, **batch)
+
+ topological_similarity, topological_mask = self.topological_similarity(topological_first, topological_second)
+ similarity = self.combine_similarities(linguistic_similarity, topological_similarity, topological_mask)
+
+ # There seem to be a mistake in Soares et al. §4.1 in the equation of p(l=1|r,r')
+ # The equation use 1/(1+exp(x)) which seems counter intuitive since the case where r=r' would lead to a low probability.
+ # By default 1/(1+exp(-x)) is used, but the equation given in the paper can be used with --reverse_sigmoid.
+ if self.config.get("reverse_sigmoid"):
+ scores: torch.Tensor = - torch.nn.functional.logsigmoid(- polarity * similarity)
+ else:
+ scores: torch.Tensor = - torch.nn.functional.logsigmoid(polarity * similarity)
+ mtb_loss: torch.Tensor = scores.mean()
+
+ is_positive: torch.Tensor = (polarity + 1) // 2
+ is_negative: torch.Tensor = 1 - is_positive
+ positive: torch.Tensor = (scores * is_positive).sum() / is_positive.sum()
+ negative: torch.Tensor = (scores * is_negative).sum() / is_negative.sum()
+
+ losses: Dict[str, torch.Tensor] = {
+ "positive": positive,
+ "negative": negative,
+ "mtb": mtb_loss,
+ "reconstruction": lm_loss}
+ loss: torch.Tensor = mtb_loss + self.config.language_model_weight * lm_loss
+ return loss, losses, {"similarity": similarity}
diff --git a/gbure/model/similarity.py b/gbure/model/similarity.py
@@ -0,0 +1,134 @@
+from typing import Dict, List, Optional, Tuple, Union
+
+import geomloss
+import torch
+
+import gbure.utils
+
+
+class LinguisticSimilarity(torch.nn.Module):
+ """
+ Compute the similarity between two bertcoder representations.
+
+ The bertcoder embeddings are all in the same direction of space since they all correspond to the tags <e1> and <e2>
+ Thus after a dot product, the activations are not standardized anymore, even after scaling by √d
+
+ Config:
+ linguistic_similarity: the function used to compute the similarity between embeddings
+ linguistic_similarity_delta: an additive constant to add to all similarity (useful to have a strictly positive cosine)
+ latent_metric_scale: how to scale the similarity once computed
+ latent_dot_mean: when latent_metric_scale=="standard", the value to substract
+ latent_dot_std: when latent_metric_scale=="standard", the value to divide by
+ """
+
+ def __init__(self, config: gbure.utils.dotdict) -> None:
+ super().__init__()
+ self.config = config
+
+ def forward(self, lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor:
+ """ Compute the similarity between lhs and rhs. """
+ if self.config.linguistic_similarity == "dot":
+ similarity: torch.Tensor = torch.einsum('b...d,b...d->b...', lhs, rhs)
+ elif self.config.linguistic_similarity == "cosine":
+ similarity: torch.Tensor = torch.nn.functional.cosine_similarity(lhs, rhs, dim=-1)
+ elif self.config.linguistic_similarity == "euclidean":
+ similarity: torch.Tensor = - torch.sum((lhs - rhs)**2, dim=-1)
+ else:
+ raise RuntimeError(f"Unknown config value for linguistic_similarity: {self.config.linguistic_similarity}")
+ similarity += self.config.get("linguistic_similarity_delta", 0)
+
+ encoder_output_size: int = lhs.shape[-1]
+ if self.config.get("latent_metric_scale") == "sqrt":
+ # If X, Y ~ N(0, I_d):
+ # E[X·Y] = 0
+ # Var[X·Y] = √d
+ similarity = similarity / (encoder_output_size ** 0.5)
+ elif self.config.get("latent_metric_scale") == "full":
+ similarity = similarity / encoder_output_size
+ elif self.config.get("latent_metric_scale") == "match":
+ # If X ~ N(0, 1):
+ # E[X²] = 1
+ # Var[X²] = √2
+ similarity = (similarity - encoder_output_size) / ((2**0.5) * encoder_output_size)
+ elif self.config.get("latent_metric_scale") == "standard":
+ similarity = (similarity - self.config.latent_dot_mean) / self.config.latent_dot_std
+ elif self.config.get("latent_metric_scale") is not None:
+ raise RuntimeError("Unsuported config value for latent_metric_scale")
+
+ return similarity
+
+
+class TopologicalSimilarity(torch.nn.Module):
+ """
+ Compute the similarity between two neighborhood representations.
+
+ This similarity is either an inner product in the case of fixed-size representations or the negative of the 1-Wasserstein distance.
+ """
+
+ def __init__(self, config: gbure.utils.dotdict) -> None:
+ super().__init__()
+ self.config = config
+ if self.config.get("gcn_aggregator", "none") == "none":
+ self.sinkhorn = geomloss.SamplesLoss(loss="sinkhorn", p=self.config.get("wasserstein_underlying_distance", 2), blur=self.config.get("sinkhorn_blur", 0.05))
+
+ @staticmethod
+ def merge_shapes(lhs: List[int], rhs: List[int]) -> List[int]:
+ """ Find the shape of an elementwise operation between two tensors of the given shapes. """
+ lhs = [1] * (len(rhs) - len(lhs)) + lhs
+ rhs = [1] * (len(lhs) - len(rhs)) + rhs
+ res = []
+ for left, right in zip(lhs, rhs):
+ if left == right:
+ res.append(left)
+ elif left == 1 or right == 1:
+ res.append(left + right - 1)
+ else:
+ raise RuntimeError(f"Incompatible shapes {lhs} and {rhs}")
+ return res
+
+ def forward(self, lhs: Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], rhs: Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Compute the distance between two topological embeddings.
+
+ The embeddings are either a single vector (per sample) which was pooled using a GCN, or four matrices corresponding to:
+ - the embeddings of the head neighborhood
+ - the embeddings of the tail neighborhood
+ - the mask of the head neighborhood
+ - the mask of the tail neighborhood
+ """
+ if isinstance(lhs, tuple):
+ shape: List[int] = self.merge_shapes(list(lhs[0].shape)[:-2], list(rhs[0].shape)[:-2])
+
+ def expand_tuple(t: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ return (t[0].expand(tuple(shape)+t[0].shape[-2:]),
+ t[1].expand(tuple(shape)+t[1].shape[-2:]),
+ t[2].expand(tuple(shape)+t[2].shape[-1:]),
+ t[3].expand(tuple(shape)+t[3].shape[-1:]))
+
+ lhs = expand_tuple(lhs)
+ rhs = expand_tuple(rhs)
+
+ cumulative_shape: List[int] = []
+ accumulator: int = 1
+ for x in shape:
+ cumulative_shape.append(accumulator)
+ accumulator *= x
+ result: torch.Tensor = lhs[0].new_zeros(shape+[2])
+ result_mask: torch.Tensor = lhs[0].new_zeros(shape+[2], dtype=torch.bool)
+ for i in range(accumulator):
+ indices = []
+ for denominator, modulo in zip(cumulative_shape, shape):
+ indices.append(i // denominator % modulo)
+ indices = tuple(indices)
+ for slot in range(2):
+ sample_lhs: torch.Tensor = lhs[slot][indices][lhs[2+slot][indices]]
+ sample_rhs: torch.Tensor = rhs[slot][indices][rhs[2+slot][indices]]
+ if sample_lhs.numel() > 0 and sample_rhs.numel() > 0:
+ result[indices+(slot,)] = - self.sinkhorn(sample_lhs, sample_rhs)
+ result_mask[indices+(slot,)] = True
+ else:
+ result_mask[indices+(slot,)] = False
+ return result, result_mask
+ else:
+ # TODO implement other topological similarities
+ return torch.einsum('b...d,b...d->b...', lhs, rhs), None
diff --git a/gbure/model/supervised.py b/gbure/model/supervised.py
@@ -0,0 +1,45 @@
+from typing import Dict, Tuple
+
+import torch
+import transformers
+
+from gbure.data.dictionary import RelationDictionary
+from gbure.model.linguistic_encoder import LinguisticEncoder
+import gbure.utils
+
+
+class Model(torch.nn.Module):
+ """
+ Supervised model from Soares et al.
+
+ Correspond to the left subfigure of Figure 2.
+ """
+
+ def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: RelationDictionary) -> None:
+ """
+ Instantiate a Soares et al. supervised model.
+
+ Args:
+ config: global config object
+ tokenizer: tokenizer used to create the vocabulary
+ relation_dictionary: dictionary of all relations
+ """
+ super().__init__()
+
+ self.config: gbure.utils.dotdict = config
+ self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
+ self.relation_dictionary: RelationDictionary = relation_dictionary
+
+ self.encoder: torch.nn.Module = LinguisticEncoder(config, tokenizer)
+ self.relation_encoder = torch.nn.Linear(
+ in_features=self.encoder.output_size,
+ out_features=len(relation_dictionary),
+ bias=False)
+ self.loss_module = torch.nn.CrossEntropyLoss(reduction="mean")
+
+ def forward(self, text: torch.Tensor, mask: torch.Tensor, entity_positions: torch.Tensor, relation: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
+ """ Compute the supervised loss on the given text and target relation. """
+ latent: torch.Tensor = self.encoder(text, mask, entity_positions)[0]
+ logits: torch.Tensor = self.relation_encoder(latent)
+ loss: torch.Tensor = self.loss_module(logits, relation)
+ return loss, {}, {"prediction_logits": logits, "predicted_relation": logits.argmax(1)}
diff --git a/gbure/model/topological_encoder.py b/gbure/model/topological_encoder.py
@@ -0,0 +1,65 @@
+from typing import List, Optional, Tuple, Union
+import functools
+import math
+import operator
+
+import torch
+
+import gbure.utils
+
+
+class TopologicalEncoder(torch.nn.Module):
+ """
+ Encoder for neighborhoods.
+
+ Config:
+ gcn_aggregator: aggregator used to pool the representations of several neighbors into a fixed-size one.
+ """
+
+ def __init__(self, config: gbure.utils.dotdict, linguistic_encoder: torch.nn.Module) -> None:
+ """
+ Instantiate a Soares et al. encoder.
+
+ Args:
+ config: global config object
+ linguistic_encoder: the model used to get a fixed-size representation of text
+ """
+ super().__init__()
+ self.config: gbure.utils.dotdict = config
+ self.linguistic_encoder: torch.nn.Module = linguistic_encoder
+
+ if self.config.get("gcn_aggregator", "") in ["mean", "chebyshev"]:
+ self.gcn_layer: torch.nn.Module = torch.nn.Linear(in_features=self.linguistic_encoder.output_size, out_features=self.linguistic_encoder.output_size)
+
+ def forward(self, prefix: str, loop: torch.Tensor, degree_delta: int = 0, **batch) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
+ """
+ Encode the neighborhood of the given prefix.
+
+ When a gcn_aggregator is defined, this result in a fixed-size representation, otherwise it returns clouds of points to be compared using optimal transport.
+ The degree_delta parameters changes the degrees used to compute various GCN weighting. This can be useful when the sample comes from outside the graph so 1 should be added to the degrees.
+ """
+ linguistic_embeddings: List[torch.Tensor] = []
+ masks: List[torch.Tensor] = []
+ for slot in [1, 2]:
+ fake_batch_size: int = functools.reduce(operator.mul, batch[f"{prefix}e{slot}_neighborhood_text"].shape[:-1])
+ mask: torch.Tensor = batch[f"{prefix}e{slot}_neighborhood_mask"].view(fake_batch_size, -1)[:, 0].unsqueeze(1)
+ linguistic_embeddings.append((self.linguistic_encoder(
+ batch[f"{prefix}e{slot}_neighborhood_text"].view(fake_batch_size, -1),
+ batch[f"{prefix}e{slot}_neighborhood_mask"].view(fake_batch_size, -1),
+ batch[f"{prefix}e{slot}_neighborhood_entity_positions"].view(fake_batch_size, 2)
+ )[0] * mask).view(*batch[f"{prefix}e{slot}_neighborhood_text"].shape[:-1], self.linguistic_encoder.output_size))
+ masks.append(mask.view(*batch[f"{prefix}e{slot}_neighborhood_text"].shape[:-1]))
+
+ if self.config.get("gcn_aggregator", "") == "mean":
+ head: torch.Tensor = linguistic_embeddings[0].sum(-2)
+ tail: torch.Tensor = linguistic_embeddings[1].sum(-2)
+ neighborhood_size: torch.Tensor = sum(mask.sum(-1, keepdim=True) for mask in masks)
+ return self.gcn_layer((loop + head + tail) / (neighborhood_size + 1))
+ elif self.config.get("gcn_aggregator", "") == "chebyshev":
+ pre_embedding: torch.Tensor = loop / torch.sqrt(2 * (batch[f"{prefix}entity_degrees"].sum(-1, keepdim=True) - 1 + degree_delta))
+ for slot in [1, 2]:
+ weights = 1 / torch.sqrt(batch[f"{prefix}e{slot}_neighborhood_entity_degrees"].sum(-1, keepdim=True) - 1 + degree_delta)
+ pre_embedding += torch.sum(weights * linguistic_embeddings[slot-1], dim=-2)
+ return self.gcn_layer(pre_embedding)
+ else:
+ return (linguistic_embeddings[0], linguistic_embeddings[1], masks[0], masks[1])
diff --git a/gbure/outputs.py b/gbure/outputs.py
@@ -0,0 +1,55 @@
+from __future__ import annotations
+from typing import Dict, Optional, Type
+import pathlib
+import types
+
+import torch
+import transformers
+
+from gbure.data.dictionary import RelationDictionary
+
+
+class Outputs:
+ """
+ Class for outputing data about the learned model.
+ """
+ # TODO parametrize this class
+
+ def __init__(self, logdir: pathlib.Path, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: RelationDictionary) -> None:
+ """ Initialize the outputs variables, but do not acquire necessary resources yet. """
+ self.logdir: pathlib.Path = logdir
+ self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
+ self.relation_dictionary: RelationDictionary = relation_dictionary
+
+ def __enter__(self) -> Outputs:
+ """ Acquire resources needed for outputing the data. """
+ # TODO parametrize this file name with split and epoch
+ self.target_prediction_file = (self.logdir / "target_prediction").open("w")
+ return self
+
+ def __exit__(self,
+ exc_type: Optional[Type[BaseException]],
+ exc_inst: Optional[BaseException],
+ exc_tb: Optional[types.TracebackType]) -> None:
+ """ Free used resources. """
+ self.target_prediction_file.close()
+
+ def update(self, batch: Dict[str, torch.Tensor], loss: torch.Tensor, losses: Dict[str, torch.Tensor], variables: Dict[str, torch.Tensor]) -> None:
+ """
+ Update model output data files with the given batch and the outputs of the model on this batch.
+
+ The variables dictionary returned by the model must contain a predicted_relation tensor.
+
+ Args:
+ batch: the input values used for evaluation
+ loss: the loss optimized by the model
+ losses: intermediary (unweighted) losses
+ variables: internal variables used by the model to compute the loss
+ """
+ targets: torch.Tensor = batch.get("relation", batch.get("query_relation"))
+ predictions: torch.Tensor = variables.get("predicted_relation")
+ if targets is not None and predictions is not None:
+ for prediction, target in zip(predictions, targets):
+ print(f"{self.relation_dictionary.decode(target)}\t{self.relation_dictionary.decode(prediction)}", file=self.target_prediction_file)
+ elif "prediction_relative" in variables:
+ print("\n".join(map(str, variables["prediction_relative"].tolist())), file=self.target_prediction_file)
diff --git a/gbure/train.py b/gbure/train.py
@@ -0,0 +1,421 @@
+from typing import Any, Callable, Dict, Iterable, List, Optional
+import contextlib
+import gc
+import logging
+import math
+import os
+import pathlib
+import shutil
+
+from torch.utils.tensorboard import SummaryWriter
+import torch
+import torch.utils
+import tqdm
+import transformers
+
+import gbure.data.batcher
+import gbure.data.dataset
+import gbure.data.dictionary
+import gbure.data.graph
+import gbure.metrics
+import gbure.outputs
+import gbure.utils
+
+logger = logging.getLogger(__name__)
+
+
+class Trainer(gbure.utils.Experiment):
+ """
+ Train a model.
+
+ Config:
+ Model: the model class to use for training and evaluation
+ Optimizer: the optimizer class to use for training
+ Scheduler: the learning rate scheduler class to use
+ accumulated_batch_size: the actual number of sample in a batch after accumulation, that is the number of samples seen before an optimizer step (must be a multiple of batch_size)
+ amp: enable automatic mixed precision and gradient scaler
+ batch_size: the number of samples in the batch of data loaded
+ eval_batch_size: batch size used for evaluation
+ clip_gradient: the maximum norm of the gradient
+ dataset_name: name of the dataset to load
+ dataset_spec: dataset specification, usually None, can be used to select a (smaller) test version
+ eval_dataset_name: overwrite evaluation dataset
+ eval_dataset_spec: overwrite evaluation dataset specification
+ early_stopping_patience: how many epoch to train after best validation score has been reached
+ learning_rate: learning rate
+ max_epoch: maximum number of epoch
+ no_initial_validation: do not run evaluation on the valid dataset before first epoch
+ optimizer_hyperparameters: hyperparameters for the optimizer (e.g. weight decay, etc)
+ pretrained: path to a pretrained model to load
+ scheduler_parameters: the parameters for initializing the Scheduler
+ test_output: path to a file where the test predictions will be written
+ transformer_model: the model of transformer to use
+ unsupervised: train an unsupervised model
+ validation_metric: metric used for early stopping
+ workers: number of data generating workers to spawn
+ """
+
+ def init(self) -> None:
+ """ Prepare training. """
+ self.prepare_dataset()
+ self.build_model()
+ self.count_parameters()
+ self.setup_optimizer()
+ self.init_writer()
+ self.init_epochs()
+
+ def main(self) -> None:
+ """ Run the experiment (i.e. here, train). """
+ self.train()
+
+ def prepare_dataset(self) -> None:
+ """ Load datasets and create iterators. """
+ dataset_spec: str = self.config.transformer_model + self.config.get("dataset_spec", "")
+ self.data_dir: pathlib.Path = gbure.utils.DATA_PATH / self.config.dataset_name / dataset_spec
+
+ self.tokenizer: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(str(self.data_dir / "tokenizer"))
+ self.batcher = gbure.data.batcher.Batcher(self.tokenizer.pad_token_id)
+ self.relation_dictionary: Optional[gbure.data.dictionary.RelationDictionary] = None if self.config.get("unsupervised") else gbure.data.dictionary.RelationDictionary(path=self.data_dir / "relations")
+ entities_path: pathlib.Path = self.data_dir / "train" / "entities" if self.config.get("unsupervised") else self.data_dir / "entities"
+ self.entity_dictionary: gbure.data.dictionary.Dictionary = gbure.data.dictionary.Dictionary(path=entities_path)
+
+ if self.config.get("eval_dataset_name"):
+ eval_dataset_spec: str = self.config.transformer_model + self.config.get("eval_dataset_spec", "")
+ self.eval_data_dir: pathlib.Path = gbure.utils.DATA_PATH / self.config.eval_dataset_name / eval_dataset_spec
+ # We assume the tokenizer is the same.
+ self.eval_relation_dictionary: gbure.data.dictionary.RelationDictionary = gbure.data.dictionary.RelationDictionary(path=self.eval_data_dir / "relations")
+ # FIXME doesn't work when a test dataset is specified …
+ entities_path: pathlib.Path = self.data_dir / f"{self.config.valid_name}.entities" if self.config.get("valid_name") else self.data_dir / "entities"
+ self.eval_entity_dictionary: gbure.data.dictionary.Dictionary = gbure.data.dictionary.Dictionary(path=self.eval_data_dir / "entities")
+ else:
+ self.eval_data_dir: pathlib.Path = self.data_dir
+ self.eval_relation_dictionary: Optional[gbure.data.dictionary.RelationDictionary] = self.relation_dictionary
+ self.eval_entity_dictionary: gbure.data.dictionary.Dictionary = self.entity_dictionary
+
+ self.dataset: Dict[str, torch.utils.data.Dataset] = {}
+ self.iterator: Dict[str, Callable[[], Any]] = {}
+ graph_buffer: Optional[gbure.data.graph.Graph] = None
+
+ for split in ["train", "valid", "test"]:
+ kwargs: Dict[str, Any] = {"rng": self.state_dicts.get("train_rng")} if split == "train" and self.state_dicts else {}
+ split_path = pathlib.Path(split)
+ data_dir = self.data_dir if split == "train" else self.eval_data_dir
+ if f"{split}_name" in self.config:
+ split_path = pathlib.Path(self.config[f"{split}_name"])
+
+ try:
+ self.dataset[split] = gbure.data.dataset.load_dataset(
+ config=self.config,
+ split=split,
+ path=data_dir / split_path,
+ tokenizer=self.tokenizer,
+ evaluation=(split != "train"),
+ **kwargs)
+ except FileNotFoundError:
+ if split == "train":
+ raise
+ continue
+
+ if self.config.get("graph_name"):
+ graph_spec: str = self.config.transformer_model + self.config.get("graph_spec", "")
+ graph_dir: pathlib.Path = gbure.utils.DATA_PATH / self.config.graph_name / graph_spec
+ entity_dictionary = self.entity_dictionary if split == "train" else self.eval_entity_dictionary
+ self.dataset[split] = gbure.data.dataset.GraphAdapter(
+ self.dataset[split],
+ entity_dictionary,
+ path=graph_dir / "train",
+ graph=graph_buffer)
+ graph_buffer = self.dataset[split].graph
+
+ self.iterator[split] = lambda trainer=self, split=split: tqdm.tqdm(
+ iterable=torch.utils.data.DataLoader(
+ dataset=trainer.dataset[split],
+ collate_fn=trainer.batcher,
+ batch_size=trainer.config.batch_size if split == "train" else trainer.config.get("eval_batch_size", trainer.config.batch_size),
+ shuffle=(split == "train" and trainer.dataset[split].shuffleable),
+ num_workers=trainer.config.workers,
+ worker_init_fn=getattr(trainer.dataset[split], "init_seed", None),
+ pin_memory=(trainer.device.type == "cuda")),
+ desc=f"Epoch {trainer.epoch:2} {split:5}",
+ unit="samples",
+ unit_scale=trainer.config.batch_size if split == "train" else trainer.config.get("eval_batch_size", trainer.config.batch_size),
+ total=math.ceil(len(trainer.dataset[split]) / (trainer.config.batch_size if split == "train" else trainer.config.get("eval_batch_size", trainer.config.batch_size))),
+ leave=False)
+
+ def build_model(self) -> None:
+ """ Instantiate the model. """
+ self.model: torch.nn.Module = self.config.Model(self.config, self.tokenizer, self.relation_dictionary)
+ if self.state_dicts:
+ missing_keys: List[str]
+ unexpected_keys: List[str]
+ missing_keys, unexpected_keys = self.model.load_state_dict(self.state_dicts["model"], strict=False)
+ if missing_keys or unexpected_keys:
+ unexpected_display: List[str] = unexpected_keys
+ if self.config.get("pretrained"):
+ unexpected_display = list(filter(lambda key: not key.startswith("language_model."), unexpected_keys))
+ if len(unexpected_display) < len(unexpected_keys):
+ unexpected_display.append("language_model.*")
+ print("\033[31mLoading state_dict mismatch.\033[0m\n")
+ print(f"\033[31mMissing keys: {' '.join(missing_keys)}.\033[0m\n")
+ print(f"\033[31mUnexpected keys: {' '.join(unexpected_display)}.\033[0m\n")
+ assert(self.config.get("pretrained"))
+ self.model.to(self.device)
+ if self.config.get("EvalModel"):
+ self.eval_model: torch.nn.Module = self.config.EvalModel(self.config, self.tokenizer, self.relation_dictionary, train_model=self.model)
+ else:
+ self.eval_model: torch.nn.Module = self.model
+ self.mixed_precision = torch.cuda.amp.autocast if self.config.get("amp") else contextlib.nullcontext
+
+ def count_parameters(self) -> None:
+ """ Display the number of parameters used by the model. """
+ total: int = 0
+ for parameter in self.model.parameters():
+ total += parameter.shape.numel()
+ print(f"\033[33mNumber of parameters: {total:,}\033[0m")
+
+ def setup_optimizer(self) -> None:
+ """ Instantiate the optimizer. """
+ optimizer_hyperparameters: Dict[str, Any] = self.config.get("optimizer_hyperparameters", {})
+ optimizer_hyperparameters["lr"] = self.config.learning_rate
+ self.optimizer: torch.optim.optimizer.Optimizer = self.config.Optimizer(self.model.parameters(), **optimizer_hyperparameters)
+ # if self.state_dicts and "optimizer" in self.state_dicts:
+ # self.optimizer.load_state_dict(self.state_dicts["optimizer"])
+ if self.config.get("Scheduler"):
+ total_iters: int = math.ceil(self.config.max_epoch * len(self.dataset["train"]) / self.config.accumulated_batch_size)
+ self.scheduler = self.config.Scheduler(self.optimizer, total_iters=total_iters, **self.config.get("scheduler_parameters", {}))
+ if self.config.get("amp"):
+ self.scaler = torch.cuda.amp.GradScaler(self.config.get("initial_grad_scale", 65536))
+
+ def save(self, path: pathlib.Path) -> None:
+ """ Save the model and all additional state which might be needed to resume training. """
+ state_dicts: Dict[str, Any] = {
+ "logdir": self.logdir,
+ "config": self.config,
+ "model": self.model.state_dict(),
+ "optimizer": self.optimizer.state_dict(),
+ "torch_rng": torch.random.get_rng_state(),
+ "epoch": self.epoch,
+ "best_epoch": self.best_epoch,
+ "best_eval": self.best_eval,
+ "global_batch_id": self.global_batch_id,
+ "results": self.results,
+ }
+ if self.config.workers == 0:
+ state_dicts["train_rng"] = getattr(self.dataset["train"], "rng", None)
+ if torch.cuda.is_available():
+ state_dicts["cuda_rng"] = torch.cuda.random.get_rng_state_all()
+ torch.save(state_dicts, path)
+
+ def init_writer(self) -> None:
+ """ Initialize summary writer for tensorboard logging. """
+ self.writer = SummaryWriter(log_dir=self.logdir)
+ self.results: Dict[str, float] = {}
+ if self.state_dicts:
+ self.results = self.state_dicts.get("results", self.results)
+
+ def close(self) -> None:
+ """ Close the writer. """
+ self.writer.close()
+
+ def init_epochs(self) -> None:
+ """
+ Given initial values for early stopping.
+
+ Assume a "higher is better" metric.
+ """
+ self.epoch: int = 0
+ self.best_epoch: int = 0
+ self.best_eval: float = -math.inf
+ self.global_batch_id: int = 0
+
+ if self.state_dicts:
+ self.epoch = self.state_dicts.get("epoch", self.epoch)
+ self.best_epoch = self.state_dicts.get("best_epoch", self.best_epoch)
+ self.best_eval = self.state_dicts.get("best_eval", self.best_eval)
+ self.global_batch_id = self.state_dicts.get("global_batch_id", self.global_batch_id)
+
+ def checkpoint(self) -> None:
+ """ Save current model. """
+ self.save(self.logdir / "checkpoint.new")
+ os.rename(self.logdir / "checkpoint.new", self.logdir / "checkpoint")
+
+ def update_best_model(self, candidate: float) -> bool:
+ """ Update current best model and return whether we should early stop. """
+ if candidate is math.nan:
+ return False
+
+ if candidate > self.best_eval:
+ logger.info(f"Updating best model at epoch {self.epoch} (metric {candidate} > {self.best_eval})")
+ self.best_epoch = self.epoch
+ self.best_eval = candidate
+ shutil.copy(self.logdir / "checkpoint", self.logdir / "best")
+ return (self.epoch - self.best_epoch > self.config.get("early_stopping_patience", self.config.max_epoch))
+
+ def load_best_model(self) -> None:
+ """ Load the model with the best validation score. """
+ best_path: pathlib.Path = self.logdir / "best"
+ if self.best_eval != -math.inf and best_path.exists():
+ print(f"Loading best model from {best_path}…", end="", flush=True)
+ self.model.load_state_dict(torch.load(best_path)["model"])
+ print(" done")
+ logger.info(f"{best_path} loaded")
+
+ def train_apply_grads(self) -> None:
+ """ Apply the gradients to the parameters. """
+ if self.config.get("amp"):
+ if self.config.get("clip_gradient"):
+ self.scaler.unscale_(self.optimizer)
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip_gradient)
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+ else:
+ if self.config.get("clip_gradient"):
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip_gradient)
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ if self.config.get("Scheduler"):
+ self.scheduler.step()
+
+ def train(self) -> None:
+ """ Train the model with intermediate evaluation on the validation dataset and final evaluation on the test dataset. """
+ assert(self.config.accumulated_batch_size % self.config.batch_size == 0)
+ data_batch_per_true_batch: int = self.config.accumulated_batch_size // self.config.batch_size
+
+ if not self.config.get("no_initial_validation"):
+ self.best_eval = self.evaluate("valid")
+
+ for self.epoch in range(self.epoch+1, self.config.max_epoch+1):
+ if self.interrupted:
+ break
+ self.model.train()
+ self.optimizer.zero_grad()
+ total_loss: float = 0.0
+ total_sample: int = 0
+
+ epoch_loop = self.iterator["train"]()
+ for batch_id, batch in enumerate(epoch_loop):
+ batch: Dict[str, torch.Tensor] = {key: value.to(self.device) for key, value in batch.items()}
+
+ loss: torch.Tensor
+ losses: Dict[str, torch.Tensor]
+ variables: Dict[str, torch.Tensor]
+ with self.mixed_precision():
+ loss, losses, variables = self.model(**batch)
+ if self.config.get("amp"):
+ self.scaler.scale(loss).backward()
+ else:
+ loss.backward()
+
+ if torch.isfinite(loss):
+ total_loss += loss.item()
+ total_sample += next(iter(batch.values())).shape[0]
+ epoch_loop.set_postfix(loss=f"{total_loss / total_sample:.2f} .", refresh=False)
+ else:
+ epoch_loop.set_postfix(loss=f"{total_loss / (total_sample if total_sample else 1):.2f} NaN", refresh=False)
+
+ self.writer.add_scalar("Loss/train", loss, self.global_batch_id)
+ for key, value in losses.items():
+ self.writer.add_scalar(f"{key}/train", value, self.global_batch_id)
+
+ if batch_id % 1000 == 0:
+ for key, value in variables.items():
+ self.writer.add_histogram(f"{key}/train", value, self.global_batch_id)
+ if batch_id % data_batch_per_true_batch == data_batch_per_true_batch - 1:
+ self.train_apply_grads()
+ self.global_batch_id += 1
+
+ if total_sample > 0:
+ self.results[f"train.loss"] = float(total_loss / total_sample)
+ else:
+ total_sample = 1
+
+ if batch_id % data_batch_per_true_batch != data_batch_per_true_batch - 1:
+ self.train_apply_grads()
+
+ print(f"Epoch {self.epoch:2} train mean loss: {total_loss / total_sample:8.4f}")
+ logger.info(f"epoch {self.epoch} train mean_loss {total_loss / total_sample}")
+
+ self.checkpoint()
+ candidate: float = self.evaluate("valid")
+ if self.update_best_model(candidate):
+ break
+
+ if "test" in self.iterator:
+ self.load_best_model()
+ self.evaluate("test")
+ self.write_results()
+
+ def metric_message(self, split: str) -> str:
+ """ Message to be displayed after an evaluation. """
+ message = "Epoch {self.epoch:2} {split:5} loss: {metrics.loss:8.4f}"
+ if not isinstance(self.dataset[split], gbure.data.dataset.UnsupervisedDataset) or hasattr(self.dataset[split], "dataset") and not isinstance(self.dataset[split].dataset, gbure.data.dataset.UnsupervisedDataset):
+ message += " accuracy: {metrics.accuracy:8.6f}"
+ if isinstance(self.dataset[split], gbure.data.dataset.SupervisedDataset):
+ message += " half-directed Macro F1: {metrics.half_directed_macro_f1:8.6f} (P: {metrics.half_directed_macro_precision:8.6f} R: {metrics.half_directed_macro_recall:8.6f})"
+ if isinstance(self.dataset[split], gbure.data.dataset.GraphAdapter) and (isinstance(self.dataset[split].dataset, gbure.data.dataset.FewShotDataset) or isinstance(self.dataset[split].dataset, gbure.data.dataset.SampledFewShotDataset)):
+ message += " accuracy non-empty: {metrics.accuracy_non_empty:8.6f} accuracy full: {metrics.accuracy_full:8.6f}"
+ return message
+
+ def evaluate(self, split: str) -> float:
+ """ Evaluate the model on the given split and return the early stopping metric. """
+ if split not in self.iterator:
+ if not self.config.get("unsupervised"):
+ print(f"\033[31m{split} dataset does not exist, evaluation skipped.\033[0m")
+ return math.nan
+
+ self.model.zero_grad(set_to_none=True)
+ gc.collect()
+ self.eval_model.eval()
+ with torch.no_grad(), gbure.outputs.Outputs(self.logdir, self.tokenizer, self.eval_relation_dictionary) as outputs:
+ metrics = gbure.metrics.Metrics(self.config, self.tokenizer, self.eval_relation_dictionary, getattr(self.dataset[split], "graph", None))
+ loop = self.iterator[split]()
+ for batch in loop:
+ batch: Dict[str, torch.Tensor] = {key: value.to(self.device) for key, value in batch.items()}
+ loss, losses, variables = self.eval_model(**batch)
+ metrics.update(batch, loss, losses, variables)
+ outputs.update(batch, loss, losses, variables)
+ loop.set_postfix(**metrics.summary, refresh=False)
+ print(self.metric_message(split).format(**locals()))
+ for metric, value in metrics.all.items():
+ logger.info(f"epoch {self.epoch} {split} {metric} {value}")
+
+ candidate: float = getattr(metrics, self.config.validation_metric) if self.config.get("validation_metric") else math.nan
+ if split == "test" or (split == "valid" and candidate > self.best_eval):
+ for key, value in metrics.all.items():
+ self.results[f"{split}.{key}"] = float(value)
+ return candidate
+
+ def write_results(self):
+ """ Write best valid and test score as tensorboard events. """
+ self.writer.add_hparams(
+ hparam_dict=gbure.utils.flatten_dict(self.config),
+ metric_dict=self.results,
+ run_name="hparams")
+
+
+if __name__ == "__main__":
+ gbure.utils.fix_transformers_logging_handler()
+
+ config: gbure.utils.dotdict = gbure.utils.parse_args()
+ logdir: pathlib.Path
+ state_dicts: Optional[Dict[str, Any]] = None
+ new_log_dir: bool = True
+
+ if config.get("load"):
+ state_dicts = torch.load(config.load, map_location=torch.device('cpu'))
+ if not config.get("overwrite_config"):
+ config = state_dicts["config"]
+ logdir = state_dicts["logdir"]
+ new_log_dir = False
+
+ if config.get("pretrained"):
+ state_dicts = torch.load(config.pretrained, map_location=torch.device('cpu'))
+ state_dicts = {"model": state_dicts["model"]}
+
+ if new_log_dir:
+ logdir = gbure.utils.LOG_PATH / gbure.utils.logdir_name("GBURE")
+ assert(not logdir.exists())
+ logdir.mkdir()
+
+ gbure.utils.add_logging_handler(logdir)
+ Trainer(config, logdir, state_dicts).run()
diff --git a/gbure/utils.py b/gbure/utils.py
@@ -0,0 +1,387 @@
+from typing import Any, Callable, Dict, List, NoReturn, Optional, Union
+import hashlib
+import importlib
+import logging
+import multiprocessing
+import os
+import pathlib
+import signal
+import subprocess
+import sys
+import time
+import types
+
+import torch
+
+logger = logging.getLogger(__name__)
+
+_HAS_DYNAMIC_ATTRIBUTES = True
+
+
+def import_environment(name: str, cast: type = str) -> None:
+ """ Import an environment variable into the global namespace. """
+ try:
+ globals()[name] = cast(os.environ[name])
+ except KeyError:
+ print(f"ERROR: {name} environment variable is not set.", file=sys.stderr)
+ sys.exit(1)
+
+
+import_environment("DATA_PATH", pathlib.Path)
+import_environment("LOG_PATH", pathlib.Path)
+
+
+class dotdict(dict):
+ """ Dictionary which can be access through var.key instead of var["key"]. """
+ def __getattr__(self, name: str) -> Any:
+ if name not in self:
+ raise AttributeError(f"Config key {name} not found")
+ return dotdict(self[name]) if type(self[name]) is dict else self[name]
+ __setattr__ = dict.__setitem__
+ __delattr__ = dict.__delitem__
+
+
+def eval_arg(config: Dict[str, Any], arg: str) -> None:
+ """
+ Evaluate arg in the context config, and update it.
+
+ The argument is expected to be of the form:
+ (parent.)*key(=value)?
+ If no value is provided, the key is assumed to be a boolean and True is assigned to it.
+ When passing a string argument through the shell, it must be enclosed in quote (like all python string), which usually need to be escaped.
+ """
+ key: str
+ value: Any
+ if '=' in arg:
+ key, value = arg.split('=', maxsplit=1)
+ value = eval(value, config)
+ else:
+ key, value = arg, True
+ path: List[str] = key.split('.')
+ for component in path[:-1]:
+ config = config[component]
+ config[path[-1]] = value
+ config.pop("__builtins__", None)
+
+
+def import_arg(config: Dict[str, Any], arg: str) -> None:
+ """
+ Load file arg, and update config with its content.
+
+ The file is loaded in an independent context, all the variable defined in the file (even through import) are added to config, with the exception of builtins and whole modules.
+ """
+ if arg.endswith(".py"):
+ arg = arg[:-3].replace('/', '.')
+ module: types.ModuleType = importlib.import_module(arg)
+ for key, value in vars(module).items():
+ if key not in module.__builtins__ and not key.startswith("__") and not isinstance(value, types.ModuleType): # pytype: disable=attribute-error
+ config[key] = value
+
+
+def parse_args() -> dotdict:
+ """
+ Parse command line arguments and return config dictionary.
+
+ Two kind of argument are supported:
+ - When the argument starts with -- it is evaluated by the eval_arg function
+ - Otherwise the argument is assumed to be a file which is loaded by the import_arg function
+ """
+ config: Dict[str, Any] = {}
+ config["config"] = config
+ for arg in sys.argv[1:]:
+ if arg.startswith("--"):
+ eval_arg(config, arg[2:])
+ else:
+ import_arg(config, arg)
+ config.pop("config")
+ return dotdict(config)
+
+
+def display_dict(output: Callable[[str], None], input: Dict[str, Any], depth: int = 0) -> None:
+ """ Display nested dictionaries in input using the provided output function. """
+ for key, value in input.items():
+ indent = '\t'*depth
+ output(f"{indent}{key}:")
+ if isinstance(value, dict):
+ output('\n')
+ display_dict(output, value, depth+1)
+ else:
+ output(f" {value}\n")
+
+
+def print_dict(input: Dict[str, Any]) -> None:
+ """ Print dictionary to standard output. """
+ display_dict(lambda x: print(x, end=""), input)
+
+
+def log_dict(logger: logging.Logger, input: Dict[str, Any]) -> None:
+ """ Log dictionary to the provided logger. """
+ class log:
+ buf: str = ""
+
+ def __call__(self, x: str) -> None:
+ self.buf += x
+ if self.buf.endswith('\n'):
+ logger.info(self.buf[:-1])
+ self.buf = ""
+ display_dict(log(), input)
+
+
+def flatten_dict(input: Dict[str, Any]) -> Dict[str, Union[bool, int, float, str]]:
+ """
+ Replace nested dict by dot-separated keys, and cast keys to simple types.
+
+ repr() is used to cast non-base-type to str.
+ """
+ def impl(result: Dict[str, Union[bool, int, float, str]], input: Dict[str, Any], prefix: str):
+ for key, value in input.items():
+ if isinstance(value, dict):
+ impl(result, value, f"{key}.")
+ else:
+ result[f"{prefix}{key}"] = value if type(value) in [bool, int, float, str] else repr(value)
+
+ result: Dict[str, Union[bool, int, float, str]] = {}
+ impl(result, input, "")
+ return result
+
+
+def get_repo_version() -> str:
+ """ Get the code repository version. """
+ repo_dir = pathlib.Path(__file__).parents[0]
+ result: subprocess.CompletedProcess = subprocess.run(
+ ["git", "rev-parse", "HEAD"],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.DEVNULL,
+ encoding="utf-8",
+ cwd=repo_dir)
+
+ if result.returncode != 0:
+ return "release"
+ commit_hash: str = result.stdout.strip()[:8]
+
+ result = subprocess.run(
+ ["git", "status", "--porcelain"],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.DEVNULL,
+ encoding="utf-8",
+ cwd=repo_dir)
+ modified_flag: str = ""
+ for line in result.stdout.split('\n'):
+ if line.startswith(" M "):
+ modified_flag = "+"
+ break
+
+ return f"{commit_hash}{modified_flag}"
+
+
+def experiment_name(name: str) -> str:
+ """ Name of the experiment (contains repository version, argument and time). """
+ args: str = ' '.join(sys.argv[1:])
+ version: str = get_repo_version()
+ stime: str = time.strftime("%FT%H:%M:%S")
+ return f"{name} {version} {args} {stime}"
+
+
+def logdir_name(name: str) -> str:
+ """ Name of the experiment directory, it should be the experiment_name clipped because of filesystem constraints. """
+ subdir: str = experiment_name(name).replace('/', '_')
+ if len(subdir) > 255:
+ sha1: str = hashlib.sha1(subdir.encode("utf-8")).digest().hex()[:16]
+ subdir = subdir[:255-17] + ' ' + sha1
+ return subdir
+
+
+def fix_transformers_logging_handler() -> None:
+ """ The transformers package from huggingface install its own logger on import, I don't want it. """
+ logger: logging.Logger = logging.getLogger()
+ for handler in logger.handlers:
+ logger.removeHandler(handler)
+
+
+def add_logging_handler(logdir: pathlib.Path) -> None:
+ logfile: pathlib.Path = logdir / "log"
+ logging.basicConfig(format="%(asctime)s\t%(levelname)s:%(name)s:%(message)s", filename=logfile, filemode='a', level=logging.INFO)
+
+
+def save_patch(outpath: pathlib.Path) -> None:
+ """ Save a file at the given patch containing the diff between the current code and the last commit. """
+ repo_dir = pathlib.Path(__file__).parents[0]
+
+ with outpath.open("w") as outfile:
+ result: subprocess.CompletedProcess = subprocess.run(
+ ["git", "diff", "HEAD"],
+ stdout=outfile,
+ stderr=subprocess.DEVNULL,
+ encoding="utf-8",
+ cwd=repo_dir)
+
+ assert(result.returncode == 0)
+
+
+class Experiment:
+ """
+ Base class for running an experiment.
+
+ Calling run on an instance of this class will call the init() then main() functions.
+ The sole purpose of this class is to make an experiment "prettier": it displays config values, store a diff of the repo in the experiment logdir, etc.
+
+ Config:
+ deterministic: run in deterministic mode
+ seed: seed for random number generators
+ """
+
+ _HAS_DYNAMIC_ATTRIBUTES = True
+
+ def __init__(self, config: dotdict, logdir: pathlib.Path, state_dicts: Optional[Dict[str, Any]] = None) -> None:
+ self.config = config
+ self.logdir = logdir
+ self.state_dicts = state_dicts
+
+ def init(self) -> None:
+ """ Prepare the experiment (e.g. initialize datasets and models). """
+ pass
+
+ def main(self) -> NoReturn:
+ """ Run the experiment in itself. """
+ raise NotImplementedError("Subclasses must implement a main method")
+
+ def close(self) -> None:
+ """ Free used resources (e.g. close opened files). """
+ pass
+
+ def run(self) -> None:
+ """ Run the whole experiment with setting ups, etc. """
+ self.log_environment()
+ self.log_patch()
+ self.initialize_rng()
+ self.init()
+ self.hook_signals()
+ self.main()
+ self.close()
+
+ def log_environment(self) -> None:
+ """ Display information about the environment. """
+ print(f"logdir is \033[1m\033[33m{self.logdir}\033[0m")
+ self.version_check()
+ self.detect_gpus()
+ print("")
+
+ print("\033[1m\033[33mConfiguration\033[0m")
+ print_dict(self.config)
+ log_dict(logging.getLogger("config"), self.config)
+ print("")
+
+ def version_check(self) -> None:
+ """ Check the version of the main dependencies. """
+ python_version: str = '.'.join(map(str, sys.version_info[:3]))
+ torch_version: str = torch.__version__
+ cuda_available: str = str(torch.cuda.is_available())
+
+ logger.info(f"python version {python_version}")
+ logger.info(f"torch version {torch_version}")
+ logger.info(f"cuda available {cuda_available}")
+
+ def problem(msg: str) -> str:
+ return f"\033[1m\033[31m{msg}\033[0m"
+ if sys.version_info < (3, 7):
+ python_version = problem(python_version)
+ if list(map(int, torch_version.split('+')[0].split('.'))) < [1, 6]:
+ torch_version = problem(torch_version)
+ if cuda_available == "False":
+ cuda_available = problem(cuda_available)
+ print(f"python version: {python_version}, torch version: {torch_version}, cuda available: {cuda_available}")
+
+ def detect_gpus(self) -> None:
+ """ Display available gpus and set self.device. """
+ count: int = torch.cuda.device_count()
+
+ if count == 0:
+ print(f"\033[1m\033[31mNo GPU available\033[0m")
+ logger.warning("no GPU available")
+ self.device = torch.device("cpu")
+ else:
+ self.device = torch.device("cuda:0")
+
+ for i in range(count):
+ gp = torch.cuda.get_device_properties(i)
+ print(f"GPU{i}: \033[33m{gp.name}\033[0m (Mem: {gp.total_memory/2**30:.2f}GiB CC: {gp.major}.{gp.minor})")
+ logger.info(f"GPU{i} {gp.name} {gp.total_memory} {gp.major}.{gp.minor}")
+
+ def log_patch(self) -> None:
+ """ Check the version of the code, and save a patch to logpath if it was modified. """
+ version: str = get_repo_version()
+ logger.info(f"repository_version {version}")
+ if version == "release":
+ print(f"\033[41mRelease version\033[0m\n")
+ elif version.endswith('+'):
+ print(f"\033[31mUncommited changes detected, saving patch to logdir.\033[0m\n")
+ suffix: str = ""
+ if self.state_dicts: # Reloading an existing Trainer
+ suffix = time.strftime("%FT%H:%M:%S")
+ save_patch(self.logdir / f"patch{suffix}")
+
+ def initialize_rng(self) -> None:
+ if self.state_dicts and "torch_rng" in self.state_dicts:
+ torch.random.set_rng_state(self.state_dicts["torch_rng"])
+ assert(("cuda_rng" in self.state_dicts) == torch.cuda.is_available())
+ if "cuda_rng" in self.state_dicts:
+ torch.cuda.random.set_rng_state_all(self.state_dicts["cuda_rng"])
+ else:
+ torch.manual_seed(self.config.seed)
+
+ if self.config.get("deterministic") and torch.backends.cudnn.enabled:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ def hook_signals(self) -> None:
+ """ Change the behavior of SIGINT (^C) to change a variable `self.interrupted' before killing the process. """
+ self.interrupted: bool = False
+
+ def handler(sig: int, frame: types.FrameType) -> None:
+ if multiprocessing.current_process().name != "MainProcess":
+ return
+
+ print("\n\033[31mInterrupted, execution will stop at the end of this epoch.\n\033[1mNEXT ^C WILL KILL THE PROCESS!\033[0m\n", file=sys.stderr)
+ self.interrupted = True
+ signal.signal(signal.SIGINT, signal.SIG_DFL)
+
+ signal.signal(signal.SIGINT, handler)
+
+
+class SharedLongTensorList:
+ def __init__(self, tensor_list: List[torch.Tensor], view: List[int] = [-1]):
+ self.view = view
+
+ total_element: int = 0
+ for tensor in tensor_list:
+ total_element += tensor.numel()
+ self.data: torch.Tensor = torch.empty(total_element, dtype=torch.int64)
+ self.indices: torch.Tensor = torch.empty(len(tensor_list)+1, dtype=torch.int64)
+
+ data_pos: int = total_element
+ indices_pos: int = len(tensor_list)
+ self.indices[indices_pos] = data_pos
+ while tensor_list:
+ tensor: torch.Tensor = tensor_list.pop()
+ tensor_size: int = tensor.numel()
+
+ indices_pos -= 1
+ data_pos -= tensor_size
+
+ self.indices[indices_pos] = data_pos
+ self.data[data_pos:data_pos+tensor_size] = tensor.flatten()
+ assert(data_pos == 0)
+ assert(indices_pos == 0)
+
+ def __len__(self) -> int:
+ return self.indices.shape[0]-1
+
+ def __getitem__(self, key: Union[int, slice]) -> torch.Tensor:
+ if isinstance(key, slice):
+ return [self[value] for value in range(*key.indices(len(self)))]
+ elif isinstance(key, int):
+ return self.data[self.indices[key]:self.indices[key+1]].view(*self.view)
+ elif isinstance(key, torch.Tensor):
+ return self[key.item()]
+ else:
+ raise TypeError("Invalid argument type.")
diff --git a/requirements.txt b/requirements.txt
@@ -1,4 +0,0 @@
-numpy>=1.16
-torch==1.3.0
-tqdm>=4
-transformers==2.0.0
diff --git a/scripts/contrastive_alignment.sh b/scripts/contrastive_alignment.sh
@@ -0,0 +1,58 @@
+#!/bin/bash
+#SBATCH -A lco@gpu
+#SBATCH -C v100-32g
+#SBATCH --job-name=gbure_contrastive
+#SBATCH --ntasks=1
+#SBATCH --gres=gpu:1
+#SBATCH --cpus-per-task=10
+#SBATCH --distribution=block:block
+#SBATCH --hint=nomultithread
+#SBATCH --time=20:00:00
+#SBATCH --output=/gpfswork/rech/lco/url46ht/Étienne/runs/%x_%j.stdout
+#SBATCH --error=/gpfswork/rech/lco/url46ht/Étienne/runs/%x_%j.stderr
+#SBATCH --array=0-23
+
+# %x = nom du job
+# %j = id du job
+
+module purge
+source ~/.bashrc
+conda activate Étienne
+export DATA_PATH=$WORK/Étienne/data
+export LOG_PATH=$WORK/Étienne/log
+export HF_DATASETS_OFFLINE=1
+export TRANSFORMERS_OFFLINE=1
+cd $WORK/Étienne/code
+
+# echo des commandes lancées
+set -x
+
+config=""
+
+if [ $(( $SLURM_ARRAY_TASK_ID % 3 )) -eq 0 ]; then
+ config="$config --language_model_weight=0"
+elif [ $(( $SLURM_ARRAY_TASK_ID % 3 )) -eq 1 ]; then
+ config="$config --language_model_weight=0.1"
+else
+ config="$config --language_model_weight=1"
+fi
+
+if [ $(( $SLURM_ARRAY_TASK_ID / 3 % 4 )) -eq 0 ]; then
+ config="$config --margin=0.1"
+elif [ $(( $SLURM_ARRAY_TASK_ID / 3 % 4 )) -eq 1 ]; then
+ config="$config --margin=1"
+elif [ $(( $SLURM_ARRAY_TASK_ID / 3 % 4 )) -eq 2 ]; then
+ config="$config --margin=10"
+else
+ config="$config --margin=100"
+fi
+
+if [ $(( $SLURM_ARRAY_TASK_ID / 12 % 2 )) -eq 0 ]; then
+ config="$config --topological_weight=1"
+else
+ config="$config --topological_weight=0.2"
+fi
+
+config="$config --seed=$(($SLURM_ARRAY_TASK_ID / 24))"
+
+python -m gbure.train gbure/config/contrastive_alignment.py $config --post_transformer_layer=\"layer_norm\"
diff --git a/scripts/mtb.sh b/scripts/mtb.sh
@@ -0,0 +1,47 @@
+#!/bin/bash
+#SBATCH -A lco@gpu
+#SBATCH -C v100-32g
+#SBATCH --job-name=gbure_mtb
+#SBATCH --ntasks=1
+#SBATCH --gres=gpu:1
+#SBATCH --cpus-per-task=10
+#SBATCH --distribution=block:block
+#SBATCH --hint=nomultithread
+#SBATCH --time=20:00:00
+#SBATCH --output=/gpfswork/rech/lco/url46ht/Étienne/runs/%x_%j.stdout
+#SBATCH --error=/gpfswork/rech/lco/url46ht/Étienne/runs/%x_%j.stderr
+#SBATCH --array=0-3
+
+# %x = nom du job
+# %j = id du job
+
+module purge
+source ~/.bashrc
+conda activate Étienne
+export DATA_PATH=$WORK/Étienne/data
+export LOG_PATH=$WORK/Étienne/log
+export HF_DATASETS_OFFLINE=1
+export TRANSFORMERS_OFFLINE=1
+cd $WORK/Étienne/code
+
+# echo des commandes lancées
+set -x
+
+config=""
+
+if [ $(( $SLURM_ARRAY_TASK_ID % 2 )) -eq 0 ]; then
+ config="$config --latent_metric_scale=\"standard\" --latent_dot_mean=1067.65 --latent_dot_std=111.17"
+else
+ config="$config --latent_metric_scale=\"sqrt\" --latent_dot_mean=None --latent_dot_std=None"
+fi
+
+if [ $(( $SLURM_ARRAY_TASK_ID / 2 % 2 )) -eq 0 ]; then
+ config="$config --filter_empty_neighborhood=True"
+else
+ config="$config --filter_empty_neighborhood=False"
+fi
+
+config="$config --seed=$(($SLURM_ARRAY_TASK_ID / 4))"
+
+# We use the gcn_mtb config with a 0 topological weight in order to get accuracies in function of neighborhood sizes
+python -m gbure.train gbure/config/gcn_mtb.py $config --post_transformer_layer=\"layer_norm\" --topological_weight=0
diff --git a/scripts/mtb_gcn.sh b/scripts/mtb_gcn.sh
@@ -0,0 +1,54 @@
+#!/bin/bash
+#SBATCH -A lco@gpu
+#SBATCH -C v100-32g
+#SBATCH --job-name=gbure_gcn_mtb
+#SBATCH --ntasks=1
+#SBATCH --gres=gpu:1
+#SBATCH --cpus-per-task=10
+#SBATCH --distribution=block:block
+#SBATCH --hint=nomultithread
+#SBATCH --time=20:00:00
+#SBATCH --output=/gpfswork/rech/lco/url46ht/Étienne/runs/%x_%j.stdout
+#SBATCH --error=/gpfswork/rech/lco/url46ht/Étienne/runs/%x_%j.stderr
+#SBATCH --array=0-11
+
+# %x = nom du job
+# %j = id du job
+
+module purge
+source ~/.bashrc
+conda activate Étienne
+export DATA_PATH=$WORK/Étienne/data
+export LOG_PATH=$WORK/Étienne/log
+export HF_DATASETS_OFFLINE=1
+export TRANSFORMERS_OFFLINE=1
+cd $WORK/Étienne/code
+
+# echo des commandes lancées
+set -x
+
+config=""
+
+if [ $(( $SLURM_ARRAY_TASK_ID % 3 )) -eq 0 ]; then
+ config="$config --gcn_aggregator=\"none\""
+elif [ $(( $SLURM_ARRAY_TASK_ID % 3 )) -eq 1 ]; then
+ config="$config --gcn_aggregator=\"mean\""
+else
+ config="$config --gcn_aggregator=\"chebyshev\""
+fi
+
+if [ $(( $SLURM_ARRAY_TASK_ID / 3 % 2 )) -eq 0 ]; then
+ config="$config --topological_weight=1"
+else
+ config="$config --topological_weight=0.2"
+fi
+
+if [ $(( $SLURM_ARRAY_TASK_ID / 6 % 2 )) -eq 0 ]; then
+ config="$config --filter_empty_neighborhood=True"
+else
+ config="$config --filter_empty_neighborhood=False"
+fi
+
+config="$config --seed=$(($SLURM_ARRAY_TASK_ID / 12))"
+
+python -m gbure.train gbure/config/gcn_mtb.py $config --post_transformer_layer=\"layer_norm\"
diff --git a/scripts/nonparametric.sh b/scripts/nonparametric.sh
@@ -0,0 +1,42 @@
+#!/bin/bash
+#SBATCH -A lco@gpu
+#SBATCH -C v100-32g
+#SBATCH --job-name=gbure_nonparametric
+#SBATCH --ntasks=1
+#SBATCH --gres=gpu:1
+#SBATCH --cpus-per-task=20
+#SBATCH --distribution=block:block
+#SBATCH --hint=nomultithread
+#SBATCH --time=20:00:00
+#SBATCH --output=/gpfswork/rech/lco/url46ht/Étienne/runs/%x_%j.stdout
+#SBATCH --error=/gpfswork/rech/lco/url46ht/Étienne/runs/%x_%j.stderr
+#SBATCH --array=0-2
+
+# %x = nom du job
+# %j = id du job
+
+module purge
+source ~/.bashrc
+conda activate Étienne
+export DATA_PATH=$WORK/Étienne/data
+export LOG_PATH=$WORK/Étienne/log
+export HF_DATASETS_OFFLINE=1
+export TRANSFORMERS_OFFLINE=1
+cd $WORK/Étienne/code
+
+# echo des commandes lancées
+set -x
+
+config=""
+
+if [ $(( $SLURM_ARRAY_TASK_ID % 3 )) -eq 0 ]; then
+ config="$config --undefined_poison_whole_meta=True"
+elif [ $(( $SLURM_ARRAY_TASK_ID % 3 )) -eq 1 ]; then
+ config="$config --undefined_poison_whole_meta=False --neutral_topological_similarity=None"
+else
+ config="$config --undefined_poison_whole_meta=False --neutral_topological_similarity=(1536*2)**0.5"
+fi
+
+for topological_weight in 0.1 0.15 0.2 0.22 0.25 0.3 0.5 1; do
+ python -m gbure.train gbure/config/nonparametric.py $config --topological_weight=$topological_weight
+done