commit bafbf5ff3a8f8f39573719da56083cacac5345f9
parent a2553cee69c02fca47496c1cb5c23dc99b5ac5c8
Author: Étienne Simon <esimon@esimon.eu>
Date:   Thu, 21 Nov 2019 12:38:01 +0100
Fix metrics
Diffstat:
11 files changed, 409 insertions(+), 151 deletions(-)
diff --git a/README b/README
@@ -1,23 +1,35 @@
 Reproduction of Matching the Blanks: Distributional Similarity for Relation Learning by Livio Baldini Soares, Nicholas FitzGerald, Jeffrey Ling, and Tom Kwiatkowski.
 
-This repository currently contains the supervised model for the Semeval and KBP37 datasets, the unsupervised MTB model and the FewRel dataset will be added latter on.
-In order to reproduce the results, you first need to manually download and extract the datasets (e.g. in /tmp), then do the following:
+This repository currently contains the supervised "entity markers-entity start" model for the Semeval and KBP37 datasets, the unsupervised MTB model and the FewRel dataset will be added latter on.
+In order to reproduce the results, you first need to manually download and extract the datasets (e.g. in /tmp), then execute the following:
 
 $ export DATA_PATH="/tmp" # the directory containing the extracted datasets
 $ export LOG_PATH="/tmp" # the directory where the models will be saved
 $ git clone https://gitlab.lip6.fr/esimon/few-shot-relation-extraction
 $ cd few-shot-relation-extraction
-$ python -m fsre.data.prepare_semeval bert-large-cased
+$ python -m fsre.data.prepare_semeval
 $ python -m fsre.train fsre/config/soares_supervised_semeval.py
-$ python -m fsre.data.prepare_kbp37 bert-large-cased
+$ python -m fsre.data.prepare_kbp37
 $ python -m fsre.train fsre/config/soares_supervised_kbp37.py
 
-I must be missing something since I'm still far from the results reported in the paper:
+I must be missing something since I'm a bit away from the results reported in the paper.
+Here are the score reached by this repository (official SemEval macro F1 "taking directionality into account"):
 
 On Semeval:
-	valid macro F1: paper 82.1 vs us 81.5 (accuracy 84.7)
-	test  macro F1: paper 89.2 vs us 81.8 (accuracy 84.8)
+	paper valid: 82.1
+	BERT cased valid: 88.55 (std: 0.70)
+	BERT uncased valid: 87.49 (std: 0.87)
+	paper test: 89.2
+	BERT cased test: 88.24 (std: 0.37)
+	BERT uncased test: 88.02 (std: 0.65)
 
 On KBP37:
-	valid macro F1: paper 70.0 vs us 65.7 (accuracy 65.5)
-	test  macro F1: paper 68.3 vs us 63.2 (accuracy 64.2)
+	paper valid: 70.0
+	BERT cased valid: 66.92 (std: 0.61)
+	BERT uncased valid: 66.48 (std: 0.46)
+	paper test: 68.3
+	BERT cased test: 67.18 (std: 0.51)
+	BERT uncased test: 66.21 (std: 0.62)
+
+The mean and std are computed over 5 runs, all reported results are for "large" BERT models.
+Detailed results can be found at http://www-ia.lip6.fr/~esimon/results.xhtml (temporary link).
diff --git a/fsre/__init__.py b/fsre/__init__.py
@@ -1,3 +1,4 @@
 import fsre.data
 import fsre.utils
 import fsre.metrics
+import fsre.train
diff --git a/fsre/config/soares_supervised_kbp37.py b/fsre/config/soares_supervised_kbp37.py
@@ -11,7 +11,7 @@ learning_rate = 3e-5
 true_batch_size = 64
 
 # Guessed
-validation_metric = "f1"
+validation_metric = "half_directed_macro_f1"
 early_stopping_patience = 2
 
 # Implementation details
diff --git a/fsre/config/soares_supervised_semeval.py b/fsre/config/soares_supervised_semeval.py
@@ -11,7 +11,7 @@ learning_rate = 3e-5
 true_batch_size = 64
 
 # Guessed
