fewshot.py (6423B)
1 from typing import Dict, Optional, Tuple 2 3 import torch 4 import transformers 5 6 import gbure.data.dictionary 7 import gbure.model.linguistic_encoder 8 import gbure.model.similarity 9 import gbure.model.topological_encoder 10 import gbure.utils 11 12 13 class Model(torch.nn.Module): 14 """ 15 Few shot model from Soares et al. 16 17 Correspond to the right subfigure of Figure 2. 18 """ 19 20 def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: gbure.data.dictionary.RelationDictionary, train_model: Optional[torch.nn.Module] = None) -> None: 21 """ 22 Instantiate a Soares et al. few shot model. 23 24 Args: 25 config: global config object 26 tokenizer: tokenizer used to create the vocabulary 27 relation_dictionary: dictionary of all relations 28 train_model: unsupervised model used to initialize the few shot model. 29 """ 30 super().__init__() 31 32 self.config: gbure.utils.dotdict = config 33 self.tokenizer: transformers.PreTrainedTokenizer = tokenizer 34 self.relation_dictionary: gbure.data.dictionary.RelationDictionary = relation_dictionary 35 36 if train_model is None: 37 self.linguistic_encoder: torch.nn.Module = gbure.model.linguistic_encoder.LinguisticEncoder(config, tokenizer) 38 self.linguistic_similarity: torch.nn.Module = gbure.model.similarity.LinguisticSimilarity(config) 39 else: 40 self.linguistic_encoder: torch.nn.Module = train_model.linguistic_encoder 41 self.linguistic_similarity: torch.nn.Module = train_model.linguistic_similarity 42 self.loss_module = torch.nn.NLLLoss(reduction="mean") 43 44 if self.config.get("neighborhood_size", 0) > 0: 45 if train_model is None: 46 self.topological_encoder: torch.nn.Module = gbure.model.topological_encoder.TopologicalEncoder(config, self.linguistic_encoder) 47 self.topological_similarity: torch.nn.Module = gbure.model.similarity.TopologicalSimilarity(config) 48 else: 49 self.topological_encoder: torch.nn.Module = train_model.topological_encoder 50 self.topological_similarity: torch.nn.Module = train_model.topological_similarity 51 if not self.config.get("undefined_poison_whole_meta"): 52 if train_model is not None: 53 self.neutral_topological_similarity = train_model.neutral_topological_similarity 54 elif self.config.get("neutral_topological_similarity") is not None: 55 self.neutral_topological_similarity: float = self.config.neutral_topological_similarity 56 else: 57 self.neutral_topological_similarity: torch.nn.Parameter = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) 58 59 def combine_similarities(self, linguistic: torch.Tensor, topological: torch.Tensor, topological_mask: Optional[torch.Tensor]) -> torch.Tensor: 60 """ Combine linguistic and topological similarities into a single value. """ 61 # topological is of dimension (batch, way, shot, slot) 62 if topological_mask is not None: 63 if self.config.get("undefined_poison_whole_meta"): 64 topological *= topological_mask.prod(1, keepdim=True).prod(2, keepdim=True) 65 else: 66 topological += (~topological_mask) * self.neutral_topological_similarity 67 topological = topological.mean(3) 68 return self.config.get("linguistic_weight", 1) * linguistic + self.config.get("topological_weight", 1) * topological 69 70 def forward(self, 71 query_text: torch.Tensor, 72 query_mask: torch.Tensor, 73 query_entity_positions: torch.Tensor, 74 candidates_text: torch.Tensor, 75 candidates_mask: torch.Tensor, 76 candidates_entity_positions: torch.Tensor, 77 answer: torch.Tensor, 78 query_relation: Optional[torch.Tensor] = None, 79 candidates_relation: Optional[torch.Tensor] = None, 80 **batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: 81 """ Compute the fewshot loss on the given query and candidates. """ 82 batch_size: int = query_text.shape[0] 83 fake_batch_size: int = batch_size * self.config.shot * self.config.way 84 85 query: torch.Tensor = self.linguistic_encoder(query_text, query_mask, query_entity_positions)[0] 86 candidates: torch.Tensor = self.linguistic_encoder( 87 candidates_text.view(fake_batch_size, -1), 88 candidates_mask.view(fake_batch_size, -1), 89 candidates_entity_positions.view(fake_batch_size, 2) 90 )[0].view(batch_size, self.config.way, self.config.shot, -1) 91 logits: torch.Tensor = self.linguistic_similarity(candidates, query.unsqueeze(1).unsqueeze(2)) 92 93 if self.config.get("neighborhood_size", 0) > 0: 94 topological_query = self.topological_encoder("query_", query, degree_delta=1, **batch) 95 topological_candidates = self.topological_encoder("candidates_", candidates, degree_delta=1, **batch) 96 if isinstance(topological_query, torch.Tensor): 97 topological_query = topological_query.view(topological_query.shape[0], 1, 1, -1) 98 else: 99 topological_query = tuple(x.view(x.shape[0], 1, 1, *x.shape[1:]) for x in topological_query) 100 101 topological_similarity, topological_mask = self.topological_similarity(topological_query, topological_candidates) 102 logits = self.combine_similarities(logits, topological_similarity, topological_mask) 103 104 log_probabilities = torch.nn.functional.log_softmax( 105 logits.view(batch_size, self.config.way*self.config.shot), 106 dim=1).view(batch_size, self.config.way, self.config.shot) 107 log_probabilities = log_probabilities.logsumexp(2) 108 109 loss: torch.Tensor = self.loss_module(log_probabilities, answer) 110 prediction: torch.Tensor = log_probabilities.argmax(1) 111 112 variables = {"prediction_logits": logits, "prediction_relative": prediction} 113 if candidates_relation is not None: 114 batch_ids: torch.Tensor = torch.arange(batch_size, device=loss.device) 115 variables["predicted_relation"] = candidates_relation[batch_ids, prediction, 0] 116 117 return loss, {}, variables