# gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git

metrics.py (13389B)

```      1 from typing import Dict, Optional, Union
2 import math
3
4 import torch
5 import transformers
6
7 from gbure.data.dictionary import RelationDictionary
8 import gbure.data.graph
9
10
11 class Metrics:
12     """
13     Class for computing metrics.
14
15     Twenty metrics are computed:
16         - Optimized loss (usually negative log likelihood)
17         - Accuracy
18         - {directed, undirected, half_directed} {micro, macro} {f1, precision, recall}
19     Note that the Accuracy is the true accuracy, taking directionality into account and scoring the unknown relation as any other relation.
20     The last 18 metrics follow the SemEval scorer:
21         - The unknown ("Other") relation is only scored indirectly
22         - Directed is equivalent to the metrics "USING DIRECTIONALITY"
23         - Undirected is equivalent to the metrics "IGNORING DIRECTIONALITY"
24         - Half-directed is equivalent to the metrics "TAKING DIRECTIONALITY INTO ACCOUNT -- OFFICIAL"
25     Note that the directed and half_directed micro metrics are equivalents.
26     """
27
28     def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: RelationDictionary, graph: Optional[gbure.data.graph.Graph]) -> None:
29         """ Initialize all metrics. """
30         self.config: gbure.utils.dotdict = config
31         self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
32         self.relation_dictionary: RelationDictionary = relation_dictionary
33         self.graph: Optional[gbure.data.graph.Graph] = graph
34         self.num_relations: int = len(relation_dictionary)
35         self.num_base_relations: int = relation_dictionary.base_size()
36
38         self.base_transition: torch.Tensor = self.build_base_transition()
40
41         self.loss_sum: float = 0.0
42         self.relation_buckets: int = 1 + self.config.get("neighborhood_size", 0)
43         self.per_bucket_confusion: torch.Tensor = torch.zeros((self.relation_buckets, self.num_relations, self.num_relations), dtype=torch.int32)
44
45     @property
46     def confusion(self):
47         return self.per_bucket_confusion.sum(0)
48
50         """ Return the relation mask used by semeval scorer (which partly ignore the unknown relation). """
52         if self.relation_dictionary.unknown is not None:
53             assert(self.relation_dictionary.decode(0) == self.relation_dictionary.unknown)
56
57     def build_base_transition(self) -> torch.Tensor:
58         """ Return the transition matrix from "directed relations" to "undirected relations". """
59         base_transition: torch.Tensor = torch.zeros((self.num_relations, self.num_base_relations))
60         for id, bid in enumerate(self.relation_dictionary.id_to_bid):
61             base_transition[id, bid] = 1
62         return base_transition
63
64     def compute_neighborhood_bucket(self, batch: Dict[str, torch.Tensor], index: int) -> int:
65         """ Return an index between 0 and self.config.neighborhood_size corresponding to the minimum number of neighbors in the sample. """
66         if self.graph is None:
67             return 0
68         neighborhood_size: Union[float, int] = math.inf
69         for feature, value in batch.items():
70             if "degree" in feature and "neighborhood" not in feature:
71                 neighborhood_size = min(neighborhood_size, value[index].min().item())
72                 # TODO here substract 1 for unsupervised (not that important since we don't care about unsupervised accuracies)
73         return min(neighborhood_size, self.relation_buckets-1) if neighborhood_size != math.inf else 0  # pytype: disable=bad-return-type
74
75     def update(self, batch: Dict[str, torch.Tensor], loss: torch.Tensor, losses: Dict[str, torch.Tensor], variables: Dict[str, torch.Tensor]) -> None:
76         """
77         Update metrics according to the given batch and the outputs of the model on this batch.
78
79         The variables dictionary returned by the model should usually contain a predicted_relation tensor.
80
81         Args:
82             batch: the input values used for evaluation
83             loss: the loss optimized by the model
84             losses: intermediary (unweighted) losses
85             variables: internal variables used by the model to compute the loss
86         """
87         predictions: torch.Tensor = variables.get("predicted_relation")
88         targets: torch.Tensor = batch.get("relation")
89         if targets is None:
90             targets = batch.get("query_relation")
91
92         if predictions is None and targets is None:
93             predictions = variables.get("prediction_relative")
95
96         for i, (prediction, target) in enumerate(zip(predictions, targets)):
97             neighborhood_bucket: int = self.compute_neighborhood_bucket(batch, i)
98             self.per_bucket_confusion[neighborhood_bucket, prediction, target] += 1
99
100         batch_size: int = predictions.shape[0]
101         self.loss_sum += loss.item() * batch_size
102
103     @property
104     def summary(self) -> Dict[str, str]:
105         """ Return a summary of metrics to be quickly displayed. """
106         metrics: Dict[str, str] = {"accuracy": f"{self.accuracy*100:.2f}",
107                                    "loss": f"{self.loss:.2f}"}
108         if self.relation_buckets > 1:
109             metrics.update({"accuracy_non_empty": f"{self.accuracy_non_empty*100:.2f}",
110                             "accuracy_full": f"{self.accuracy_full*100:.2f}"})
111         return metrics
112
113     @property
114     def all(self) -> Dict[str, float]:
115         """ Return a dictionary of all metrics. """
116         keys = ["accuracy", "accuracy_non_empty", "accuracy_full", "loss"] + [
117                 f"{direction}_{level}_{metric}"
118                 for direction in ["directed", "undirected", "half_directed"]
119                 for level in ["macro", "micro"]
120                 for metric in ["f1", "precision", "recall"]]
121         return {key: getattr(self, key) for key in keys}
122
123     @property
124     def base_confusion(self) -> torch.Tensor:
125         """ Confusion matrix between "undirected" relation classes. """
126         return self.base_transition.t().mm(self.confusion.type_as(self.base_transition)).mm(self.base_transition)
127
128     @property
129     def accuracy(self) -> float:
130         return math.nan if self.confusion.sum() == 0 else self.confusion.diagonal().sum() / self.confusion.sum().type(torch.float32)
131
132     @property
133     def accuracy_non_empty(self) -> float:
134         non_empty: torch.Tensor = self.per_bucket_confusion[1:].sum(0)
135         return math.nan if non_empty.sum() == 0 else non_empty.diagonal().sum() / non_empty.sum().type(torch.float32)
136
137     @property
138     def accuracy_full(self) -> float:
139         full: torch.Tensor = self.per_bucket_confusion[-1]
140         return math.nan if full.sum() == 0 else full.diagonal().sum() / full.sum().type(torch.float32)
141
142     @property
143     def loss(self) -> float:
144         return math.nan if self.confusion.sum() == 0 else self.loss_sum / self.confusion.sum().type(torch.float32)
145
146     ##########################
147     # Directed macro metrics #
148     ##########################
149
150     @property
151     def directed_class_precision(self) -> torch.Tensor:
152         norm: torch.Tensor = self.confusion.sum(1)
153         norm[norm == 0] = 1
154         return self.confusion.diagonal() / norm.type(torch.float32)
155
156     @property
157     def directed_class_recall(self) -> torch.Tensor:
158         norm: torch.Tensor = self.confusion.sum(0)
159         norm[norm == 0] = 1
160         return self.confusion.diagonal() / norm.type(torch.float32)
161
162     @property
163     def directed_class_f1(self) -> torch.Tensor:
164         norm: torch.Tensor = self.directed_class_precision + self.directed_class_recall
165         norm[norm == 0] = 1
166         return 2 * self.directed_class_precision * self.directed_class_recall / norm
167
168     @property
169     def directed_macro_precision(self) -> float:
171
172     @property
173     def directed_macro_recall(self) -> float:
175
176     @property
177     def directed_macro_f1(self) -> float:
179
180     ############################
181     # Undirected macro metrics #
182     ############################
183
184     @property
185     def undirected_class_precision(self) -> torch.Tensor:
186         norm: torch.Tensor = self.base_confusion.sum(1)
187         norm[norm == 0] = 1
188         return self.base_confusion.diagonal() / norm
189
190     @property
191     def undirected_class_recall(self) -> torch.Tensor:
192         norm: torch.Tensor = self.base_confusion.sum(0)
193         norm[norm == 0] = 1
194         return self.base_confusion.diagonal() / norm
195
196     @property
197     def undirected_class_f1(self) -> torch.Tensor:
198         norm: torch.Tensor = self.undirected_class_precision + self.undirected_class_recall
199         norm[norm == 0] = 1
200         return 2 * self.undirected_class_precision * self.undirected_class_recall / norm
201
202     @property
203     def undirected_macro_precision(self) -> float:
205
206     @property
207     def undirected_macro_recall(self) -> float:
209
210     @property
211     def undirected_macro_f1(self) -> float:
213
214     ###############################
215     # Half-directed macro metrics #
216     ###############################
217
218     @property
219     def half_directed_class_precision(self) -> torch.Tensor:
220         norm: torch.Tensor = self.base_confusion.sum(1)
221         norm[norm == 0] = 1
222         return self.base_transition.t().mv(self.confusion.diagonal().type_as(self.base_transition)) / norm
223
224     @property
225     def half_directed_class_recall(self) -> torch.Tensor:
226         norm: torch.Tensor = self.base_confusion.sum(0)
227         norm[norm == 0] = 1
228         return self.base_transition.t().mv(self.confusion.diagonal().type_as(self.base_transition)) / norm
229
230     @property
231     def half_directed_class_f1(self) -> torch.Tensor:
232         norm: torch.Tensor = self.half_directed_class_precision + self.half_directed_class_recall
233         norm[norm == 0] = 1
234         return 2 * self.half_directed_class_precision * self.half_directed_class_recall / norm
235
236     @property
237     def half_directed_macro_precision(self) -> float:
239
240     @property
241     def half_directed_macro_recall(self) -> float:
243
244     @property
245     def half_directed_macro_f1(self) -> float:
247
248     #################
249     # Micro metrics #
250     #################
251
252     @property
253     def directed_micro_precision(self) -> float:
254         norm: torch.Tensor = (self.confusion.sum(1) * self.mask).sum()
255         return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item()
256
257     @property
258     def directed_micro_recall(self) -> float:
259         norm: torch.Tensor = (self.confusion.sum(0) * self.mask).sum()
260         return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item()
261
262     @property
263     def directed_micro_f1(self) -> float:
264         norm: float = self.directed_micro_precision + self.directed_micro_recall
265         return 0 if norm == 0 else 2 * (self.directed_micro_precision * self.directed_micro_recall) / norm
266
267     @property
268     def half_directed_micro_precision(self) -> float:
269         norm: torch.Tensor = (self.confusion.sum(1) * self.mask).sum()
270         return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item()
271
272     @property
273     def half_directed_micro_recall(self) -> float:
274         norm: torch.Tensor = (self.confusion.sum(0) * self.mask).sum()
275         return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item()
276
277     @property
278     def half_directed_micro_f1(self) -> float:
279         norm: float = self.half_directed_micro_precision + self.half_directed_micro_recall
280         return 0 if norm == 0 else 2 * (self.half_directed_micro_precision * self.half_directed_micro_recall) / norm
281
282     @property
283     def undirected_micro_precision(self) -> float:
284         norm: torch.Tensor = (self.base_confusion.sum(1) * self.base_mask).sum()
285         return 0 if norm == 0 else ((self.base_confusion.diagonal() * self.base_mask).sum() / norm).item()
286
287     @property
288     def undirected_micro_recall(self) -> float:
289         norm: torch.Tensor = (self.base_confusion.sum(0) * self.base_mask).sum()
290         return 0 if norm == 0 else ((self.base_confusion.diagonal() * self.base_mask).sum() / norm).item()
291
292     @property
293     def undirected_micro_f1(self) -> float:
294         norm: float = self.undirected_micro_precision + self.undirected_micro_recall
295         return 0 if norm == 0 else 2 * (self.undirected_micro_precision * self.undirected_micro_recall) / norm
```