masked_lm.py (2472B)
1 from typing import Callable 2 3 import torch 4 import transformers 5 6 import gbure.utils 7 8 9 class MaskedLM(torch.nn.Module): 10 """ 11 Masked language model to be used on top of a transformer. 12 13 This class is only useful for unsupervised pre-training, Soares et al. keep the BERT loss alongside their "matching the blanks" loss. 14 15 Config: 16 transformer_model: Which transformer to use (e.g. bert-large-uncased). 17 """ 18 19 def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer) -> None: 20 """ 21 Instantiate a masked language model. 22 23 Args: 24 config: global config object 25 tokenizer: tokenizer used to create the vocabulary 26 """ 27 super().__init__() 28 29 self.config: gbure.utils.dotdict = config 30 self.tokenizer: transformers.PreTrainedTokenizer = tokenizer 31 32 self.transformer: transformers.PreTrainedModel 33 if self.config.get("load") or self.config.get("pretrained"): 34 # TODO introduce a config parameter to change the initialization of <tags> embeddings 35 transformer_config = transformers.AutoConfig.from_pretrained(self.config.transformer_model) 36 transformer_config.vocab_size = len(tokenizer) 37 self.transformer = transformers.AutoModelForMaskedLM(transformer_config) 38 else: 39 self.transformer = transformers.AutoModelForMaskedLM.from_pretrained(self.config.transformer_model) 40 self.transformer.resize_token_embeddings(len(tokenizer)) 41 42 @property 43 def encoder(self) -> transformers.PreTrainedModel: 44 if isinstance(self.transformer, transformers.BertForMaskedLM): 45 return self.transformer.bert 46 elif isinstance(self.transformer, transformers.DistilBertForMaskedLM): 47 return self.transformer.distilbert 48 else: 49 raise RuntimeError("Unknown transformer model, can't split masked language model off") 50 51 def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 52 """ 53 Compute masked language model loss from transformer's output. 54 55 Sadly, huggingface does not provide an unified interface, so we need to do some copy-pasting. 56 """ 57 masked_target: torch.Tensor = target * mask - 100 * (1 - mask.long()) 58 output = self.transformer(input_ids=input, attention_mask=mask, labels=masked_target, return_dict=True) 59 return output.loss