commit bafbf5ff3a8f8f39573719da56083cacac5345f9
parent a2553cee69c02fca47496c1cb5c23dc99b5ac5c8
Author: Étienne Simon <esimon@esimon.eu>
Date: Thu, 21 Nov 2019 12:38:01 +0100
Fix metrics
Diffstat:
11 files changed, 409 insertions(+), 151 deletions(-)
diff --git a/README b/README
@@ -1,23 +1,35 @@
Reproduction of Matching the Blanks: Distributional Similarity for Relation Learning by Livio Baldini Soares, Nicholas FitzGerald, Jeffrey Ling, and Tom Kwiatkowski.
-This repository currently contains the supervised model for the Semeval and KBP37 datasets, the unsupervised MTB model and the FewRel dataset will be added latter on.
-In order to reproduce the results, you first need to manually download and extract the datasets (e.g. in /tmp), then do the following:
+This repository currently contains the supervised "entity markers-entity start" model for the Semeval and KBP37 datasets, the unsupervised MTB model and the FewRel dataset will be added latter on.
+In order to reproduce the results, you first need to manually download and extract the datasets (e.g. in /tmp), then execute the following:
$ export DATA_PATH="/tmp" # the directory containing the extracted datasets
$ export LOG_PATH="/tmp" # the directory where the models will be saved
$ git clone https://gitlab.lip6.fr/esimon/few-shot-relation-extraction
$ cd few-shot-relation-extraction
-$ python -m fsre.data.prepare_semeval bert-large-cased
+$ python -m fsre.data.prepare_semeval
$ python -m fsre.train fsre/config/soares_supervised_semeval.py
-$ python -m fsre.data.prepare_kbp37 bert-large-cased
+$ python -m fsre.data.prepare_kbp37
$ python -m fsre.train fsre/config/soares_supervised_kbp37.py
-I must be missing something since I'm still far from the results reported in the paper:
+I must be missing something since I'm a bit away from the results reported in the paper.
+Here are the score reached by this repository (official SemEval macro F1 "taking directionality into account"):
On Semeval:
- valid macro F1: paper 82.1 vs us 81.5 (accuracy 84.7)
- test macro F1: paper 89.2 vs us 81.8 (accuracy 84.8)
+ paper valid: 82.1
+ BERT cased valid: 88.55 (std: 0.70)
+ BERT uncased valid: 87.49 (std: 0.87)
+ paper test: 89.2
+ BERT cased test: 88.24 (std: 0.37)
+ BERT uncased test: 88.02 (std: 0.65)
On KBP37:
- valid macro F1: paper 70.0 vs us 65.7 (accuracy 65.5)
- test macro F1: paper 68.3 vs us 63.2 (accuracy 64.2)
+ paper valid: 70.0
+ BERT cased valid: 66.92 (std: 0.61)
+ BERT uncased valid: 66.48 (std: 0.46)
+ paper test: 68.3
+ BERT cased test: 67.18 (std: 0.51)
+ BERT uncased test: 66.21 (std: 0.62)
+
+The mean and std are computed over 5 runs, all reported results are for "large" BERT models.
+Detailed results can be found at http://www-ia.lip6.fr/~esimon/results.xhtml (temporary link).
diff --git a/fsre/__init__.py b/fsre/__init__.py
@@ -1,3 +1,4 @@
import fsre.data
import fsre.utils
import fsre.metrics
+import fsre.train
diff --git a/fsre/config/soares_supervised_kbp37.py b/fsre/config/soares_supervised_kbp37.py
@@ -11,7 +11,7 @@ learning_rate = 3e-5
true_batch_size = 64
# Guessed
-validation_metric = "f1"
+validation_metric = "half_directed_macro_f1"
early_stopping_patience = 2
# Implementation details
diff --git a/fsre/config/soares_supervised_semeval.py b/fsre/config/soares_supervised_semeval.py
@@ -11,7 +11,7 @@ learning_rate = 3e-5
true_batch_size = 64
# Guessed
-validation_metric = "f1"
+validation_metric = "half_directed_macro_f1"
early_stopping_patience = 2
# Implementation details
diff --git a/fsre/data/prepare_fewrel.py b/fsre/data/prepare_fewrel.py
@@ -1,61 +0,0 @@
-import argparse
-import numpy
-import transformers
-import tqdm
-
-from fsre.utils import DATA_PATH
-from fsre.data.relation_dictionary import RelationDictionary
-
-
-def load_fewrel_dataset(path, tokenizer, relation_dictionary):
- pass
-
-
-def prepare_fewrel(args):
- rng = numpy.random.RandomState(args.seed)
- fewrel_path = DATA_PATH / "FewRel"
- output_path = fewrel_path / args.tokenizer
-
- if not output_path.is_dir():
- output_path.mkdir()
-
- relation_dictionary = RelationDictionary()
-
- tokenizer = transformers.BertTokenizer.from_pretrained(args.tokenizer)
- tokenizer.add_special_tokens({"additional_special_tokens": ["<e1>", "</e1>", "<e2>", "</e2>"]})
- tokenizer_path = output_path / "tokenizer"
- if not tokenizer_path.is_dir():
- tokenizer_path.mkdir()
- tokenizer.save_pretrained(tokenizer_path)
-
- train = load_fewrel_dataset(
- fewrel_path / "train.json",
- tokenizer,
- relation_dictionary)
- rng.shuffle(train)
-
- valid = load_fewrel_dataset(
- fewrel_path / "val.json",
- tokenizer,
- relation_dictionary)
-
- numpy.save(output_path / "train.npy", numpy.array(train))
- numpy.save(output_path / "valid.npy", numpy.array(valid))
-
- relation_dictionary.save(output_path / "relations")
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- description="Prepare the FewRel dataset.")
- parser.add_argument("tokenizer",
- type=str,
- nargs='?',
- default="bert-large-uncased",
- help="Name of the transformers tokenizer")
- parser.add_argument("-s", "--seed",
- type=int,
- default=0,
- help="Seed of the RNG for shuffling the dataset")
-
- prepare_fewrel(parser.parse_args())
diff --git a/fsre/data/prepare_kbp37.py b/fsre/data/prepare_kbp37.py
@@ -10,6 +10,7 @@ from fsre.data.prepare_semeval import load_semeval_dataset
TRAIN_SIZE = 15917
VALID_SIZE = 1724
TEST_SIZE = 3405
+UNKNOWN_RELATION = "no_relation"
def prepare_kbp37(args):
@@ -20,7 +21,7 @@ def prepare_kbp37(args):
if not output_path.is_dir():
output_path.mkdir()
- relation_dictionary = RelationDictionary()
+ relation_dictionary = RelationDictionary(unknown=UNKNOWN_RELATION)
tokenizer = transformers.BertTokenizer.from_pretrained(args.tokenizer)
tokenizer.add_special_tokens({"additional_special_tokens": ["<e1>", "</e1>", "<e2>", "</e2>"]})
diff --git a/fsre/data/prepare_semeval.py b/fsre/data/prepare_semeval.py
@@ -8,6 +8,7 @@ from fsre.data.relation_dictionary import RelationDictionary
TRAIN_SIZE = 8000
TEST_SIZE = 2717
+UNKNOWN_RELATION = "Other"
def load_semeval_dataset(path, tokenizer, relation_dictionary, size):
@@ -24,12 +25,21 @@ def load_semeval_dataset(path, tokenizer, relation_dictionary, size):
id, raw_text = idtext_line.rstrip().split('\t')
id = int(id)
+
raw_text = raw_text[1:-1] # remove quotes around text
text = tokenizer.encode(raw_text, add_special_tokens=True)
e1_pos = text.index(be1_id)
e2_pos = text.index(be2_id)
+ if len(text) > tokenizer.max_len:
+ text = text[:tokenizer.max_len]
+ e1_pos = min(tokenizer.max_len-1, e1_pos)
+ e2_pos = min(tokenizer.max_len-1, e2_pos)
text = numpy.array(text, dtype=numpy.int32)
- relation = relation_dictionary.encode(relation_line.rstrip())
+
+ relation_line = relation_line.rstrip()
+ dir_start = relation_line.find('(')
+ relation_base = relation_line[:dir_start] if dir_start >= 0 else relation_line
+ relation = relation_dictionary.encode(relation_line, relation_base)
dataset.append([id, text, e1_pos, e2_pos, relation])
infile.readline() # Ignore Comment line
@@ -45,7 +55,7 @@ def prepare_semeval(args):
if not output_path.is_dir():
output_path.mkdir()
- relation_dictionary = RelationDictionary()
+ relation_dictionary = RelationDictionary(unknown=UNKNOWN_RELATION)
tokenizer = transformers.BertTokenizer.from_pretrained(args.tokenizer)
tokenizer.add_special_tokens({"additional_special_tokens": ["<e1>", "</e1>", "<e2>", "</e2>"]})
diff --git a/fsre/data/relation_dictionary.py b/fsre/data/relation_dictionary.py
@@ -1,31 +1,94 @@
+import pickle
+
+
class RelationDictionary:
- """ A very simple dictionary to be used for relations. """
- def __init__(self, path=None):
+ """
+ A dictionary to be used for relations.
+
+ The tokens held by this class are divided between:
+ - *relation* such as "Entity-Destination(e1,e2)"
+ - *base* such as "Entity-Destination"
+ """
+
+ def __init__(self, *, unknown=None, path=None):
self.encoder = {}
self.decoder = []
+
+ self.base_encoder = {}
+ self.base_decoder = []
+ self.id_to_bid = []
+
+ self.unknown = unknown
+ if unknown is not None:
+ self.encoder[unknown] = 0
+ self.decoder.append(unknown)
+ self.base_encoder[unknown] = 0
+ self.base_decoder.append(unknown)
+ self.id_to_bid.append(0)
+
if path is not None:
self.load(path)
def __len__(self):
+ """ Number of relations in the dictionary. """
return len(self.decoder)
- def encode(self, token):
- id = self.encoder.get(token)
+ def base_size(self):
+ """ Number of bases in the dictionary. """
+ return len(self.base_decoder)
+
+ def encode(self, relation, base=None):
+ """
+ Returns the id corresponding to a relation string.
+
+ Args:
+ relation: the string of the relation (e.g. "Entity-Destination(e1,e2)")
+ base: the string of the base relation (e.g. "Entity-Destination")
+ """
+
+ if relation is None:
+ return None
+
+ if base is None:
+ return self.encoder[relation]
+
+ id = self.encoder.get(relation)
if id is not None:
return id
+
+ bid = self.base_encoder.get(base)
+ if bid is None:
+ bid = len(self.base_decoder)
+ self.base_encoder[base] = bid
+ self.base_decoder.append(base)
+
id = len(self.decoder)
- self.encoder[token] = id
- self.decoder.append(token)
+ self.encoder[relation] = id
+ self.decoder.append(relation)
+ self.id_to_bid.append(bid)
return id
def decode(self, id):
+ """ Returns the string corresponding to a relation id. """
return self.decoder[id]
+ def base_id(self, id):
+ """ Returns the base id corresponding to a relation id. """
+ return self.id_to_bid[id]
+
def save(self, path):
- with open(path, 'w') as file:
- file.writelines(map(lambda x: f"{x}\n", self.decoder))
+ with open(path, "wb") as file:
+ pickle.dump({
+ "unknown": self.unknown,
+ "decoder": self.decoder,
+ "encoder": self.encoder,
+ "base_encoder": self.base_encoder,
+ "base_decoder": self.base_decoder,
+ "id_to_bid": self.id_to_bid,
+ }, file)
def load(self, path):
- with open(path, 'r') as file:
- self.decoder = list(map(str.rstrip, file.readlines()))
- self.encoder = dict(zip(self.decoder, range(len(self.decoder))))
+ with open(path, "rb") as file:
+ data = pickle.load(file)
+ for key, value in data.items():
+ setattr(self, key, value)
diff --git a/fsre/eval.py b/fsre/eval.py
@@ -0,0 +1,30 @@
+import torch
+
+import fsre
+
+
+class Evaluator(fsre.train.Trainer):
+ """
+ Evaluate a model.
+ """
+
+ def __init__(self, config, state_dicts):
+ self.eval_config = config
+ super().__init__(state_dicts["config"], None, state_dicts)
+
+ def run(self):
+ self.epoch = self.state_dicts["epoch"]
+ self.info()
+ self.initialize_rng()
+ self.prepare_dataset()
+ self.build_model()
+ self.count_parameters()
+ self.evaluate("test")
+
+
+if __name__ == "__main__":
+ fsre.utils.fix_transformers_logging_handler()
+ config = fsre.utils.parse_args()
+
+ state_dicts = torch.load(config.load)
+ Evaluator(config, state_dicts).run()
diff --git a/fsre/metrics.py b/fsre/metrics.py
@@ -1,3 +1,4 @@
+import math
import numpy
import torch
@@ -6,28 +7,49 @@ class Metrics:
"""
Class for computing metrics.
- Five metrics are computed:
+ Twenty metrics are computed:
- Accuracy
- - Macro F1
- - Macro Precision
- - Macro Recall
- - Negative Log Likelihood
+ - Negative Log Likelihood (nll)
+ - {directed, undirected, half_directed} {micro, macro} {f1, precision, recall}
+ Note that the Accuracy is the true accuracy, taking directionality into account and scoring the unknown relation as any other relation.
+ The last 18 metrics follow the SemEval scorer:
+ - The unknown ("Other") relation is only scored indirectly
+ - Directed is equivalent to the metrics "USING DIRECTIONALITY"
+ - Undirected is equivalent to the metrics "IGNORING DIRECTIONALITY"
+ - Half-directed is equivalent to the metrics "TAKING DIRECTIONALITY INTO ACCOUNT -- OFFICIAL"
+ Note that the directed and half_directed micro metrics are equivalents.
"""
- def __init__(self, nclass):
+ def __init__(self, relation_dictionary):
"""
Initialize all metrics.
Args:
- nclass: number of relations
+ relation_dictionary: see class RelationDictionary
"""
- self.n = nclass
+ self.relation_dictionary = relation_dictionary
+ self.n = len(relation_dictionary)
+ self.m = relation_dictionary.base_size()
+ self.build_mask()
+ self.build_base_transition()
self.crossentropy = torch.nn.CrossEntropyLoss(reduction="sum")
self.size = 0
+ self.correct = 0
self.ce_sum = 0
- self.confusion = numpy.zeros((nclass, nclass), numpy.int64)
+ self.confusion = numpy.zeros((self.n, self.n), numpy.int64)
+
+ def build_mask(self):
+ self.mask = numpy.ones(self.n)
+ if self.relation_dictionary.unknown is not None:
+ assert(self.relation_dictionary.decode(0) == self.relation_dictionary.unknown)
+ self.mask[0] = 0
+
+ def build_base_transition(self):
+ self.base_transition = numpy.zeros((self.n, self.m))
+ for id, bid in enumerate(self.relation_dictionary.id_to_bid):
+ self.base_transition[id, bid] = 1
def update(self, predictions, target):
"""
@@ -39,46 +61,190 @@ class Metrics:
"""
self.size += predictions.shape[0]
- self.ce_sum += self.crossentropy(predictions, target).cpu().item()
+ self.ce_sum += self.crossentropy(predictions, target).item()
prediction = predictions.argmax(1)
- for p, t in zip(prediction.cpu(), target.cpu()):
- self.confusion[p.item(), t.item()] += 1
+ for p, t in zip(prediction.tolist(), target.tolist()):
+ self.confusion[p, t] += 1
+ self.correct += (p == t)
+
+ @property
+ def summary(self):
+ return {"accuracy": f"{self.accuracy*100:.2f}",
+ "nll": f"{self.nll:.2f}"}
+
+ @property
+ def all(self):
+ keys = ["accuracy", "nll"] + [
+ f"{direction}_{level}_{metric}"
+ for direction in ["directed", "undirected", "half_directed"]
+ for level in ["macro", "micro"]
+ for metric in ["f1", "precision", "recall"]]
+ return {key: getattr(self, key) for key in keys}
+
+ @property
+ def base_mask(self):
+ return self.mask.dot(self.base_transition).clip(0, 1)
+
+ @property
+ def base_confusion(self):
+ return self.base_transition.T.dot(self.confusion).dot(self.base_transition)
@property
def accuracy(self):
- return self.confusion.diagonal().sum() / (self.confusion.sum() + 1e-12)
+ return math.nan if self.size == 0 else self.correct / self.size
@property
- def class_precision(self):
- return self.confusion.diagonal() / (self.confusion.sum(1) + 1e-12)
+ def nll(self):
+ return math.nan if self.size == 0 else self.ce_sum / self.size
+
+ ##########################
+ # Directed macro metrics #
+ ##########################
@property
- def class_recall(self):
- return self.confusion.diagonal() / (self.confusion.sum(0) + 1e-12)
+ def directed_class_precision(self):
+ norm = self.confusion.sum(1)
+ norm[norm == 0] = 1
+ return self.confusion.diagonal() / norm
@property
- def class_f1(self):
- return 2 * self.class_precision * self.class_recall / (self.class_precision + self.class_recall + 1e-12)
+ def directed_class_recall(self):
+ norm = self.confusion.sum(0)
+ norm[norm == 0] = 1
+ return self.confusion.diagonal() / norm
@property
- def precision(self):
- return self.class_precision.mean()
+ def directed_class_f1(self):
+ norm = self.directed_class_precision + self.directed_class_recall
+ norm[norm == 0] = 1
+ return 2 * self.directed_class_precision * self.directed_class_recall / norm
@property
- def recall(self):
- return self.class_recall.mean()
+ def directed_macro_precision(self):
+ return numpy.sum(self.directed_class_precision * self.mask) / self.mask.sum()
@property
- def f1(self):
- return self.class_f1.mean()
+ def directed_macro_recall(self):
+ return numpy.sum(self.directed_class_recall * self.mask) / self.mask.sum()
@property
- def nll(self):
- return self.ce_sum / self.size
+ def directed_macro_f1(self):
+ return numpy.sum(self.directed_class_f1 * self.mask) / self.mask.sum()
+
+ ############################
+ # Undirected macro metrics #
+ ############################
@property
- def summary(self):
- return {"accuracy": f"{self.accuracy*100:.2f}",
- "f1": f"{self.f1*100:.2f}",
- "nll": f"{self.nll:.2f}"}
+ def undirected_class_precision(self):
+ norm = self.base_confusion.sum(1)
+ norm[norm == 0] = 1
+ return self.base_confusion.diagonal() / norm
+
+ @property
+ def undirected_class_recall(self):
+ norm = self.base_confusion.sum(0)
+ norm[norm == 0] = 1
+ return self.base_confusion.diagonal() / norm
+
+ @property
+ def undirected_class_f1(self):
+ norm = self.undirected_class_precision + self.undirected_class_recall
+ norm[norm == 0] = 1
+ return 2 * self.undirected_class_precision * self.undirected_class_recall / norm
+
+ @property
+ def undirected_macro_precision(self):
+ return numpy.sum(self.undirected_class_precision * self.base_mask) / self.base_mask.sum()
+
+ @property
+ def undirected_macro_recall(self):
+ return numpy.sum(self.undirected_class_recall * self.base_mask) / self.base_mask.sum()
+
+ @property
+ def undirected_macro_f1(self):
+ return numpy.sum(self.undirected_class_f1 * self.base_mask) / self.base_mask.sum()
+
+ ###############################
+ # Half-directed macro metrics #
+ ###############################
+
+ @property
+ def half_directed_class_precision(self):
+ norm = self.base_confusion.sum(1)
+ norm[norm == 0] = 1
+ return self.confusion.diagonal().dot(self.base_transition) / norm
+
+ @property
+ def half_directed_class_recall(self):
+ norm = self.base_confusion.sum(0)
+ norm[norm == 0] = 1
+ return self.confusion.diagonal().dot(self.base_transition) / norm
+
+ @property
+ def half_directed_class_f1(self):
+ norm = self.half_directed_class_precision + self.half_directed_class_recall
+ norm[norm == 0] = 1
+ return 2 * self.half_directed_class_precision * self.half_directed_class_recall / norm
+
+ @property
+ def half_directed_macro_precision(self):
+ return numpy.sum(self.half_directed_class_precision * self.base_mask) / self.base_mask.sum()
+
+ @property
+ def half_directed_macro_recall(self):
+ return numpy.sum(self.half_directed_class_recall * self.base_mask) / self.base_mask.sum()
+
+ @property
+ def half_directed_macro_f1(self):
+ return numpy.sum(self.half_directed_class_f1 * self.base_mask) / self.base_mask.sum()
+
+ #################
+ # Micro metrics #
+ #################
+
+ @property
+ def directed_micro_precision(self):
+ norm = numpy.sum(self.confusion.sum(1) * self.mask)
+ return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm
+
+ @property
+ def directed_micro_recall(self):
+ norm = numpy.sum(self.confusion.sum(0) * self.mask)
+ return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm
+
+ @property
+ def directed_micro_f1(self):
+ norm = self.directed_micro_precision + self.directed_micro_recall
+ return 0 if norm == 0 else 2 * (self.directed_micro_precision * self.directed_micro_recall) / norm
+
+ @property
+ def half_directed_micro_precision(self):
+ norm = numpy.sum(self.confusion.sum(1) * self.mask)
+ return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm
+
+ @property
+ def half_directed_micro_recall(self):
+ norm = numpy.sum(self.confusion.sum(0) * self.mask)
+ return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm
+
+ @property
+ def half_directed_micro_f1(self):
+ norm = self.half_directed_micro_precision + self.half_directed_micro_recall
+ return 0 if norm == 0 else 2 * (self.half_directed_micro_precision * self.half_directed_micro_recall) / norm
+
+ @property
+ def undirected_micro_precision(self):
+ norm = numpy.sum(self.base_confusion.sum(1) * self.base_mask)
+ return 0 if norm == 0 else numpy.sum(self.base_confusion.diagonal() * self.base_mask) / norm
+
+ @property
+ def undirected_micro_recall(self):
+ norm = numpy.sum(self.base_confusion.sum(0) * self.base_mask)
+ return 0 if norm == 0 else numpy.sum(self.base_confusion.diagonal() * self.base_mask) / norm
+
+ @property
+ def undirected_micro_f1(self):
+ norm = self.undirected_micro_precision + self.undirected_micro_recall
+ return 0 if norm == 0 else 2 * (self.undirected_micro_precision * self.undirected_micro_recall) / norm
diff --git a/fsre/train.py b/fsre/train.py
@@ -1,6 +1,8 @@
import sys
import os
import math
+import time
+import contextlib
import multiprocessing
import signal
import logging
@@ -19,15 +21,20 @@ class Trainer:
Train a model.
Config:
+ Model: the model class to use for training
batch_size: the number of samples in the batch of data loaded
- true_batch_size: the actual number of sample in a batch, the number of
- sample seen before a backward (must be a multiple of batch_size)
+ bert_model: the model of transformer to use
dataset_name: name of the dataset to load
deterministic: run in deterministic mode
+ early_stopping_patience: how many epoch to train after best validation score has been reached
learning_rate: learning rate
max_epoch: maximum number of epoch
+ no_initial_validation: do not run evaluation on the valid dataset before first epoch
seed: the seed for the random number generator
sort_per_shuffle_bucket: the number of sort buckets in a shuffle bucket
+ test_output: path to a file where the test predictions will be written
+ true_batch_size: the actual number of sample in a batch, the number of sample seen before a backward (must be a multiple of batch_size)
+ validation_metric: metric used for early stopping
"""
def __init__(self, config, logdir, state_dicts=None):
@@ -83,7 +90,11 @@ class Trainer:
logger.info(f"GPU{i} {gp.name} {gp.total_memory} {gp.major}.{gp.minor}")
def info(self):
- print(f"logdir is \033[1m\033[33m{self.logdir}\033[0m")
+ if self.logdir is None:
+ print(f"logdir is \033[1m\033[31mnot set\033[0m, log messages will be discarded")
+ else:
+ print(f"logdir is \033[1m\033[33m{self.logdir}\033[0m")
+
self.environment_check()
self.detect_gpus()
print("")
@@ -103,7 +114,7 @@ class Trainer:
suffix = ""
if self.state_dicts:
suffix = time.strftime("%FT%H:%M:%S")
- fsre.utils.save_patch(self.logdir / "patch{suffix}")
+ fsre.utils.save_patch(self.logdir / f"patch{suffix}")
def initialize_rng(self):
if self.state_dicts:
@@ -120,7 +131,7 @@ class Trainer:
def prepare_dataset(self):
data_dir = fsre.utils.DATA_PATH / self.config.dataset_name / self.config.bert_model
- self.relation_dictionary = fsre.data.RelationDictionary(data_dir / "relations")
+ self.relation_dictionary = fsre.data.RelationDictionary(path=data_dir / "relations")
self.tokenizer = transformers.BertTokenizer.from_pretrained(data_dir / "tokenizer")
self.dataset = {}
@@ -178,6 +189,22 @@ class Trainer:
signal.signal(signal.SIGINT, handler)
+ def eval_context(self, dataset):
+ self.test_output_file = None
+ if dataset == "test" and self.config.get("test_output"):
+ self.test_output_file = open(self.config.test_output, 'w')
+ return self.test_output_file
+ return contextlib.nullcontext()
+
+ def eval_handle_predictions(self, ids, predictions):
+ if self.test_output_file is None:
+ return
+ ids = ids.tolist()
+ predictions = predictions.argmax(1).tolist()
+ for id, prediction in zip(ids, predictions):
+ prediction = self.relation_dictionary.decode(prediction)
+ self.test_output_file.write(f"{id}\t{prediction}\n")
+
def evaluate(self, dataset):
loop = tqdm.tqdm(
iterable=self.iterator[dataset](),
@@ -188,36 +215,45 @@ class Trainer:
leave=False)
self.model.eval()
- with torch.no_grad():
- scorer = fsre.metrics.Metrics(len(self.relation_dictionary))
+ output_prediction_to_file = True
+ with torch.no_grad(), self.eval_context(dataset):
+ has_target = False
+ scorer = fsre.metrics.Metrics(self.relation_dictionary)
correct_prediction = 0
for batch in loop:
batch = {key: value.to(self.device) for key, value in batch.items()}
# Pop the target to ensure it's not used by the model
- target = batch.pop("relation")
+ if "relation" in batch:
+ target = batch.pop("relation")
+ has_target = True
predictions = self.model(batch)
- scorer.update(predictions, target)
- loop.set_postfix(**scorer.summary, refresh=False)
-
- print(f"Epoch {self.epoch} {dataset:5} accuracy: {scorer.accuracy*100:8.4f}% F1: {scorer.f1*100:8.4f}% (P: {scorer.precision*100:8.4f}% R: {scorer.recall*100:8.4f}%) NLL: {scorer.nll:8.4f}")
- logger.info(f"epoch {self.epoch} {dataset} accuracy {scorer.accuracy} F1 {scorer.f1} precision {scorer.precision} recall {scorer.recall} NLL {scorer.nll}")
- return getattr(scorer, self.config.validation_metric)
-
- def save(self, path, full):
- state_dicts = {"model": self.model.state_dict()}
- if full:
- state_dicts.update({
- "logdir": self.logdir,
- "optimizer": self.optimizer.state_dict(),
- "train_rng": self.dataset["train"].state_dict(),
- "torch_rng": torch.random.get_rng_state(),
- "epoch": self.epoch,
- "best_epoch": self.best_epoch,
- "best_eval": self.best_eval,
- })
- if torch.cuda.is_available():
- state_dicts["cuda_rng"] = torch.cuda.random.get_rng_state_all()
+ self.eval_handle_predictions(batch["id"], predictions)
+ if has_target:
+ scorer.update(predictions, target)
+ loop.set_postfix(**scorer.summary, refresh=False)
+
+ if has_target:
+ print(f"Epoch {self.epoch} {dataset:5} accuracy: {scorer.accuracy*100:8.4f}% Half-directed Macro F1: {scorer.half_directed_macro_f1*100:8.4f}% (P: {scorer.half_directed_macro_precision*100:8.4f}% R: {scorer.half_directed_macro_recall*100:8.4f}%) NLL: {scorer.nll:8.4f}")
+ for metric, value in scorer.all.items():
+ logger.info(f"epoch {self.epoch} {dataset} {metric} {value}")
+ return getattr(scorer, self.config.validation_metric)
+ return None
+
+ def save(self, path):
+ state_dicts = {
+ "logdir": self.logdir,
+ "config": self.config,
+ "model": self.model.state_dict(),
+ "optimizer": self.optimizer.state_dict(),
+ "train_rng": self.dataset["train"].state_dict(),
+ "torch_rng": torch.random.get_rng_state(),
+ "epoch": self.epoch,
+ "best_epoch": self.best_epoch,
+ "best_eval": self.best_eval,
+ }
+ if torch.cuda.is_available():
+ state_dicts["cuda_rng"] = torch.cuda.random.get_rng_state_all()
torch.save(state_dicts, path)
def train(self):
@@ -275,14 +311,14 @@ class Trainer:
print(f"Epoch {self.epoch} train mean loss: {total_loss / total_sample:8.4f}")
logger.info(f"epoch {self.epoch} train mean_loss {total_loss / total_sample}")
- self.save(self.logdir / "checkpoint.new", full=True)
+ self.save(self.logdir / "checkpoint.new")
os.rename(self.logdir / "checkpoint.new", self.logdir / "checkpoint")
candidate = self.evaluate("valid")
if candidate > self.best_eval:
self.best_epoch = self.epoch
self.best_eval = candidate
- self.save(best_path, full=False)
+ self.save(best_path)
logger.info(f"Model saved to {best_path}")
elif self.epoch - self.best_epoch > self.config.get("early_stopping_patience", self.config.max_epoch):
break
@@ -302,10 +338,10 @@ if __name__ == "__main__":
state_dicts = None
if config.get("load"):
- state_dicts = torch.load(config.load, strict=False)
-
- if config.get("load"):
+ state_dicts = torch.load(config.load)
logdir = state_dicts["logdir"]
+ if config.get("reuse_config"):
+ config = state_dicts["config"]
else:
logdir = fsre.utils.logdir_name("FSRE")
assert(not logdir.exists())