gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

matching_the_blanks.py (5796B)


      1 from typing import Dict, Optional, Tuple
      2 import math
      3 
      4 import torch
      5 import transformers
      6 
      7 import gbure.model.linguistic_encoder
      8 import gbure.model.topological_encoder
      9 import gbure.model.masked_lm
     10 import gbure.model.similarity
     11 import gbure.utils
     12 
     13 
     14 class Model(torch.nn.Module):
     15     """
     16     Unsupervised pre-training model from Soares et al.
     17 
     18     Correspond to the model explained in Section 4.
     19 
     20     Config:
     21         linguistic_weight: factor for the linguistic similarity
     22         topological_weight: factor for the topological similarity
     23     """
     24 
     25     def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: None) -> None:
     26         """
     27         Instantiate a Soares et al. matching the blanks model.
     28 
     29         Args:
     30             config: global config object
     31             tokenizer: tokenizer used to create the vocabulary
     32             relation_dictionary: dictionary of all relations (unused)
     33         """
     34         super().__init__()
     35 
     36         self.config: gbure.utils.dotdict = config
     37         self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
     38 
     39         self.transformer: transformers.PreTrainedModel = transformers.AutoModelForMaskedLM
     40         self.language_model: torch.nn.Module = gbure.model.masked_lm.MaskedLM(config, tokenizer)
     41         self.linguistic_encoder: torch.nn.Module = gbure.model.linguistic_encoder.LinguisticEncoder(config, tokenizer, transformer=self.language_model.encoder)
     42         self.linguistic_similarity: torch.nn.Module = gbure.model.similarity.LinguisticSimilarity(config)
     43 
     44         if self.config.get("neighborhood_size", 0) > 0:
     45             self.topological_encoder: torch.nn.Module = gbure.model.topological_encoder.TopologicalEncoder(config, self.linguistic_encoder)
     46             self.topological_similarity: torch.nn.Module = gbure.model.similarity.TopologicalSimilarity(config)
     47             if not self.config.get("undefined_poison_whole_meta"):
     48                 if self.config.get("neutral_topological_similarity") is not None:
     49                     self.neutral_topological_similarity: float = self.config.neutral_topological_similarity
     50                 else:
     51                     self.neutral_topological_similarity: torch.nn.Parameter = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
     52 
     53     def combine_similarities(self, linguistic: torch.Tensor, topological: torch.Tensor, topological_mask: Optional[torch.Tensor]) -> torch.Tensor:
     54         """ Combine linguistic and topological similarities into a single value. """
     55         # topological is of dimension (batch, slot)
     56         if topological_mask is not None:
     57             if not self.config.get("undefined_poison_whole_meta"):
     58                 topological += (~topological_mask) * self.neutral_topological_similarity
     59             topological = topological.mean(1)
     60         return self.config.get("linguistic_weight", 1) * linguistic + self.config.get("topological_weight", 1) * topological
     61 
     62     def forward(self,
     63                 first_text: torch.Tensor,
     64                 first_mask: torch.Tensor,
     65                 first_entity_positions: torch.Tensor,
     66                 second_text: torch.Tensor,
     67                 second_mask: torch.Tensor,
     68                 second_entity_positions: torch.Tensor,
     69                 polarity: torch.Tensor,
     70                 first_mlm_input: torch.Tensor,
     71                 first_mlm_target: torch.Tensor,
     72                 **batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
     73         """ Compute the unsupervised matching the blanks loss between the given pairs. """
     74         first: torch.Tensor
     75         first_transformer_out: torch.Tensor
     76         first, first_transformer_out = self.linguistic_encoder(first_text, first_mask, first_entity_positions)
     77         second: torch.Tensor = self.linguistic_encoder(second_text, second_mask, second_entity_positions)[0]
     78         linguistic_similarity: torch.Tensor = self.linguistic_similarity(first, second)
     79         lm_loss: torch.Tensor = self.language_model(first_mlm_input, first_mlm_target, first_mask)
     80 
     81         similarity: torch.Tensor = linguistic_similarity
     82         if self.config.get("neighborhood_size", 0) > 0:
     83             topological_first = self.topological_encoder("first_", first, **batch)
     84             topological_second = self.topological_encoder("second_", second, **batch)
     85 
     86             topological_similarity, topological_mask = self.topological_similarity(topological_first, topological_second)
     87             similarity = self.combine_similarities(linguistic_similarity, topological_similarity, topological_mask)
     88 
     89         # There seem to be a mistake in Soares et al. ยง4.1 in the equation of p(l=1|r,r')
     90         # The equation use 1/(1+exp(x)) which seems counter intuitive since the case where r=r' would lead to a low probability.
     91         # By default 1/(1+exp(-x)) is used, but the equation given in the paper can be used with --reverse_sigmoid.
     92         if self.config.get("reverse_sigmoid"):
     93             scores: torch.Tensor = - torch.nn.functional.logsigmoid(- polarity * similarity)
     94         else:
     95             scores: torch.Tensor = - torch.nn.functional.logsigmoid(polarity * similarity)
     96         mtb_loss: torch.Tensor = scores.mean()
     97 
     98         is_positive: torch.Tensor = (polarity + 1) // 2
     99         is_negative: torch.Tensor = 1 - is_positive
    100         positive: torch.Tensor = (scores * is_positive).sum() / is_positive.sum()
    101         negative: torch.Tensor = (scores * is_negative).sum() / is_negative.sum()
    102 
    103         losses: Dict[str, torch.Tensor] = {
    104                 "positive": positive,
    105                 "negative": negative,
    106                 "mtb": mtb_loss,
    107                 "reconstruction": lm_loss}
    108         loss: torch.Tensor = mtb_loss + self.config.language_model_weight * lm_loss
    109         return loss, losses, {"similarity": similarity}