dictionary.py (3608B)
1 from typing import Any, Dict, List, Optional 2 import pathlib 3 import pickle 4 5 6 class Dictionary: 7 keys: List[str] = ["keys", "unknown", "decoder", "encoder"] 8 9 def __init__(self, *, unknown: Optional[str] = None, path: Optional[pathlib.Path] = None) -> None: 10 self.encoder: Dict[str, int] = {} 11 self.decoder: List[str] = [] 12 13 self.unknown: Optional[str] = unknown 14 if unknown is not None: 15 self.encoder[unknown] = 0 16 self.decoder.append(unknown) 17 18 if path is not None: 19 self.load(path) 20 21 def __len__(self) -> int: 22 """ Number of tokens in the dictionary. """ 23 return len(self.decoder) 24 25 def encode(self, token: str) -> int: 26 """ Returns the id corresponding to a token. """ 27 id: Optional[int] = self.encoder.get(token) 28 if id is not None: 29 return id 30 31 id = len(self.decoder) 32 self.encoder[token] = id 33 self.decoder.append(token) 34 return id 35 36 def decode(self, id: int) -> str: 37 """ Returns the token corresponding to an id. """ 38 return self.decoder[id] 39 40 def save(self, path: pathlib.Path) -> None: 41 with path.open("wb") as file: 42 pickle.dump({key: getattr(self, key) for key in self.keys}, file) 43 44 def load(self, path: pathlib.Path) -> None: 45 with path.open("rb") as file: 46 data: Dict[str, Any] = pickle.load(file) 47 for key, value in data.items(): 48 setattr(self, key, value) 49 50 51 class RelationDictionary(Dictionary): 52 """ 53 A dictionary to be used for relations. 54 55 The tokens held by this class are divided between: 56 - *relation* such as "Entity-Destination(e1,e2)" 57 - *base* such as "Entity-Destination" 58 """ 59 60 keys = ["keys", "unknown", "decoder", "encoder", "base_encoder", "base_decoder", "id_to_bid"] 61 62 def __init__(self, *, unknown: Optional[str] = None, path: Optional[pathlib.Path] = None) -> None: 63 self.base_encoder: Dict[str, int] = {} 64 self.base_decoder: List[str] = [] 65 self.id_to_bid: List[int] = [] 66 67 if unknown is not None: 68 self.base_encoder[unknown] = 0 69 self.base_decoder.append(unknown) 70 self.id_to_bid.append(0) 71 72 super().__init__(unknown=unknown, path=path) 73 74 def base_size(self) -> int: 75 """ Number of bases in the dictionary. """ 76 return len(self.base_decoder) 77 78 def encode(self, relation: str, base: Optional[str] = None) -> int: 79 """ 80 Returns the id corresponding to a relation string. 81 82 If base is none, do not attempt to insert a new id and returns the id of relation immediately 83 84 Args: 85 relation: the string of the relation (e.g. "Entity-Destination(e1,e2)") 86 base: the string of the base relation (e.g. "Entity-Destination") 87 """ 88 if relation is None: 89 return None 90 91 if base is None: 92 return self.encoder[relation] 93 94 id: Optional[int] = self.encoder.get(relation) 95 if id is not None: 96 return id 97 98 bid: Optional[int] = self.base_encoder.get(base) 99 if bid is None: 100 bid = len(self.base_decoder) 101 self.base_encoder[base] = bid 102 self.base_decoder.append(base) 103 104 id = len(self.decoder) 105 self.encoder[relation] = id 106 self.decoder.append(relation) 107 self.id_to_bid.append(bid) 108 return id 109 110 def base_id(self, id: int) -> int: 111 """ Returns the base id corresponding to a relation id. """ 112 return self.id_to_bid[id]