-validation_metric = "f1"
+validation_metric = "half_directed_macro_f1"
 early_stopping_patience = 2
 
 # Implementation details
diff --git a/fsre/data/prepare_fewrel.py b/fsre/data/prepare_fewrel.py
@@ -1,61 +0,0 @@
-import argparse
-import numpy
-import transformers
-import tqdm
-
-from fsre.utils import DATA_PATH
-from fsre.data.relation_dictionary import RelationDictionary
-
-
-def load_fewrel_dataset(path, tokenizer, relation_dictionary):
-    pass
-
-
-def prepare_fewrel(args):
-    rng = numpy.random.RandomState(args.seed)
-    fewrel_path = DATA_PATH / "FewRel"
-    output_path = fewrel_path / args.tokenizer
-
-    if not output_path.is_dir():
-        output_path.mkdir()
-
-    relation_dictionary = RelationDictionary()
-
-    tokenizer = transformers.BertTokenizer.from_pretrained(args.tokenizer)
-    tokenizer.add_special_tokens({"additional_special_tokens": ["<e1>", "</e1>", "<e2>", "</e2>"]})
-    tokenizer_path = output_path / "tokenizer"
-    if not tokenizer_path.is_dir():
-        tokenizer_path.mkdir()
-    tokenizer.save_pretrained(tokenizer_path)
-
-    train = load_fewrel_dataset(
-            fewrel_path / "train.json",
-            tokenizer,
-            relation_dictionary)
-    rng.shuffle(train)
-
-    valid = load_fewrel_dataset(
-            fewrel_path / "val.json",
-            tokenizer,
-            relation_dictionary)
-
-    numpy.save(output_path / "train.npy", numpy.array(train))
-    numpy.save(output_path / "valid.npy", numpy.array(valid))
-
-    relation_dictionary.save(output_path / "relations")
-
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser(
-            description="Prepare the FewRel dataset.")
-    parser.add_argument("tokenizer",
-                        type=str,
-                        nargs='?',
-                        default="bert-large-uncased",
-                        help="Name of the transformers tokenizer")
-    parser.add_argument("-s", "--seed",
-                        type=int,
-                        default=0,
-                        help="Seed of the RNG for shuffling the dataset")
-
-    prepare_fewrel(parser.parse_args())
diff --git a/fsre/data/prepare_kbp37.py b/fsre/data/prepare_kbp37.py
@@ -10,6 +10,7 @@ from fsre.data.prepare_semeval import load_semeval_dataset
 TRAIN_SIZE = 15917
 VALID_SIZE = 1724
 TEST_SIZE = 3405
+UNKNOWN_RELATION = "no_relation"
 
 
 def prepare_kbp37(args):
@@ -20,7 +21,7 @@ def prepare_kbp37(args):
     if not output_path.is_dir():
         output_path.mkdir()
 
-    relation_dictionary = RelationDictionary()
+    relation_dictionary = RelationDictionary(unknown=UNKNOWN_RELATION)
 
     tokenizer = transformers.BertTokenizer.from_pretrained(args.tokenizer)
     tokenizer.add_special_tokens({"additional_special_tokens": ["<e1>", "</e1>", "<e2>", "</e2>"]})
diff --git a/fsre/data/prepare_semeval.py b/fsre/data/prepare_semeval.py
@@ -8,6 +8,7 @@ from fsre.data.relation_dictionary import RelationDictionary
 
 TRAIN_SIZE = 8000
 TEST_SIZE = 2717
+UNKNOWN_RELATION = "Other"
 
 
 def load_semeval_dataset(path, tokenizer, relation_dictionary, size):
@@ -24,12 +25,21 @@ def load_semeval_dataset(path, tokenizer, relation_dictionary, size):
 
             id, raw_text = idtext_line.rstrip().split('\t')
             id = int(id)
+
             raw_text = raw_text[1:-1]  # remove quotes around text
             text = tokenizer.encode(raw_text, add_special_tokens=True)
             e1_pos = text.index(be1_id)
             e2_pos = text.index(be2_id)
