contrastive_alignment.py (4570B)
1 from typing import Any, Dict, List, Tuple 2 3 import torch 4 import transformers 5 6 from gbure.model.linguistic_encoder import LinguisticEncoder 7 from gbure.model.masked_lm import MaskedLM 8 from gbure.model.similarity import LinguisticSimilarity, TopologicalSimilarity 9 from gbure.model.topological_encoder import TopologicalEncoder 10 import gbure.utils 11 12 13 class Model(torch.nn.Module): 14 """ 15 Unsupervised pre-training model from Soares et al. 16 17 Correspond to the model explained in section 4. 18 """ 19 20 def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: None) -> None: 21 """ 22 Instantiate a Soares et al. matching the blanks model. 23 24 Args: 25 config: global config object 26 tokenizer: tokenizer used to create the vocabulary 27 relation_dictionary: dictionary of all relations (unused) 28 margin: maximum enforced meta-distance between positive and negative distances 29 """ 30 super().__init__() 31 32 self.config: gbure.utils.dotdict = config 33 self.tokenizer: transformers.PreTrainedTokenizer = tokenizer 34 35 self.transformer: transformers.PreTrainedModel = transformers.AutoModelForMaskedLM 36 self.language_model: torch.nn.Module = MaskedLM(config, tokenizer) 37 self.linguistic_encoder: torch.nn.Module = LinguisticEncoder(config, tokenizer, transformer=self.language_model.encoder) 38 self.topological_encoder: torch.nn.Module = TopologicalEncoder(config, self.linguistic_encoder) 39 self.linguistic_similarity: torch.nn.Module = LinguisticSimilarity(config) 40 self.topological_similarity: torch.nn.Module = TopologicalSimilarity(config) 41 42 def forward(self, **batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: 43 """ Compute the unsupervised matching the blanks loss between the given pairs. """ 44 linguistic_embeddings: List[torch.Tensor] = [] 45 topological_embeddings: List[Any] = [] 46 47 for i, order in enumerate(["first_", "second_", "third_"]): 48 linguistic_embeddings.append(self.linguistic_encoder(batch[f"{order}text"], batch[f"{order}mask"], batch[f"{order}entity_positions"])[0]) 49 topological_embeddings.append(self.topological_encoder(order, linguistic_embeddings[-1], **batch)) 50 51 # d(a_1, a_2, [1, 0]) 52 positive_linguistic_similarity: torch.Tensor = self.linguistic_similarity(linguistic_embeddings[0], linguistic_embeddings[1]) 53 # d(a_1, a_3, [1, 0]) 54 negative_linguistic_similarity: torch.Tensor = self.linguistic_similarity(linguistic_embeddings[0], linguistic_embeddings[2]) 55 56 # d(a_1, a_2, [0, 1]) 57 positive_topological_similarity: torch.Tensor = self.topological_similarity(topological_embeddings[0], topological_embeddings[1])[0] 58 # d(a_1, a_3, [0, 1]) 59 negative_topological_similarity: torch.Tensor = self.topological_similarity(topological_embeddings[0], topological_embeddings[2])[0] 60 61 # (d(a_1, a_2, [1, 0]) - d(a_1, a_2, [0, 1]))² 62 positive: torch.Tensor = 2 * (positive_linguistic_similarity - positive_topological_similarity)**2 63 # (d(a_1, a_3, [1, 0]) - d(a_1, a_2, [0, 1]))² + (d(a_1, a_2, [1, 0]) - d(a_1, a_3, [0, 1]))² 64 negative: torch.Tensor = ((positive_linguistic_similarity - negative_topological_similarity)**2 + (negative_linguistic_similarity - positive_topological_similarity)**2) 65 66 contrastive_loss: torch.Tensor = torch.nn.functional.relu(self.config.margin + positive - negative).mean() 67 if self.config.get("language_model_weight", 0) > 0: 68 lm_loss: torch.Tensor = self.language_model(batch["first_mlm_input"], batch["first_mlm_target"], batch["first_mask"]) 69 else: 70 lm_loss: int = 0 71 loss: torch.Tensor = contrastive_loss + self.config.get("language_model_weight", 0) * lm_loss 72 73 losses: Dict[str, torch.Tensor] = { 74 "positive": positive.mean(), 75 "negative": negative.mean(), 76 "contrastive": contrastive_loss, 77 "reconstruction": lm_loss} 78 variables: Dict[str, torch.Tensor] = { 79 "positive_linguistic_similarity": positive_linguistic_similarity, 80 "negative_linguistic_similarity": negative_linguistic_similarity, 81 "positive_topological_similarity": positive_topological_similarity, 82 "negative_topological_similarity": negative_topological_similarity 83 } 84 85 return loss, losses, variables