gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

metrics.py (13389B)


      1 from typing import Dict, Optional, Union
      2 import math
      3 
      4 import torch
      5 import transformers
      6 
      7 from gbure.data.dictionary import RelationDictionary
      8 import gbure.data.graph
      9 
     10 
     11 class Metrics:
     12     """
     13     Class for computing metrics.
     14 
     15     Twenty metrics are computed:
     16         - Optimized loss (usually negative log likelihood)
     17         - Accuracy
     18         - {directed, undirected, half_directed} {micro, macro} {f1, precision, recall}
     19     Note that the Accuracy is the true accuracy, taking directionality into account and scoring the unknown relation as any other relation.
     20     The last 18 metrics follow the SemEval scorer:
     21         - The unknown ("Other") relation is only scored indirectly
     22         - Directed is equivalent to the metrics "USING DIRECTIONALITY"
     23         - Undirected is equivalent to the metrics "IGNORING DIRECTIONALITY"
     24         - Half-directed is equivalent to the metrics "TAKING DIRECTIONALITY INTO ACCOUNT -- OFFICIAL"
     25     Note that the directed and half_directed micro metrics are equivalents.
     26     """
     27 
     28     def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: RelationDictionary, graph: Optional[gbure.data.graph.Graph]) -> None:
     29         """ Initialize all metrics. """
     30         self.config: gbure.utils.dotdict = config
     31         self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
     32         self.relation_dictionary: RelationDictionary = relation_dictionary
     33         self.graph: Optional[gbure.data.graph.Graph] = graph
     34         self.num_relations: int = len(relation_dictionary)
     35         self.num_base_relations: int = relation_dictionary.base_size()
     36 
     37         self.mask: torch.Tensor = self.build_mask()
     38         self.base_transition: torch.Tensor = self.build_base_transition()
     39         self.base_mask: torch.Tensor = self.base_transition.t().mv(self.mask).clamp(0, 1)
     40 
     41         self.loss_sum: float = 0.0
     42         self.relation_buckets: int = 1 + self.config.get("neighborhood_size", 0)
     43         self.per_bucket_confusion: torch.Tensor = torch.zeros((self.relation_buckets, self.num_relations, self.num_relations), dtype=torch.int32)
     44 
     45     @property
     46     def confusion(self):
     47         return self.per_bucket_confusion.sum(0)
     48 
     49     def build_mask(self) -> torch.Tensor:
     50         """ Return the relation mask used by semeval scorer (which partly ignore the unknown relation). """
     51         mask: torch.Tensor = torch.ones(self.num_relations)
     52         if self.relation_dictionary.unknown is not None:
     53             assert(self.relation_dictionary.decode(0) == self.relation_dictionary.unknown)
     54             mask[0] = 0
     55         return mask
     56 
     57     def build_base_transition(self) -> torch.Tensor:
     58         """ Return the transition matrix from "directed relations" to "undirected relations". """
     59         base_transition: torch.Tensor = torch.zeros((self.num_relations, self.num_base_relations))
     60         for id, bid in enumerate(self.relation_dictionary.id_to_bid):
     61             base_transition[id, bid] = 1
     62         return base_transition
     63 
     64     def compute_neighborhood_bucket(self, batch: Dict[str, torch.Tensor], index: int) -> int:
     65         """ Return an index between 0 and self.config.neighborhood_size corresponding to the minimum number of neighbors in the sample. """
     66         if self.graph is None:
     67             return 0
     68         neighborhood_size: Union[float, int] = math.inf
     69         for feature, value in batch.items():
     70             if "degree" in feature and "neighborhood" not in feature:
     71                 neighborhood_size = min(neighborhood_size, value[index].min().item())
     72                 # TODO here substract 1 for unsupervised (not that important since we don't care about unsupervised accuracies)
     73         return min(neighborhood_size, self.relation_buckets-1) if neighborhood_size != math.inf else 0  # pytype: disable=bad-return-type
     74 
     75     def update(self, batch: Dict[str, torch.Tensor], loss: torch.Tensor, losses: Dict[str, torch.Tensor], variables: Dict[str, torch.Tensor]) -> None:
     76         """
     77         Update metrics according to the given batch and the outputs of the model on this batch.
     78 
     79         The variables dictionary returned by the model should usually contain a predicted_relation tensor.
     80 
     81         Args:
     82             batch: the input values used for evaluation
     83             loss: the loss optimized by the model
     84             losses: intermediary (unweighted) losses
     85             variables: internal variables used by the model to compute the loss
     86         """
     87         predictions: torch.Tensor = variables.get("predicted_relation")
     88         targets: torch.Tensor = batch.get("relation")
     89         if targets is None:
     90             targets = batch.get("query_relation")
     91 
     92         if predictions is None and targets is None:
     93             predictions = variables.get("prediction_relative")
     94             targets = batch.get("answer")
     95 
     96         for i, (prediction, target) in enumerate(zip(predictions, targets)):
     97             neighborhood_bucket: int = self.compute_neighborhood_bucket(batch, i)
     98             self.per_bucket_confusion[neighborhood_bucket, prediction, target] += 1
     99 
    100         batch_size: int = predictions.shape[0]
    101         self.loss_sum += loss.item() * batch_size
    102 
    103     @property
    104     def summary(self) -> Dict[str, str]:
    105         """ Return a summary of metrics to be quickly displayed. """
    106         metrics: Dict[str, str] = {"accuracy": f"{self.accuracy*100:.2f}",
    107                                    "loss": f"{self.loss:.2f}"}
    108         if self.relation_buckets > 1:
    109             metrics.update({"accuracy_non_empty": f"{self.accuracy_non_empty*100:.2f}",
    110                             "accuracy_full": f"{self.accuracy_full*100:.2f}"})
    111         return metrics
    112 
    113     @property
    114     def all(self) -> Dict[str, float]:
    115         """ Return a dictionary of all metrics. """
    116         keys = ["accuracy", "accuracy_non_empty", "accuracy_full", "loss"] + [
    117                 f"{direction}_{level}_{metric}"
    118                 for direction in ["directed", "undirected", "half_directed"]
    119                 for level in ["macro", "micro"]
    120                 for metric in ["f1", "precision", "recall"]]
    121         return {key: getattr(self, key) for key in keys}
    122 
    123     @property
    124     def base_confusion(self) -> torch.Tensor:
    125         """ Confusion matrix between "undirected" relation classes. """
    126         return self.base_transition.t().mm(self.confusion.type_as(self.base_transition)).mm(self.base_transition)
    127 
    128     @property
    129     def accuracy(self) -> float:
    130         return math.nan if self.confusion.sum() == 0 else self.confusion.diagonal().sum() / self.confusion.sum().type(torch.float32)
    131 
    132     @property
    133     def accuracy_non_empty(self) -> float:
    134         non_empty: torch.Tensor = self.per_bucket_confusion[1:].sum(0)
    135         return math.nan if non_empty.sum() == 0 else non_empty.diagonal().sum() / non_empty.sum().type(torch.float32)
    136 
    137     @property
    138     def accuracy_full(self) -> float:
    139         full: torch.Tensor = self.per_bucket_confusion[-1]
    140         return math.nan if full.sum() == 0 else full.diagonal().sum() / full.sum().type(torch.float32)
    141 
    142     @property
    143     def loss(self) -> float:
    144         return math.nan if self.confusion.sum() == 0 else self.loss_sum / self.confusion.sum().type(torch.float32)
    145 
    146     ##########################
    147     # Directed macro metrics #
    148     ##########################
    149 
    150     @property
    151     def directed_class_precision(self) -> torch.Tensor:
    152         norm: torch.Tensor = self.confusion.sum(1)
    153         norm[norm == 0] = 1
    154         return self.confusion.diagonal() / norm.type(torch.float32)
    155 
    156     @property
    157     def directed_class_recall(self) -> torch.Tensor:
    158         norm: torch.Tensor = self.confusion.sum(0)
    159         norm[norm == 0] = 1
    160         return self.confusion.diagonal() / norm.type(torch.float32)
    161 
    162     @property
    163     def directed_class_f1(self) -> torch.Tensor:
    164         norm: torch.Tensor = self.directed_class_precision + self.directed_class_recall
    165         norm[norm == 0] = 1
    166         return 2 * self.directed_class_precision * self.directed_class_recall / norm
    167 
    168     @property
    169     def directed_macro_precision(self) -> float:
    170         return ((self.directed_class_precision * self.mask).sum() / self.mask.sum()).item()
    171 
    172     @property
    173     def directed_macro_recall(self) -> float:
    174         return ((self.directed_class_recall * self.mask).sum() / self.mask.sum()).item()
    175 
    176     @property
    177     def directed_macro_f1(self) -> float:
    178         return ((self.directed_class_f1 * self.mask).sum() / self.mask.sum()).item()
    179 
    180     ############################
    181     # Undirected macro metrics #
    182     ############################
    183 
    184     @property
    185     def undirected_class_precision(self) -> torch.Tensor:
    186         norm: torch.Tensor = self.base_confusion.sum(1)
    187         norm[norm == 0] = 1
    188         return self.base_confusion.diagonal() / norm
    189 
    190     @property
    191     def undirected_class_recall(self) -> torch.Tensor:
    192         norm: torch.Tensor = self.base_confusion.sum(0)
    193         norm[norm == 0] = 1
    194         return self.base_confusion.diagonal() / norm
    195 
    196     @property
    197     def undirected_class_f1(self) -> torch.Tensor:
    198         norm: torch.Tensor = self.undirected_class_precision + self.undirected_class_recall
    199         norm[norm == 0] = 1
    200         return 2 * self.undirected_class_precision * self.undirected_class_recall / norm
    201 
    202     @property
    203     def undirected_macro_precision(self) -> float:
    204         return ((self.undirected_class_precision * self.base_mask).sum() / self.base_mask.sum()).item()
    205 
    206     @property
    207     def undirected_macro_recall(self) -> float:
    208         return ((self.undirected_class_recall * self.base_mask).sum() / self.base_mask.sum()).item()
    209 
    210     @property
    211     def undirected_macro_f1(self) -> float:
    212         return ((self.undirected_class_f1 * self.base_mask).sum() / self.base_mask.sum()).item()
    213 
    214     ###############################
    215     # Half-directed macro metrics #
    216     ###############################
    217 
    218     @property
    219     def half_directed_class_precision(self) -> torch.Tensor:
    220         norm: torch.Tensor = self.base_confusion.sum(1)
    221         norm[norm == 0] = 1
    222         return self.base_transition.t().mv(self.confusion.diagonal().type_as(self.base_transition)) / norm
    223 
    224     @property
    225     def half_directed_class_recall(self) -> torch.Tensor:
    226         norm: torch.Tensor = self.base_confusion.sum(0)
    227         norm[norm == 0] = 1
    228         return self.base_transition.t().mv(self.confusion.diagonal().type_as(self.base_transition)) / norm
    229 
    230     @property
    231     def half_directed_class_f1(self) -> torch.Tensor:
    232         norm: torch.Tensor = self.half_directed_class_precision + self.half_directed_class_recall
    233         norm[norm == 0] = 1
    234         return 2 * self.half_directed_class_precision * self.half_directed_class_recall / norm
    235 
    236     @property
    237     def half_directed_macro_precision(self) -> float:
    238         return ((self.half_directed_class_precision * self.base_mask).sum() / self.base_mask.sum()).item()
    239 
    240     @property
    241     def half_directed_macro_recall(self) -> float:
    242         return ((self.half_directed_class_recall * self.base_mask).sum() / self.base_mask.sum()).item()
    243 
    244     @property
    245     def half_directed_macro_f1(self) -> float:
    246         return ((self.half_directed_class_f1 * self.base_mask).sum() / self.base_mask.sum()).item()
    247 
    248     #################
    249     # Micro metrics #
    250     #################
    251 
    252     @property
    253     def directed_micro_precision(self) -> float:
    254         norm: torch.Tensor = (self.confusion.sum(1) * self.mask).sum()
    255         return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item()
    256 
    257     @property
    258     def directed_micro_recall(self) -> float:
    259         norm: torch.Tensor = (self.confusion.sum(0) * self.mask).sum()
    260         return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item()
    261 
    262     @property
    263     def directed_micro_f1(self) -> float:
    264         norm: float = self.directed_micro_precision + self.directed_micro_recall
    265         return 0 if norm == 0 else 2 * (self.directed_micro_precision * self.directed_micro_recall) / norm
    266 
    267     @property
    268     def half_directed_micro_precision(self) -> float:
    269         norm: torch.Tensor = (self.confusion.sum(1) * self.mask).sum()
    270         return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item()
    271 
    272     @property
    273     def half_directed_micro_recall(self) -> float:
    274         norm: torch.Tensor = (self.confusion.sum(0) * self.mask).sum()
    275         return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item()
    276 
    277     @property
    278     def half_directed_micro_f1(self) -> float:
    279         norm: float = self.half_directed_micro_precision + self.half_directed_micro_recall
    280         return 0 if norm == 0 else 2 * (self.half_directed_micro_precision * self.half_directed_micro_recall) / norm
    281 
    282     @property
    283     def undirected_micro_precision(self) -> float:
    284         norm: torch.Tensor = (self.base_confusion.sum(1) * self.base_mask).sum()
    285         return 0 if norm == 0 else ((self.base_confusion.diagonal() * self.base_mask).sum() / norm).item()
    286 
    287     @property
    288     def undirected_micro_recall(self) -> float:
    289         norm: torch.Tensor = (self.base_confusion.sum(0) * self.base_mask).sum()
    290         return 0 if norm == 0 else ((self.base_confusion.diagonal() * self.base_mask).sum() / norm).item()
    291 
    292     @property
    293     def undirected_micro_f1(self) -> float:
    294         norm: float = self.undirected_micro_precision + self.undirected_micro_recall
    295         return 0 if norm == 0 else 2 * (self.undirected_micro_precision * self.undirected_micro_recall) / norm