+            if len(text) > tokenizer.max_len:
+                text = text[:tokenizer.max_len]
+                e1_pos = min(tokenizer.max_len-1, e1_pos)
+                e2_pos = min(tokenizer.max_len-1, e2_pos)
             text = numpy.array(text, dtype=numpy.int32)
-            relation = relation_dictionary.encode(relation_line.rstrip())
+
+            relation_line = relation_line.rstrip()
+            dir_start = relation_line.find('(')
+            relation_base = relation_line[:dir_start] if dir_start >= 0 else relation_line
+            relation = relation_dictionary.encode(relation_line, relation_base)
 
             dataset.append([id, text, e1_pos, e2_pos, relation])
             infile.readline()  # Ignore Comment line
@@ -45,7 +55,7 @@ def prepare_semeval(args):
     if not output_path.is_dir():
         output_path.mkdir()
 
-    relation_dictionary = RelationDictionary()
+    relation_dictionary = RelationDictionary(unknown=UNKNOWN_RELATION)
 
     tokenizer = transformers.BertTokenizer.from_pretrained(args.tokenizer)
     tokenizer.add_special_tokens({"additional_special_tokens": ["<e1>", "</e1>", "<e2>", "</e2>"]})
diff --git a/fsre/data/relation_dictionary.py b/fsre/data/relation_dictionary.py
@@ -1,31 +1,94 @@
+import pickle
+
+
 class RelationDictionary:
-    """ A very simple dictionary to be used for relations. """
-    def __init__(self, path=None):
+    """
+    A dictionary to be used for relations.
+
+    The tokens held by this class are divided between:
+        - *relation* such as "Entity-Destination(e1,e2)"
+        - *base* such as "Entity-Destination"
+    """
+
+    def __init__(self, *, unknown=None, path=None):
         self.encoder = {}
         self.decoder = []
+
+        self.base_encoder = {}
+        self.base_decoder = []
+        self.id_to_bid = []
+
+        self.unknown = unknown
+        if unknown is not None:
+            self.encoder[unknown] = 0
+            self.decoder.append(unknown)
+            self.base_encoder[unknown] = 0
+            self.base_decoder.append(unknown)
+            self.id_to_bid.append(0)
+
         if path is not None:
             self.load(path)
 
     def __len__(self):
+        """ Number of relations in the dictionary. """
         return len(self.decoder)
 
-    def encode(self, token):
-        id = self.encoder.get(token)
+    def base_size(self):
+        """ Number of bases in the dictionary. """
+        return len(self.base_decoder)
+
+    def encode(self, relation, base=None):
+        """
+        Returns the id corresponding to a relation string.
+
+        Args:
+            relation: the string of the relation (e.g. "Entity-Destination(e1,e2)")
+            base: the string of the base relation (e.g. "Entity-Destination")
+        """
+
+        if relation is None:
+            return None
+
+        if base is None:
+            return self.encoder[relation]
+
+        id = self.encoder.get(relation)
         if id is not None:
             return id
+
+        bid = self.base_encoder.get(base)
+        if bid is None:
+            bid = len(self.base_decoder)
+            self.base_encoder[base] = bid
+            self.base_decoder.append(base)
+
         id = len(self.decoder)
-        self.encoder[token] = id
-        self.decoder.append(token)
+        self.encoder[relation] = id
+        self.decoder.append(relation)
+        self.id_to_bid.append(bid)
         return id
 
     def decode(self, id):
+        """ Returns the string corresponding to a relation id.  """
         return self.decoder[id]
 
+    def base_id(self, id):
+        """ Returns the base id corresponding to a relation id.  """
+        return self.id_to_bid[id]
+
     def save(self, path):
-        with open(path, 'w') as file:
-            file.writelines(map(lambda x: f"{x}\n", self.decoder))
+        with open(path, "wb") as file:
+            pickle.dump({
+                    "unknown": self.unknown,
+                    "decoder": self.decoder,
+                    "encoder": self.encoder,
+                    "base_encoder": self.base_encoder,
+                    "base_decoder": self.base_decoder,
+                    "id_to_bid": self.id_to_bid,
+                }, file)
 
     def load(self, path):
