matching_the_blanks.py (5796B)
1 from typing import Dict, Optional, Tuple 2 import math 3 4 import torch 5 import transformers 6 7 import gbure.model.linguistic_encoder 8 import gbure.model.topological_encoder 9 import gbure.model.masked_lm 10 import gbure.model.similarity 11 import gbure.utils 12 13 14 class Model(torch.nn.Module): 15 """ 16 Unsupervised pre-training model from Soares et al. 17 18 Correspond to the model explained in Section 4. 19 20 Config: 21 linguistic_weight: factor for the linguistic similarity 22 topological_weight: factor for the topological similarity 23 """ 24 25 def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: None) -> None: 26 """ 27 Instantiate a Soares et al. matching the blanks model. 28 29 Args: 30 config: global config object 31 tokenizer: tokenizer used to create the vocabulary 32 relation_dictionary: dictionary of all relations (unused) 33 """ 34 super().__init__() 35 36 self.config: gbure.utils.dotdict = config 37 self.tokenizer: transformers.PreTrainedTokenizer = tokenizer 38 39 self.transformer: transformers.PreTrainedModel = transformers.AutoModelForMaskedLM 40 self.language_model: torch.nn.Module = gbure.model.masked_lm.MaskedLM(config, tokenizer) 41 self.linguistic_encoder: torch.nn.Module = gbure.model.linguistic_encoder.LinguisticEncoder(config, tokenizer, transformer=self.language_model.encoder) 42 self.linguistic_similarity: torch.nn.Module = gbure.model.similarity.LinguisticSimilarity(config) 43 44 if self.config.get("neighborhood_size", 0) > 0: 45 self.topological_encoder: torch.nn.Module = gbure.model.topological_encoder.TopologicalEncoder(config, self.linguistic_encoder) 46 self.topological_similarity: torch.nn.Module = gbure.model.similarity.TopologicalSimilarity(config) 47 if not self.config.get("undefined_poison_whole_meta"): 48 if self.config.get("neutral_topological_similarity") is not None: 49 self.neutral_topological_similarity: float = self.config.neutral_topological_similarity 50 else: 51 self.neutral_topological_similarity: torch.nn.Parameter = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) 52 53 def combine_similarities(self, linguistic: torch.Tensor, topological: torch.Tensor, topological_mask: Optional[torch.Tensor]) -> torch.Tensor: 54 """ Combine linguistic and topological similarities into a single value. """ 55 # topological is of dimension (batch, slot) 56 if topological_mask is not None: 57 if not self.config.get("undefined_poison_whole_meta"): 58 topological += (~topological_mask) * self.neutral_topological_similarity 59 topological = topological.mean(1) 60 return self.config.get("linguistic_weight", 1) * linguistic + self.config.get("topological_weight", 1) * topological 61 62 def forward(self, 63 first_text: torch.Tensor, 64 first_mask: torch.Tensor, 65 first_entity_positions: torch.Tensor, 66 second_text: torch.Tensor, 67 second_mask: torch.Tensor, 68 second_entity_positions: torch.Tensor, 69 polarity: torch.Tensor, 70 first_mlm_input: torch.Tensor, 71 first_mlm_target: torch.Tensor, 72 **batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: 73 """ Compute the unsupervised matching the blanks loss between the given pairs. """ 74 first: torch.Tensor 75 first_transformer_out: torch.Tensor 76 first, first_transformer_out = self.linguistic_encoder(first_text, first_mask, first_entity_positions) 77 second: torch.Tensor = self.linguistic_encoder(second_text, second_mask, second_entity_positions)[0] 78 linguistic_similarity: torch.Tensor = self.linguistic_similarity(first, second) 79 lm_loss: torch.Tensor = self.language_model(first_mlm_input, first_mlm_target, first_mask) 80 81 similarity: torch.Tensor = linguistic_similarity 82 if self.config.get("neighborhood_size", 0) > 0: 83 topological_first = self.topological_encoder("first_", first, **batch) 84 topological_second = self.topological_encoder("second_", second, **batch) 85 86 topological_similarity, topological_mask = self.topological_similarity(topological_first, topological_second) 87 similarity = self.combine_similarities(linguistic_similarity, topological_similarity, topological_mask) 88 89 # There seem to be a mistake in Soares et al. ยง4.1 in the equation of p(l=1|r,r') 90 # The equation use 1/(1+exp(x)) which seems counter intuitive since the case where r=r' would lead to a low probability. 91 # By default 1/(1+exp(-x)) is used, but the equation given in the paper can be used with --reverse_sigmoid. 92 if self.config.get("reverse_sigmoid"): 93 scores: torch.Tensor = - torch.nn.functional.logsigmoid(- polarity * similarity) 94 else: 95 scores: torch.Tensor = - torch.nn.functional.logsigmoid(polarity * similarity) 96 mtb_loss: torch.Tensor = scores.mean() 97 98 is_positive: torch.Tensor = (polarity + 1) // 2 99 is_negative: torch.Tensor = 1 - is_positive 100 positive: torch.Tensor = (scores * is_positive).sum() / is_positive.sum() 101 negative: torch.Tensor = (scores * is_negative).sum() / is_negative.sum() 102 103 losses: Dict[str, torch.Tensor] = { 104 "positive": positive, 105 "negative": negative, 106 "mtb": mtb_loss, 107 "reconstruction": lm_loss} 108 loss: torch.Tensor = mtb_loss + self.config.language_model_weight * lm_loss 109 return loss, losses, {"similarity": similarity}