topological_encoder.py (3592B)
1 from typing import List, Optional, Tuple, Union 2 import functools 3 import math 4 import operator 5 6 import torch 7 8 import gbure.utils 9 10 11 class TopologicalEncoder(torch.nn.Module): 12 """ 13 Encoder for neighborhoods. 14 15 Config: 16 gcn_aggregator: aggregator used to pool the representations of several neighbors into a fixed-size one. 17 """ 18 19 def __init__(self, config: gbure.utils.dotdict, linguistic_encoder: torch.nn.Module) -> None: 20 """ 21 Instantiate a Soares et al. encoder. 22 23 Args: 24 config: global config object 25 linguistic_encoder: the model used to get a fixed-size representation of text 26 """ 27 super().__init__() 28 self.config: gbure.utils.dotdict = config 29 self.linguistic_encoder: torch.nn.Module = linguistic_encoder 30 31 if self.config.get("gcn_aggregator", "") in ["mean", "chebyshev"]: 32 self.gcn_layer: torch.nn.Module = torch.nn.Linear(in_features=self.linguistic_encoder.output_size, out_features=self.linguistic_encoder.output_size) 33 34 def forward(self, prefix: str, loop: torch.Tensor, degree_delta: int = 0, **batch) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: 35 """ 36 Encode the neighborhood of the given prefix. 37 38 When a gcn_aggregator is defined, this result in a fixed-size representation, otherwise it returns clouds of points to be compared using optimal transport. 39 The degree_delta parameters changes the degrees used to compute various GCN weighting. This can be useful when the sample comes from outside the graph so 1 should be added to the degrees. 40 """ 41 linguistic_embeddings: List[torch.Tensor] = [] 42 masks: List[torch.Tensor] = [] 43 for slot in [1, 2]: 44 fake_batch_size: int = functools.reduce(operator.mul, batch[f"{prefix}e{slot}_neighborhood_text"].shape[:-1]) 45 mask: torch.Tensor = batch[f"{prefix}e{slot}_neighborhood_mask"].view(fake_batch_size, -1)[:, 0].unsqueeze(1) 46 linguistic_embeddings.append((self.linguistic_encoder( 47 batch[f"{prefix}e{slot}_neighborhood_text"].view(fake_batch_size, -1), 48 batch[f"{prefix}e{slot}_neighborhood_mask"].view(fake_batch_size, -1), 49 batch[f"{prefix}e{slot}_neighborhood_entity_positions"].view(fake_batch_size, 2) 50 )[0] * mask).view(*batch[f"{prefix}e{slot}_neighborhood_text"].shape[:-1], self.linguistic_encoder.output_size)) 51 masks.append(mask.view(*batch[f"{prefix}e{slot}_neighborhood_text"].shape[:-1])) 52 53 if self.config.get("gcn_aggregator", "") == "mean": 54 head: torch.Tensor = linguistic_embeddings[0].sum(-2) 55 tail: torch.Tensor = linguistic_embeddings[1].sum(-2) 56 neighborhood_size: torch.Tensor = sum(mask.sum(-1, keepdim=True) for mask in masks) 57 return self.gcn_layer((loop + head + tail) / (neighborhood_size + 1)) 58 elif self.config.get("gcn_aggregator", "") == "chebyshev": 59 pre_embedding: torch.Tensor = loop / torch.sqrt(2 * (batch[f"{prefix}entity_degrees"].sum(-1, keepdim=True) - 1 + degree_delta)) 60 for slot in [1, 2]: 61 weights = 1 / torch.sqrt(batch[f"{prefix}e{slot}_neighborhood_entity_degrees"].sum(-1, keepdim=True) - 1 + degree_delta) 62 pre_embedding += torch.sum(weights * linguistic_embeddings[slot-1], dim=-2) 63 return self.gcn_layer(pre_embedding) 64 else: 65 return (linguistic_embeddings[0], linguistic_embeddings[1], masks[0], masks[1])