-        with open(path, 'r') as file:
-            self.decoder = list(map(str.rstrip, file.readlines()))
-        self.encoder = dict(zip(self.decoder, range(len(self.decoder))))
+        with open(path, "rb") as file:
+            data = pickle.load(file)
+            for key, value in data.items():
+                setattr(self, key, value)
diff --git a/fsre/eval.py b/fsre/eval.py
@@ -0,0 +1,30 @@
+import torch
+
+import fsre
+
+
+class Evaluator(fsre.train.Trainer):
+    """
+    Evaluate a model.
+    """
+
+    def __init__(self, config, state_dicts):
+        self.eval_config = config
+        super().__init__(state_dicts["config"], None, state_dicts)
+
+    def run(self):
+        self.epoch = self.state_dicts["epoch"]
+        self.info()
+        self.initialize_rng()
+        self.prepare_dataset()
+        self.build_model()
+        self.count_parameters()
+        self.evaluate("test")
+
+
+if __name__ == "__main__":
+    fsre.utils.fix_transformers_logging_handler()
+    config = fsre.utils.parse_args()
+
+    state_dicts = torch.load(config.load)
+    Evaluator(config, state_dicts).run()
diff --git a/fsre/metrics.py b/fsre/metrics.py
@@ -1,3 +1,4 @@
+import math
 import numpy
 import torch
 
@@ -6,28 +7,49 @@ class Metrics:
     """
     Class for computing metrics.
 
-    Five metrics are computed:
+    Twenty metrics are computed:
         - Accuracy
-        - Macro F1
-        - Macro Precision
-        - Macro Recall
-        - Negative Log Likelihood
+        - Negative Log Likelihood (nll)
+        - {directed, undirected, half_directed} {micro, macro} {f1, precision, recall}
+    Note that the Accuracy is the true accuracy, taking directionality into account and scoring the unknown relation as any other relation.
+    The last 18 metrics follow the SemEval scorer:
+        - The unknown ("Other") relation is only scored indirectly
+        - Directed is equivalent to the metrics "USING DIRECTIONALITY"
+        - Undirected is equivalent to the metrics "IGNORING DIRECTIONALITY"
+        - Half-directed is equivalent to the metrics "TAKING DIRECTIONALITY INTO ACCOUNT -- OFFICIAL"
+    Note that the directed and half_directed micro metrics are equivalents.
     """
 
-    def __init__(self, nclass):
+    def __init__(self, relation_dictionary):
         """
         Initialize all metrics.
 
         Args:
-            nclass: number of relations
+            relation_dictionary: see class RelationDictionary
         """
 
-        self.n = nclass
+        self.relation_dictionary = relation_dictionary
+        self.n = len(relation_dictionary)
+        self.m = relation_dictionary.base_size()
+        self.build_mask()
+        self.build_base_transition()
         self.crossentropy = torch.nn.CrossEntropyLoss(reduction="sum")
 
         self.size = 0
+        self.correct = 0
         self.ce_sum = 0
-        self.confusion = numpy.zeros((nclass, nclass), numpy.int64)
+        self.confusion = numpy.zeros((self.n, self.n), numpy.int64)
+
+    def build_mask(self):
+        self.mask = numpy.ones(self.n)
+        if self.relation_dictionary.unknown is not None:
+            assert(self.relation_dictionary.decode(0) == self.relation_dictionary.unknown)
+            self.mask[0] = 0
+
+    def build_base_transition(self):
+        self.base_transition = numpy.zeros((self.n, self.m))
+        for id, bid in enumerate(self.relation_dictionary.id_to_bid):
+            self.base_transition[id, bid] = 1
 
     def update(self, predictions, target):
         """
@@ -39,46 +61,190 @@ class Metrics:
         """
 
         self.size += predictions.shape[0]
