gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

linguistic_encoder.py (4018B)


      1 from typing import Callable, Optional, Tuple
      2 
      3 import torch
      4 import transformers
      5 
      6 import gbure.utils
      7 
      8 
      9 class LinguisticEncoder(torch.nn.Module):
     10     """
     11     Transformer encoder from Soares et al.
     12 
     13     Correspond to the left part of each subfigure of Figure 2 (Deep Transformer Encoder and the green layer above).
     14     We only implement the "entity markers, entity start" variant (which is the one with the best performance).
     15 
     16     Config:
     17         transformer_model: Which transformer to use (e.g. bert-large-uncased).
     18         post_transformer_layer: The transformation applied after the transformer (must be "linear" or "layer_norm")
     19     """
     20 
     21     def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, transformer: Optional[transformers.PreTrainedModel] = None) -> None:
     22         """
     23         Instantiate a Soares et al. encoder.
     24 
     25         Args:
     26             config: global config object
     27             tokenizer: tokenizer used to create the vocabulary
     28             transformer: the transformer to use instead of loading a pre-trained one
     29         """
     30         super().__init__()
     31 
     32         self.config: gbure.utils.dotdict = config
     33         self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
     34 
     35         self.transformer: transformers.PreTrainedModel
     36         if transformer is not None:
     37             self.transformer = transformer
     38         elif self.config.get("load") or self.config.get("pretrained"):
     39             # TODO introduce a config parameter to change the initialization of <tags> embeddings
     40             transformer_config = transformers.AutoConfig.from_pretrained(self.config.transformer_model)
     41             transformer_config.vocab_size = len(tokenizer)
     42             self.transformer = transformers.AutoModel.from_config(transformer_config)
     43         else:
     44             self.transformer = transformers.AutoModel.from_pretrained(self.config.transformer_model)
     45             self.transformer.resize_token_embeddings(len(tokenizer))
     46 
     47         self.post_transformer: Callable[[torch.Tensor], torch.Tensor]
     48         if self.config.post_transformer_layer == "linear":
     49             self.post_transformer_linear = torch.nn.Linear(in_features=self.output_size, out_features=self.output_size)
     50             self.post_transformer = lambda x: self.post_transformer_linear(x)
     51         elif self.config.post_transformer_layer == "layer_norm":
     52             # It is not clear whether a Linear should be added before the layer_norm, see Soares et al. section 3.3
     53             # Setting elementwise_affine to True (the default) makes little sense when computing similarity scores.
     54             self.post_transformer_linear = torch.nn.Linear(in_features=self.output_size, out_features=self.output_size)
     55             self.post_transformer_activation = torch.nn.LayerNorm(self.output_size, elementwise_affine=False)
     56             self.post_transformer = lambda x: self.post_transformer_activation(self.post_transformer_linear(x))
     57         elif self.config.post_transformer_layer == "none":
     58             self.post_transformer = lambda x: x
     59         else:
     60             raise RuntimeError("Unsuported config value for post_transformer_layer")
     61 
     62     @property
     63     def output_size(self) -> int:
     64         """ Dimension of the representation returned by the model. """
     65         return 2 * self.transformer.config.hidden_size
     66 
     67     def forward(self, text: torch.Tensor, mask: torch.Tensor, entity_positions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
     68         """ Encode the given text into a fixed size representation. """
     69         batch_size: int = text.shape[0]
     70         batch_ids: torch.Tensor = torch.arange(batch_size, device=text.device, dtype=torch.int64).unsqueeze(1)
     71 
     72         # The first element of the tuple is the Batch×Sentence×Hidden output matrix.
     73         transformer_out: torch.Tensor = self.transformer(text, attention_mask=mask)[0]
     74         sentence: torch.Tensor = transformer_out[batch_ids, entity_positions].view(batch_size, self.output_size)
     75         return self.post_transformer(sentence), transformer_out