gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

contrastive_alignment.py (4570B)


      1 from typing import Any, Dict, List, Tuple
      2 
      3 import torch
      4 import transformers
      5 
      6 from gbure.model.linguistic_encoder import LinguisticEncoder
      7 from gbure.model.masked_lm import MaskedLM
      8 from gbure.model.similarity import LinguisticSimilarity, TopologicalSimilarity
      9 from gbure.model.topological_encoder import TopologicalEncoder
     10 import gbure.utils
     11 
     12 
     13 class Model(torch.nn.Module):
     14     """
     15     Unsupervised pre-training model from Soares et al.
     16 
     17     Correspond to the model explained in section 4.
     18     """
     19 
     20     def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: None) -> None:
     21         """
     22         Instantiate a Soares et al. matching the blanks model.
     23 
     24         Args:
     25             config: global config object
     26             tokenizer: tokenizer used to create the vocabulary
     27             relation_dictionary: dictionary of all relations (unused)
     28             margin: maximum enforced meta-distance between positive and negative distances
     29         """
     30         super().__init__()
     31 
     32         self.config: gbure.utils.dotdict = config
     33         self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
     34 
     35         self.transformer: transformers.PreTrainedModel = transformers.AutoModelForMaskedLM
     36         self.language_model: torch.nn.Module = MaskedLM(config, tokenizer)
     37         self.linguistic_encoder: torch.nn.Module = LinguisticEncoder(config, tokenizer, transformer=self.language_model.encoder)
     38         self.topological_encoder: torch.nn.Module = TopologicalEncoder(config, self.linguistic_encoder)
     39         self.linguistic_similarity: torch.nn.Module = LinguisticSimilarity(config)
     40         self.topological_similarity: torch.nn.Module = TopologicalSimilarity(config)
     41 
     42     def forward(self, **batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
     43         """ Compute the unsupervised matching the blanks loss between the given pairs. """
     44         linguistic_embeddings: List[torch.Tensor] = []
     45         topological_embeddings: List[Any] = []
     46 
     47         for i, order in enumerate(["first_", "second_", "third_"]):
     48             linguistic_embeddings.append(self.linguistic_encoder(batch[f"{order}text"], batch[f"{order}mask"], batch[f"{order}entity_positions"])[0])
     49             topological_embeddings.append(self.topological_encoder(order, linguistic_embeddings[-1], **batch))
     50 
     51         # d(a_1, a_2, [1, 0])
     52         positive_linguistic_similarity: torch.Tensor = self.linguistic_similarity(linguistic_embeddings[0], linguistic_embeddings[1])
     53         # d(a_1, a_3, [1, 0])
     54         negative_linguistic_similarity: torch.Tensor = self.linguistic_similarity(linguistic_embeddings[0], linguistic_embeddings[2])
     55 
     56         # d(a_1, a_2, [0, 1])
     57         positive_topological_similarity: torch.Tensor = self.topological_similarity(topological_embeddings[0], topological_embeddings[1])[0]
     58         # d(a_1, a_3, [0, 1])
     59         negative_topological_similarity: torch.Tensor = self.topological_similarity(topological_embeddings[0], topological_embeddings[2])[0]
     60 
     61         # (d(a_1, a_2, [1, 0]) - d(a_1, a_2, [0, 1]))²
     62         positive: torch.Tensor = 2 * (positive_linguistic_similarity - positive_topological_similarity)**2
     63         # (d(a_1, a_3, [1, 0]) - d(a_1, a_2, [0, 1]))² + (d(a_1, a_2, [1, 0]) - d(a_1, a_3, [0, 1]))²
     64         negative: torch.Tensor = ((positive_linguistic_similarity - negative_topological_similarity)**2 + (negative_linguistic_similarity - positive_topological_similarity)**2)
     65 
     66         contrastive_loss: torch.Tensor = torch.nn.functional.relu(self.config.margin + positive - negative).mean()
     67         if self.config.get("language_model_weight", 0) > 0:
     68             lm_loss: torch.Tensor = self.language_model(batch["first_mlm_input"], batch["first_mlm_target"], batch["first_mask"])
     69         else:
     70             lm_loss: int = 0
     71         loss: torch.Tensor = contrastive_loss + self.config.get("language_model_weight", 0) * lm_loss
     72 
     73         losses: Dict[str, torch.Tensor] = {
     74                 "positive": positive.mean(),
     75                 "negative": negative.mean(),
     76                 "contrastive": contrastive_loss,
     77                 "reconstruction": lm_loss}
     78         variables: Dict[str, torch.Tensor] = {
     79                 "positive_linguistic_similarity": positive_linguistic_similarity,
     80                 "negative_linguistic_similarity": negative_linguistic_similarity,
     81                 "positive_topological_similarity": positive_topological_similarity,
     82                 "negative_topological_similarity": negative_topological_similarity
     83             }
     84 
     85         return loss, losses, variables