-        self.ce_sum += self.crossentropy(predictions, target).cpu().item()
+        self.ce_sum += self.crossentropy(predictions, target).item()
 
         prediction = predictions.argmax(1)
-        for p, t in zip(prediction.cpu(), target.cpu()):
-            self.confusion[p.item(), t.item()] += 1
+        for p, t in zip(prediction.tolist(), target.tolist()):
+            self.confusion[p, t] += 1
+            self.correct += (p == t)
+
+    @property
+    def summary(self):
+        return {"accuracy": f"{self.accuracy*100:.2f}",
+                "nll": f"{self.nll:.2f}"}
+
+    @property
+    def all(self):
+        keys = ["accuracy", "nll"] + [
+                f"{direction}_{level}_{metric}"
+                for direction in ["directed", "undirected", "half_directed"]
+                for level in ["macro", "micro"]
+                for metric in ["f1", "precision", "recall"]]
+        return {key: getattr(self, key) for key in keys}
+
+    @property
+    def base_mask(self):
+        return self.mask.dot(self.base_transition).clip(0, 1)
+
+    @property
+    def base_confusion(self):
+        return self.base_transition.T.dot(self.confusion).dot(self.base_transition)
 
     @property
     def accuracy(self):
-        return self.confusion.diagonal().sum() / (self.confusion.sum() + 1e-12)
+        return math.nan if self.size == 0 else self.correct / self.size
 
     @property
-    def class_precision(self):
-        return self.confusion.diagonal() / (self.confusion.sum(1) + 1e-12)
+    def nll(self):
+        return math.nan if self.size == 0 else self.ce_sum / self.size
+
+    ##########################
+    # Directed macro metrics #
+    ##########################
 
     @property
-    def class_recall(self):
-        return self.confusion.diagonal() / (self.confusion.sum(0) + 1e-12)
+    def directed_class_precision(self):
+        norm = self.confusion.sum(1)
+        norm[norm == 0] = 1
+        return self.confusion.diagonal() / norm
 
     @property
-    def class_f1(self):
-        return 2 * self.class_precision * self.class_recall / (self.class_precision + self.class_recall + 1e-12)
+    def directed_class_recall(self):
+        norm = self.confusion.sum(0)
+        norm[norm == 0] = 1
+        return self.confusion.diagonal() / norm
 
     @property
-    def precision(self):
-        return self.class_precision.mean()
+    def directed_class_f1(self):
+        norm = self.directed_class_precision + self.directed_class_recall
+        norm[norm == 0] = 1
+        return 2 * self.directed_class_precision * self.directed_class_recall / norm
 
     @property
-    def recall(self):
-        return self.class_recall.mean()
+    def directed_macro_precision(self):
+        return numpy.sum(self.directed_class_precision * self.mask) / self.mask.sum()
 
     @property
-    def f1(self):
-        return self.class_f1.mean()
+    def directed_macro_recall(self):
+        return numpy.sum(self.directed_class_recall * self.mask) / self.mask.sum()
 
     @property
-    def nll(self):
-        return self.ce_sum / self.size
+    def directed_macro_f1(self):
+        return numpy.sum(self.directed_class_f1 * self.mask) / self.mask.sum()
+
+    ############################
+    # Undirected macro metrics #
+    ############################
 
     @property
