gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

fewshot.py (6423B)


      1 from typing import Dict, Optional, Tuple
      2 
      3 import torch
      4 import transformers
      5 
      6 import gbure.data.dictionary
      7 import gbure.model.linguistic_encoder
      8 import gbure.model.similarity
      9 import gbure.model.topological_encoder
     10 import gbure.utils
     11 
     12 
     13 class Model(torch.nn.Module):
     14     """
     15     Few shot model from Soares et al.
     16 
     17     Correspond to the right subfigure of Figure 2.
     18     """
     19 
     20     def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: gbure.data.dictionary.RelationDictionary, train_model: Optional[torch.nn.Module] = None) -> None:
     21         """
     22         Instantiate a Soares et al. few shot model.
     23 
     24         Args:
     25             config: global config object
     26             tokenizer: tokenizer used to create the vocabulary
     27             relation_dictionary: dictionary of all relations
     28             train_model: unsupervised model used to initialize the few shot model.
     29         """
     30         super().__init__()
     31 
     32         self.config: gbure.utils.dotdict = config
     33         self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
     34         self.relation_dictionary: gbure.data.dictionary.RelationDictionary = relation_dictionary
     35 
     36         if train_model is None:
     37             self.linguistic_encoder: torch.nn.Module = gbure.model.linguistic_encoder.LinguisticEncoder(config, tokenizer)
     38             self.linguistic_similarity: torch.nn.Module = gbure.model.similarity.LinguisticSimilarity(config)
     39         else:
     40             self.linguistic_encoder: torch.nn.Module = train_model.linguistic_encoder
     41             self.linguistic_similarity: torch.nn.Module = train_model.linguistic_similarity
     42         self.loss_module = torch.nn.NLLLoss(reduction="mean")
     43 
     44         if self.config.get("neighborhood_size", 0) > 0:
     45             if train_model is None:
     46                 self.topological_encoder: torch.nn.Module = gbure.model.topological_encoder.TopologicalEncoder(config, self.linguistic_encoder)
     47                 self.topological_similarity: torch.nn.Module = gbure.model.similarity.TopologicalSimilarity(config)
     48             else:
     49                 self.topological_encoder: torch.nn.Module = train_model.topological_encoder
     50                 self.topological_similarity: torch.nn.Module = train_model.topological_similarity
     51             if not self.config.get("undefined_poison_whole_meta"):
     52                 if train_model is not None:
     53                     self.neutral_topological_similarity = train_model.neutral_topological_similarity
     54                 elif self.config.get("neutral_topological_similarity") is not None:
     55                     self.neutral_topological_similarity: float = self.config.neutral_topological_similarity
     56                 else:
     57                     self.neutral_topological_similarity: torch.nn.Parameter = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
     58 
     59     def combine_similarities(self, linguistic: torch.Tensor, topological: torch.Tensor, topological_mask: Optional[torch.Tensor]) -> torch.Tensor:
     60         """ Combine linguistic and topological similarities into a single value. """
     61         # topological is of dimension (batch, way, shot, slot)
     62         if topological_mask is not None:
     63             if self.config.get("undefined_poison_whole_meta"):
     64                 topological *= topological_mask.prod(1, keepdim=True).prod(2, keepdim=True)
     65             else:
     66                 topological += (~topological_mask) * self.neutral_topological_similarity
     67             topological = topological.mean(3)
     68         return self.config.get("linguistic_weight", 1) * linguistic + self.config.get("topological_weight", 1) * topological
     69 
     70     def forward(self,
     71                 query_text: torch.Tensor,
     72                 query_mask: torch.Tensor,
     73                 query_entity_positions: torch.Tensor,
     74                 candidates_text: torch.Tensor,
     75                 candidates_mask: torch.Tensor,
     76                 candidates_entity_positions: torch.Tensor,
     77                 answer: torch.Tensor,
     78                 query_relation: Optional[torch.Tensor] = None,
     79                 candidates_relation: Optional[torch.Tensor] = None,
     80                 **batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
     81         """ Compute the fewshot loss on the given query and candidates. """
     82         batch_size: int = query_text.shape[0]
     83         fake_batch_size: int = batch_size * self.config.shot * self.config.way
     84 
     85         query: torch.Tensor = self.linguistic_encoder(query_text, query_mask, query_entity_positions)[0]
     86         candidates: torch.Tensor = self.linguistic_encoder(
     87                 candidates_text.view(fake_batch_size, -1),
     88                 candidates_mask.view(fake_batch_size, -1),
     89                 candidates_entity_positions.view(fake_batch_size, 2)
     90                 )[0].view(batch_size, self.config.way, self.config.shot, -1)
     91         logits: torch.Tensor = self.linguistic_similarity(candidates, query.unsqueeze(1).unsqueeze(2))
     92 
     93         if self.config.get("neighborhood_size", 0) > 0:
     94             topological_query = self.topological_encoder("query_", query, degree_delta=1, **batch)
     95             topological_candidates = self.topological_encoder("candidates_", candidates, degree_delta=1, **batch)
     96             if isinstance(topological_query, torch.Tensor):
     97                 topological_query = topological_query.view(topological_query.shape[0], 1, 1, -1)
     98             else:
     99                 topological_query = tuple(x.view(x.shape[0], 1, 1, *x.shape[1:]) for x in topological_query)
    100 
    101             topological_similarity, topological_mask = self.topological_similarity(topological_query, topological_candidates)
    102             logits = self.combine_similarities(logits, topological_similarity, topological_mask)
    103 
    104         log_probabilities = torch.nn.functional.log_softmax(
    105                 logits.view(batch_size, self.config.way*self.config.shot),
    106                 dim=1).view(batch_size, self.config.way, self.config.shot)
    107         log_probabilities = log_probabilities.logsumexp(2)
    108 
    109         loss: torch.Tensor = self.loss_module(log_probabilities, answer)
    110         prediction: torch.Tensor = log_probabilities.argmax(1)
    111 
    112         variables = {"prediction_logits": logits, "prediction_relative": prediction}
    113         if candidates_relation is not None:
    114             batch_ids: torch.Tensor = torch.arange(batch_size, device=loss.device)
    115             variables["predicted_relation"] = candidates_relation[batch_ids, prediction, 0]
    116 
    117         return loss, {}, variables