gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

masked_lm.py (2472B)


      1 from typing import Callable
      2 
      3 import torch
      4 import transformers
      5 
      6 import gbure.utils
      7 
      8 
      9 class MaskedLM(torch.nn.Module):
     10     """
     11     Masked language model to be used on top of a transformer.
     12 
     13     This class is only useful for unsupervised pre-training, Soares et al. keep the BERT loss alongside their "matching the blanks" loss.
     14 
     15     Config:
     16         transformer_model: Which transformer to use (e.g. bert-large-uncased).
     17     """
     18 
     19     def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer) -> None:
     20         """
     21         Instantiate a masked language model.
     22 
     23         Args:
     24             config: global config object
     25             tokenizer: tokenizer used to create the vocabulary
     26         """
     27         super().__init__()
     28 
     29         self.config: gbure.utils.dotdict = config
     30         self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
     31 
     32         self.transformer: transformers.PreTrainedModel
     33         if self.config.get("load") or self.config.get("pretrained"):
     34             # TODO introduce a config parameter to change the initialization of <tags> embeddings
     35             transformer_config = transformers.AutoConfig.from_pretrained(self.config.transformer_model)
     36             transformer_config.vocab_size = len(tokenizer)
     37             self.transformer = transformers.AutoModelForMaskedLM(transformer_config)
     38         else:
     39             self.transformer = transformers.AutoModelForMaskedLM.from_pretrained(self.config.transformer_model)
     40             self.transformer.resize_token_embeddings(len(tokenizer))
     41 
     42     @property
     43     def encoder(self) -> transformers.PreTrainedModel:
     44         if isinstance(self.transformer, transformers.BertForMaskedLM):
     45             return self.transformer.bert
     46         elif isinstance(self.transformer, transformers.DistilBertForMaskedLM):
     47             return self.transformer.distilbert
     48         else:
     49             raise RuntimeError("Unknown transformer model, can't split masked language model off")
     50 
     51     def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
     52         """
     53         Compute masked language model loss from transformer's output.
     54 
     55         Sadly, huggingface does not provide an unified interface, so we need to do some copy-pasting.
     56         """
     57         masked_target: torch.Tensor = target * mask - 100 * (1 - mask.long())
     58         output = self.transformer(input_ids=input, attention_mask=mask, labels=masked_target, return_dict=True)
     59         return output.loss