-    def summary(self):
-        return {"accuracy": f"{self.accuracy*100:.2f}",
-                "f1": f"{self.f1*100:.2f}",
-                "nll": f"{self.nll:.2f}"}
+    def undirected_class_precision(self):
+        norm = self.base_confusion.sum(1)
+        norm[norm == 0] = 1
+        return self.base_confusion.diagonal() / norm
+
+    @property
+    def undirected_class_recall(self):
+        norm = self.base_confusion.sum(0)
+        norm[norm == 0] = 1
+        return self.base_confusion.diagonal() / norm
+
+    @property
+    def undirected_class_f1(self):
+        norm = self.undirected_class_precision + self.undirected_class_recall
+        norm[norm == 0] = 1
+        return 2 * self.undirected_class_precision * self.undirected_class_recall / norm
+
+    @property
+    def undirected_macro_precision(self):
+        return numpy.sum(self.undirected_class_precision * self.base_mask) / self.base_mask.sum()
+
+    @property
+    def undirected_macro_recall(self):
+        return numpy.sum(self.undirected_class_recall * self.base_mask) / self.base_mask.sum()
+
+    @property
+    def undirected_macro_f1(self):
+        return numpy.sum(self.undirected_class_f1 * self.base_mask) / self.base_mask.sum()
+
+    ###############################
+    # Half-directed macro metrics #
+    ###############################
+
+    @property
+    def half_directed_class_precision(self):
+        norm = self.base_confusion.sum(1)
+        norm[norm == 0] = 1
+        return self.confusion.diagonal().dot(self.base_transition) / norm
+
+    @property
+    def half_directed_class_recall(self):
+        norm = self.base_confusion.sum(0)
+        norm[norm == 0] = 1
+        return self.confusion.diagonal().dot(self.base_transition) / norm
+
+    @property
+    def half_directed_class_f1(self):
+        norm = self.half_directed_class_precision + self.half_directed_class_recall
+        norm[norm == 0] = 1
+        return 2 * self.half_directed_class_precision * self.half_directed_class_recall / norm
+
+    @property
+    def half_directed_macro_precision(self):
+        return numpy.sum(self.half_directed_class_precision * self.base_mask) / self.base_mask.sum()
+
+    @property
+    def half_directed_macro_recall(self):
+        return numpy.sum(self.half_directed_class_recall * self.base_mask) / self.base_mask.sum()
+
+    @property
+    def half_directed_macro_f1(self):
+        return numpy.sum(self.half_directed_class_f1 * self.base_mask) / self.base_mask.sum()
+
+    #################
+    # Micro metrics #
+    #################
+
+    @property
+    def directed_micro_precision(self):
+        norm = numpy.sum(self.confusion.sum(1) * self.mask)
+        return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm
+
+    @property
+    def directed_micro_recall(self):
+        norm = numpy.sum(self.confusion.sum(0) * self.mask)
+        return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm
+
+    @property
+    def directed_micro_f1(self):
+        norm = self.directed_micro_precision + self.directed_micro_recall
+        return 0 if norm == 0 else 2 * (self.directed_micro_precision * self.directed_micro_recall) / norm
+
+    @property
+    def half_directed_micro_precision(self):
+        norm = numpy.sum(self.confusion.sum(1) * self.mask)
+        return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm
+
+    @property
+    def half_directed_micro_recall(self):
+        norm = numpy.sum(self.confusion.sum(0) * self.mask)
+        return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm
+
+    @property
+    def half_directed_micro_f1(self):
+        norm = self.half_directed_micro_precision + self.half_directed_micro_recall
+        return 0 if norm == 0 else 2 * (self.half_directed_micro_precision * self.half_directed_micro_recall) / norm
+
+    @property
+    def undirected_micro_precision(self):
+        norm = numpy.sum(self.base_confusion.sum(1) * self.base_mask)
+        return 0 if norm == 0 else numpy.sum(self.base_confusion.diagonal() * self.base_mask) / norm
+
+    @property
+    def undirected_micro_recall(self):
+        norm = numpy.sum(self.base_confusion.sum(0) * self.base_mask)
+        return 0 if norm == 0 else numpy.sum(self.base_confusion.diagonal() * self.base_mask) / norm
+
+    @property
+    def undirected_micro_f1(self):
+        norm = self.undirected_micro_precision + self.undirected_micro_recall
+        return 0 if norm == 0 else 2 * (self.undirected_micro_precision * self.undirected_micro_recall) / norm
diff --git a/fsre/train.py b/fsre/train.py
@@ -1,6 +1,8 @@
 import sys
 import os
 import math
+import time
+import contextlib
 import multiprocessing
 import signal
 import logging
@@ -19,15 +21,20 @@ class Trainer:
     Train a model.
 
     Config:
+        Model: the model class to use for training
         batch_size: the number of samples in the batch of data loaded
