gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

dictionary.py (3608B)


      1 from typing import Any, Dict, List, Optional
      2 import pathlib
      3 import pickle
      4 
      5 
      6 class Dictionary:
      7     keys: List[str] = ["keys", "unknown", "decoder", "encoder"]
      8 
      9     def __init__(self, *, unknown: Optional[str] = None, path: Optional[pathlib.Path] = None) -> None:
     10         self.encoder: Dict[str, int] = {}
     11         self.decoder: List[str] = []
     12 
     13         self.unknown: Optional[str] = unknown
     14         if unknown is not None:
     15             self.encoder[unknown] = 0
     16             self.decoder.append(unknown)
     17 
     18         if path is not None:
     19             self.load(path)
     20 
     21     def __len__(self) -> int:
     22         """ Number of tokens in the dictionary. """
     23         return len(self.decoder)
     24 
     25     def encode(self, token: str) -> int:
     26         """ Returns the id corresponding to a token. """
     27         id: Optional[int] = self.encoder.get(token)
     28         if id is not None:
     29             return id
     30 
     31         id = len(self.decoder)
     32         self.encoder[token] = id
     33         self.decoder.append(token)
     34         return id
     35 
     36     def decode(self, id: int) -> str:
     37         """ Returns the token corresponding to an id.  """
     38         return self.decoder[id]
     39 
     40     def save(self, path: pathlib.Path) -> None:
     41         with path.open("wb") as file:
     42             pickle.dump({key: getattr(self, key) for key in self.keys}, file)
     43 
     44     def load(self, path: pathlib.Path) -> None:
     45         with path.open("rb") as file:
     46             data: Dict[str, Any] = pickle.load(file)
     47             for key, value in data.items():
     48                 setattr(self, key, value)
     49 
     50 
     51 class RelationDictionary(Dictionary):
     52     """
     53     A dictionary to be used for relations.
     54 
     55     The tokens held by this class are divided between:
     56         - *relation* such as "Entity-Destination(e1,e2)"
     57         - *base* such as "Entity-Destination"
     58     """
     59 
     60     keys = ["keys", "unknown", "decoder", "encoder", "base_encoder", "base_decoder", "id_to_bid"]
     61 
     62     def __init__(self, *, unknown: Optional[str] = None, path: Optional[pathlib.Path] = None) -> None:
     63         self.base_encoder: Dict[str, int] = {}
     64         self.base_decoder: List[str] = []
     65         self.id_to_bid: List[int] = []
     66 
     67         if unknown is not None:
     68             self.base_encoder[unknown] = 0
     69             self.base_decoder.append(unknown)
     70             self.id_to_bid.append(0)
     71 
     72         super().__init__(unknown=unknown, path=path)
     73 
     74     def base_size(self) -> int:
     75         """ Number of bases in the dictionary. """
     76         return len(self.base_decoder)
     77 
     78     def encode(self, relation: str, base: Optional[str] = None) -> int:
     79         """
     80         Returns the id corresponding to a relation string.
     81 
     82         If base is none, do not attempt to insert a new id and returns the id of relation immediately
     83 
     84         Args:
     85             relation: the string of the relation (e.g. "Entity-Destination(e1,e2)")
     86             base: the string of the base relation (e.g. "Entity-Destination")
     87         """
     88         if relation is None:
     89             return None
     90 
     91         if base is None:
     92             return self.encoder[relation]
     93 
     94         id: Optional[int] = self.encoder.get(relation)
     95         if id is not None:
     96             return id
     97 
     98         bid: Optional[int] = self.base_encoder.get(base)
     99         if bid is None:
    100             bid = len(self.base_decoder)
    101             self.base_encoder[base] = bid
    102             self.base_decoder.append(base)
    103 
    104         id = len(self.decoder)
    105         self.encoder[relation] = id
    106         self.decoder.append(relation)
    107         self.id_to_bid.append(bid)
    108         return id
    109 
    110     def base_id(self, id: int) -> int:
    111         """ Returns the base id corresponding to a relation id.  """
    112         return self.id_to_bid[id]