-        true_batch_size: the actual number of sample in a batch, the number of
-            sample seen before a backward (must be a multiple of batch_size)
+        bert_model: the model of transformer to use
         dataset_name: name of the dataset to load
         deterministic: run in deterministic mode
+        early_stopping_patience: how many epoch to train after best validation score has been reached
         learning_rate: learning rate
         max_epoch: maximum number of epoch
+        no_initial_validation: do not run evaluation on the valid dataset before first epoch
         seed: the seed for the random number generator
         sort_per_shuffle_bucket: the number of sort buckets in a shuffle bucket
+        test_output: path to a file where the test predictions will be written
+        true_batch_size: the actual number of sample in a batch, the number of sample seen before a backward (must be a multiple of batch_size)
+        validation_metric: metric used for early stopping
     """
 
     def __init__(self, config, logdir, state_dicts=None):
@@ -83,7 +90,11 @@ class Trainer:
             logger.info(f"GPU{i} {gp.name} {gp.total_memory} {gp.major}.{gp.minor}")
 
     def info(self):
-        print(f"logdir is \033[1m\033[33m{self.logdir}\033[0m")
+        if self.logdir is None:
+            print(f"logdir is \033[1m\033[31mnot set\033[0m, log messages will be discarded")
+        else:
+            print(f"logdir is \033[1m\033[33m{self.logdir}\033[0m")
+
         self.environment_check()
         self.detect_gpus()
         print("")
@@ -103,7 +114,7 @@ class Trainer:
             suffix = ""
             if self.state_dicts:
                 suffix = time.strftime("%FT%H:%M:%S")
-            fsre.utils.save_patch(self.logdir / "patch{suffix}")
+            fsre.utils.save_patch(self.logdir / f"patch{suffix}")
 
     def initialize_rng(self):
         if self.state_dicts:
@@ -120,7 +131,7 @@ class Trainer:
 
     def prepare_dataset(self):
         data_dir = fsre.utils.DATA_PATH / self.config.dataset_name / self.config.bert_model
-        self.relation_dictionary = fsre.data.RelationDictionary(data_dir / "relations")
+        self.relation_dictionary = fsre.data.RelationDictionary(path=data_dir / "relations")
         self.tokenizer = transformers.BertTokenizer.from_pretrained(data_dir / "tokenizer")
 
         self.dataset = {}
@@ -178,6 +189,22 @@ class Trainer:
 
         signal.signal(signal.SIGINT, handler)
 
+    def eval_context(self, dataset):
+        self.test_output_file = None
+        if dataset == "test" and self.config.get("test_output"):
+            self.test_output_file = open(self.config.test_output, 'w')
+            return self.test_output_file
+        return contextlib.nullcontext()
+
+    def eval_handle_predictions(self, ids, predictions):
+        if self.test_output_file is None:
+            return
+        ids = ids.tolist()
+        predictions = predictions.argmax(1).tolist()
+        for id, prediction in zip(ids, predictions):
+            prediction = self.relation_dictionary.decode(prediction)
+            self.test_output_file.write(f"{id}\t{prediction}\n")
+
     def evaluate(self, dataset):
         loop = tqdm.tqdm(
                 iterable=self.iterator[dataset](),
@@ -188,36 +215,45 @@ class Trainer:
                 leave=False)
 
         self.model.eval()
-        with torch.no_grad():
-            scorer = fsre.metrics.Metrics(len(self.relation_dictionary))
+        output_prediction_to_file = True
+        with torch.no_grad(), self.eval_context(dataset):
+            has_target = False
+            scorer = fsre.metrics.Metrics(self.relation_dictionary)
             correct_prediction = 0
             for batch in loop:
                 batch = {key: value.to(self.device) for key, value in batch.items()}
 
                 # Pop the target to ensure it's not used by the model
-                target = batch.pop("relation")
+                if "relation" in batch:
+                    target = batch.pop("relation")
+                    has_target = True
                 predictions = self.model(batch)
-                scorer.update(predictions, target)
-                loop.set_postfix(**scorer.summary, refresh=False)
-
-            print(f"Epoch {self.epoch} {dataset:5} accuracy: {scorer.accuracy*100:8.4f}%   F1: {scorer.f1*100:8.4f}%   (P: {scorer.precision*100:8.4f}% R: {scorer.recall*100:8.4f}%)    NLL: {scorer.nll:8.4f}")
-            logger.info(f"epoch {self.epoch} {dataset} accuracy {scorer.accuracy} F1 {scorer.f1} precision {scorer.precision} recall {scorer.recall} NLL {scorer.nll}")
-            return getattr(scorer, self.config.validation_metric)
-
-    def save(self, path, full):
-        state_dicts = {"model": self.model.state_dict()}
-        if full:
-            state_dicts.update({
-                    "logdir": self.logdir,
-                    "optimizer": self.optimizer.state_dict(),
-                    "train_rng": self.dataset["train"].state_dict(),
-                    "torch_rng": torch.random.get_rng_state(),
-                    "epoch": self.epoch,
-                    "best_epoch": self.best_epoch,
-                    "best_eval": self.best_eval,
-                })
-            if torch.cuda.is_available():
-                state_dicts["cuda_rng"] = torch.cuda.random.get_rng_state_all()
+                self.eval_handle_predictions(batch["id"], predictions)
+                if has_target:
+                    scorer.update(predictions, target)
+                    loop.set_postfix(**scorer.summary, refresh=False)
+
+            if has_target:
+                print(f"Epoch {self.epoch} {dataset:5} accuracy: {scorer.accuracy*100:8.4f}%   Half-directed Macro F1: {scorer.half_directed_macro_f1*100:8.4f}%   (P: {scorer.half_directed_macro_precision*100:8.4f}% R: {scorer.half_directed_macro_recall*100:8.4f}%)    NLL: {scorer.nll:8.4f}")
+                for metric, value in scorer.all.items():
+                    logger.info(f"epoch {self.epoch} {dataset} {metric} {value}")
+                return getattr(scorer, self.config.validation_metric)
+            return None
+
+    def save(self, path):
+        state_dicts = {
+                "logdir": self.logdir,
+                "config": self.config,
+                "model": self.model.state_dict(),
+                "optimizer": self.optimizer.state_dict(),
+                "train_rng": self.dataset["train"].state_dict(),
+                "torch_rng": torch.random.get_rng_state(),
+                "epoch": self.epoch,
+                "best_epoch": self.best_epoch,
+                "best_eval": self.best_eval,
+            }
+        if torch.cuda.is_available():
+            state_dicts["cuda_rng"] = torch.cuda.random.get_rng_state_all()
         torch.save(state_dicts, path)
 
     def train(self):
@@ -275,14 +311,14 @@ class Trainer:
             print(f"Epoch {self.epoch} train mean loss: {total_loss / total_sample:8.4f}")
             logger.info(f"epoch {self.epoch} train mean_loss {total_loss / total_sample}")
 
-            self.save(self.logdir / "checkpoint.new", full=True)
+            self.save(self.logdir / "checkpoint.new")
             os.rename(self.logdir / "checkpoint.new", self.logdir / "checkpoint")
 
             candidate = self.evaluate("valid")
             if candidate > self.best_eval:
                 self.best_epoch = self.epoch
                 self.best_eval = candidate
-                self.save(best_path, full=False)
+                self.save(best_path)
                 logger.info(f"Model saved to {best_path}")
             elif self.epoch - self.best_epoch > self.config.get("early_stopping_patience", self.config.max_epoch):
                 break
@@ -302,10 +338,10 @@ if __name__ == "__main__":
 
     state_dicts = None
     if config.get("load"):
-        state_dicts = torch.load(config.load, strict=False)
-
-    if config.get("load"):
+        state_dicts = torch.load(config.load)
         logdir = state_dicts["logdir"]
+        if config.get("reuse_config"):
+            config = state_dicts["config"]
     else:
         logdir = fsre.utils.logdir_name("FSRE")
         assert(not logdir.exists())