gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

train.py (21628B)


      1 from typing import Any, Callable, Dict, Iterable, List, Optional
      2 import contextlib
      3 import gc
      4 import logging
      5 import math
      6 import os
      7 import pathlib
      8 import shutil
      9 
     10 from torch.utils.tensorboard import SummaryWriter
     11 import torch
     12 import torch.utils
     13 import tqdm
     14 import transformers
     15 
     16 import gbure.data.batcher
     17 import gbure.data.dataset
     18 import gbure.data.dictionary
     19 import gbure.data.graph
     20 import gbure.metrics
     21 import gbure.outputs
     22 import gbure.utils
     23 
     24 logger = logging.getLogger(__name__)
     25 
     26 
     27 class Trainer(gbure.utils.Experiment):
     28     """
     29     Train a model.
     30 
     31     Config:
     32         Model: the model class to use for training and evaluation
     33         Optimizer: the optimizer class to use for training
     34         Scheduler: the learning rate scheduler class to use
     35         accumulated_batch_size: the actual number of sample in a batch after accumulation, that is the number of samples seen before an optimizer step (must be a multiple of batch_size)
     36         amp: enable automatic mixed precision and gradient scaler
     37         batch_size: the number of samples in the batch of data loaded
     38         eval_batch_size: batch size used for evaluation
     39         clip_gradient: the maximum norm of the gradient
     40         dataset_name: name of the dataset to load
     41         dataset_spec: dataset specification, usually None, can be used to select a (smaller) test version
     42         eval_dataset_name: overwrite evaluation dataset
     43         eval_dataset_spec: overwrite evaluation dataset specification
     44         early_stopping_patience: how many epoch to train after best validation score has been reached
     45         learning_rate: learning rate
     46         max_epoch: maximum number of epoch
     47         no_initial_validation: do not run evaluation on the valid dataset before first epoch
     48         optimizer_hyperparameters: hyperparameters for the optimizer (e.g. weight decay, etc)
     49         pretrained: path to a pretrained model to load
     50         scheduler_parameters: the parameters for initializing the Scheduler
     51         test_output: path to a file where the test predictions will be written
     52         transformer_model: the model of transformer to use
     53         unsupervised: train an unsupervised model
     54         validation_metric: metric used for early stopping
     55         workers: number of data generating workers to spawn
     56     """
     57 
     58     def init(self) -> None:
     59         """ Prepare training. """
     60         self.prepare_dataset()
     61         self.build_model()
     62         self.count_parameters()
     63         self.setup_optimizer()
     64         self.init_writer()
     65         self.init_epochs()
     66 
     67     def main(self) -> None:
     68         """ Run the experiment (i.e. here, train). """
     69         self.train()
     70 
     71     def prepare_dataset(self) -> None:
     72         """ Load datasets and create iterators. """
     73         dataset_spec: str = self.config.transformer_model + self.config.get("dataset_spec", "")
     74         self.data_dir: pathlib.Path = gbure.utils.DATA_PATH / self.config.dataset_name / dataset_spec
     75 
     76         self.tokenizer: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(str(self.data_dir / "tokenizer"))
     77         self.batcher = gbure.data.batcher.Batcher(self.tokenizer.pad_token_id)
     78         self.relation_dictionary: Optional[gbure.data.dictionary.RelationDictionary] = None if self.config.get("unsupervised") else gbure.data.dictionary.RelationDictionary(path=self.data_dir / "relations")
     79         entities_path: pathlib.Path = self.data_dir / "train" / "entities" if self.config.get("unsupervised") else self.data_dir / "entities"
     80         self.entity_dictionary: gbure.data.dictionary.Dictionary = gbure.data.dictionary.Dictionary(path=entities_path)
     81 
     82         if self.config.get("eval_dataset_name"):
     83             eval_dataset_spec: str = self.config.transformer_model + self.config.get("eval_dataset_spec", "")
     84             self.eval_data_dir: pathlib.Path = gbure.utils.DATA_PATH / self.config.eval_dataset_name / eval_dataset_spec
     85             # We assume the tokenizer is the same.
     86             self.eval_relation_dictionary: gbure.data.dictionary.RelationDictionary = gbure.data.dictionary.RelationDictionary(path=self.eval_data_dir / "relations")
     87             # FIXME doesn't work when a test dataset is specified …
     88             entities_path: pathlib.Path = self.data_dir / f"{self.config.valid_name}.entities" if self.config.get("valid_name") else self.data_dir / "entities"
     89             self.eval_entity_dictionary: gbure.data.dictionary.Dictionary = gbure.data.dictionary.Dictionary(path=self.eval_data_dir / "entities")
     90         else:
     91             self.eval_data_dir: pathlib.Path = self.data_dir
     92             self.eval_relation_dictionary: Optional[gbure.data.dictionary.RelationDictionary] = self.relation_dictionary
     93             self.eval_entity_dictionary: gbure.data.dictionary.Dictionary = self.entity_dictionary
     94 
     95         self.dataset: Dict[str, torch.utils.data.Dataset] = {}
     96         self.iterator: Dict[str, Callable[[], Any]] = {}
     97         graph_buffer: Optional[gbure.data.graph.Graph] = None
     98 
     99         for split in ["train", "valid", "test"]:
    100             kwargs: Dict[str, Any] = {"rng": self.state_dicts.get("train_rng")} if split == "train" and self.state_dicts else {}
    101             split_path = pathlib.Path(split)
    102             data_dir = self.data_dir if split == "train" else self.eval_data_dir
    103             if f"{split}_name" in self.config:
    104                 split_path = pathlib.Path(self.config[f"{split}_name"])
    105 
    106             try:
    107                 self.dataset[split] = gbure.data.dataset.load_dataset(
    108                     config=self.config,
    109                     split=split,
    110                     path=data_dir / split_path,
    111                     tokenizer=self.tokenizer,
    112                     evaluation=(split != "train"),
    113                     **kwargs)
    114             except FileNotFoundError:
    115                 if split == "train":
    116                     raise
    117                 continue
    118 
    119             if self.config.get("graph_name"):
    120                 graph_spec: str = self.config.transformer_model + self.config.get("graph_spec", "")
    121                 graph_dir: pathlib.Path = gbure.utils.DATA_PATH / self.config.graph_name / graph_spec
    122                 entity_dictionary = self.entity_dictionary if split == "train" else self.eval_entity_dictionary
    123                 self.dataset[split] = gbure.data.dataset.GraphAdapter(
    124                         self.dataset[split],
    125                         entity_dictionary,
    126                         path=graph_dir / "train",
    127                         graph=graph_buffer)
    128                 graph_buffer = self.dataset[split].graph
    129 
    130             self.iterator[split] = lambda trainer=self, split=split: tqdm.tqdm(
    131                     iterable=torch.utils.data.DataLoader(
    132                         dataset=trainer.dataset[split],
    133                         collate_fn=trainer.batcher,
    134                         batch_size=trainer.config.batch_size if split == "train" else trainer.config.get("eval_batch_size", trainer.config.batch_size),
    135                         shuffle=(split == "train" and trainer.dataset[split].shuffleable),
    136                         num_workers=trainer.config.workers,
    137                         worker_init_fn=getattr(trainer.dataset[split], "init_seed", None),
    138                         pin_memory=(trainer.device.type == "cuda")),
    139                     desc=f"Epoch {trainer.epoch:2} {split:5}",
    140                     unit="samples",
    141                     unit_scale=trainer.config.batch_size if split == "train" else trainer.config.get("eval_batch_size", trainer.config.batch_size),
    142                     total=math.ceil(len(trainer.dataset[split]) / (trainer.config.batch_size if split == "train" else trainer.config.get("eval_batch_size", trainer.config.batch_size))),
    143                     leave=False)
    144 
    145     def build_model(self) -> None:
    146         """ Instantiate the model. """
    147         self.model: torch.nn.Module = self.config.Model(self.config, self.tokenizer, self.relation_dictionary)
    148         if self.state_dicts:
    149             missing_keys: List[str]
    150             unexpected_keys: List[str]
    151             missing_keys, unexpected_keys = self.model.load_state_dict(self.state_dicts["model"], strict=False)
    152             if missing_keys or unexpected_keys:
    153                 unexpected_display: List[str] = unexpected_keys
    154                 if self.config.get("pretrained"):
    155                     unexpected_display = list(filter(lambda key: not key.startswith("language_model."), unexpected_keys))
    156                     if len(unexpected_display) < len(unexpected_keys):
    157                         unexpected_display.append("language_model.*")
    158                 print("\033[31mLoading state_dict mismatch.\033[0m\n")
    159                 print(f"\033[31mMissing keys: {' '.join(missing_keys)}.\033[0m\n")
    160                 print(f"\033[31mUnexpected keys: {' '.join(unexpected_display)}.\033[0m\n")
    161                 assert(self.config.get("pretrained"))
    162         self.model.to(self.device)
    163         if self.config.get("EvalModel"):
    164             self.eval_model: torch.nn.Module = self.config.EvalModel(self.config, self.tokenizer, self.relation_dictionary, train_model=self.model)
    165         else:
    166             self.eval_model: torch.nn.Module = self.model
    167         self.mixed_precision = torch.cuda.amp.autocast if self.config.get("amp") else contextlib.nullcontext
    168 
    169     def count_parameters(self) -> None:
    170         """ Display the number of parameters used by the model. """
    171         total: int = 0
    172         for parameter in self.model.parameters():
    173             total += parameter.shape.numel()
    174         print(f"\033[33mNumber of parameters: {total:,}\033[0m")
    175 
    176     def setup_optimizer(self) -> None:
    177         """ Instantiate the optimizer. """
    178         optimizer_hyperparameters: Dict[str, Any] = self.config.get("optimizer_hyperparameters", {})
    179         optimizer_hyperparameters["lr"] = self.config.learning_rate
    180         self.optimizer: torch.optim.optimizer.Optimizer = self.config.Optimizer(self.model.parameters(), **optimizer_hyperparameters)
    181         # if self.state_dicts and "optimizer" in self.state_dicts:
    182         #     self.optimizer.load_state_dict(self.state_dicts["optimizer"])
    183         if self.config.get("Scheduler"):
    184             total_iters: int = math.ceil(self.config.max_epoch * len(self.dataset["train"]) / self.config.accumulated_batch_size)
    185             self.scheduler = self.config.Scheduler(self.optimizer, total_iters=total_iters, **self.config.get("scheduler_parameters", {}))
    186         if self.config.get("amp"):
    187             self.scaler = torch.cuda.amp.GradScaler(self.config.get("initial_grad_scale", 65536))
    188 
    189     def save(self, path: pathlib.Path) -> None:
    190         """ Save the model and all additional state which might be needed to resume training. """
    191         state_dicts: Dict[str, Any] = {
    192                 "logdir": self.logdir,
    193                 "config": self.config,
    194                 "model": self.model.state_dict(),
    195                 "optimizer": self.optimizer.state_dict(),
    196                 "torch_rng": torch.random.get_rng_state(),
    197                 "epoch": self.epoch,
    198                 "best_epoch": self.best_epoch,
    199                 "best_eval": self.best_eval,
    200                 "global_batch_id": self.global_batch_id,
    201                 "results": self.results,
    202             }
    203         if self.config.workers == 0:
    204             state_dicts["train_rng"] = getattr(self.dataset["train"], "rng", None)
    205         if torch.cuda.is_available():
    206             state_dicts["cuda_rng"] = torch.cuda.random.get_rng_state_all()
    207         torch.save(state_dicts, path)
    208 
    209     def init_writer(self) -> None:
    210         """ Initialize summary writer for tensorboard logging. """
    211         self.writer = SummaryWriter(log_dir=self.logdir)
    212         self.results: Dict[str, float] = {}
    213         if self.state_dicts:
    214             self.results = self.state_dicts.get("results", self.results)
    215 
    216     def close(self) -> None:
    217         """ Close the writer. """
    218         self.writer.close()
    219 
    220     def init_epochs(self) -> None:
    221         """
    222         Given initial values for early stopping.
    223 
    224         Assume a "higher is better" metric.
    225         """
    226         self.epoch: int = 0
    227         self.best_epoch: int = 0
    228         self.best_eval: float = -math.inf
    229         self.global_batch_id: int = 0
    230 
    231         if self.state_dicts:
    232             self.epoch = self.state_dicts.get("epoch", self.epoch)
    233             self.best_epoch = self.state_dicts.get("best_epoch", self.best_epoch)
    234             self.best_eval = self.state_dicts.get("best_eval", self.best_eval)
    235             self.global_batch_id = self.state_dicts.get("global_batch_id", self.global_batch_id)
    236 
    237     def checkpoint(self) -> None:
    238         """ Save current model. """
    239         self.save(self.logdir / "checkpoint.new")
    240         os.rename(self.logdir / "checkpoint.new", self.logdir / "checkpoint")
    241 
    242     def update_best_model(self, candidate: float) -> bool:
    243         """ Update current best model and return whether we should early stop. """
    244         if candidate is math.nan:
    245             return False
    246 
    247         if candidate > self.best_eval:
    248             logger.info(f"Updating best model at epoch {self.epoch} (metric {candidate} > {self.best_eval})")
    249             self.best_epoch = self.epoch
    250             self.best_eval = candidate
    251             shutil.copy(self.logdir / "checkpoint", self.logdir / "best")
    252         return (self.epoch - self.best_epoch > self.config.get("early_stopping_patience", self.config.max_epoch))
    253 
    254     def load_best_model(self) -> None:
    255         """ Load the model with the best validation score. """
    256         best_path: pathlib.Path = self.logdir / "best"
    257         if self.best_eval != -math.inf and best_path.exists():
    258             print(f"Loading best model from {best_path}…", end="", flush=True)
    259             self.model.load_state_dict(torch.load(best_path)["model"])
    260             print(" done")
    261             logger.info(f"{best_path} loaded")
    262 
    263     def train_apply_grads(self) -> None:
    264         """ Apply the gradients to the parameters. """
    265         if self.config.get("amp"):
    266             if self.config.get("clip_gradient"):
    267                 self.scaler.unscale_(self.optimizer)
    268                 torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip_gradient)
    269             self.scaler.step(self.optimizer)
    270             self.scaler.update()
    271         else:
    272             if self.config.get("clip_gradient"):
    273                 torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip_gradient)
    274             self.optimizer.step()
    275         self.optimizer.zero_grad()
    276         if self.config.get("Scheduler"):
    277             self.scheduler.step()
    278 
    279     def train(self) -> None:
    280         """ Train the model with intermediate evaluation on the validation dataset and final evaluation on the test dataset. """
    281         assert(self.config.accumulated_batch_size % self.config.batch_size == 0)
    282         data_batch_per_true_batch: int = self.config.accumulated_batch_size // self.config.batch_size
    283 
    284         if not self.config.get("no_initial_validation"):
    285             self.best_eval = self.evaluate("valid")
    286 
    287         for self.epoch in range(self.epoch+1, self.config.max_epoch+1):
    288             if self.interrupted:
    289                 break
    290             self.model.train()
    291             self.optimizer.zero_grad()
    292             total_loss: float = 0.0
    293             total_sample: int = 0
    294 
    295             epoch_loop = self.iterator["train"]()
    296             for batch_id, batch in enumerate(epoch_loop):
    297                 batch: Dict[str, torch.Tensor] = {key: value.to(self.device) for key, value in batch.items()}
    298 
    299                 loss: torch.Tensor
    300                 losses: Dict[str, torch.Tensor]
    301                 variables: Dict[str, torch.Tensor]
    302                 with self.mixed_precision():
    303                     loss, losses, variables = self.model(**batch)
    304                 if self.config.get("amp"):
    305                     self.scaler.scale(loss).backward()
    306                 else:
    307                     loss.backward()
    308 
    309                 if torch.isfinite(loss):
    310                     total_loss += loss.item()
    311                     total_sample += next(iter(batch.values())).shape[0]
    312                     epoch_loop.set_postfix(loss=f"{total_loss / total_sample:.2f}   .", refresh=False)
    313                 else:
    314                     epoch_loop.set_postfix(loss=f"{total_loss / (total_sample if total_sample else 1):.2f} NaN", refresh=False)
    315 
    316                 self.writer.add_scalar("Loss/train", loss, self.global_batch_id)
    317                 for key, value in losses.items():
    318                     self.writer.add_scalar(f"{key}/train", value, self.global_batch_id)
    319 
    320                 if batch_id % 1000 == 0:
    321                     for key, value in variables.items():
    322                         self.writer.add_histogram(f"{key}/train", value, self.global_batch_id)
    323                 if batch_id % data_batch_per_true_batch == data_batch_per_true_batch - 1:
    324                     self.train_apply_grads()
    325                 self.global_batch_id += 1
    326 
    327             if total_sample > 0:
    328                 self.results[f"train.loss"] = float(total_loss / total_sample)
    329             else:
    330                 total_sample = 1
    331 
    332             if batch_id % data_batch_per_true_batch != data_batch_per_true_batch - 1:
    333                 self.train_apply_grads()
    334 
    335             print(f"Epoch {self.epoch:2} train mean loss: {total_loss / total_sample:8.4f}")
    336             logger.info(f"epoch {self.epoch} train mean_loss {total_loss / total_sample}")
    337 
    338             self.checkpoint()
    339             candidate: float = self.evaluate("valid")
    340             if self.update_best_model(candidate):
    341                 break
    342 
    343         if "test" in self.iterator:
    344             self.load_best_model()
    345         self.evaluate("test")
    346         self.write_results()
    347 
    348     def metric_message(self, split: str) -> str:
    349         """ Message to be displayed after an evaluation. """
    350         message = "Epoch {self.epoch:2} {split:5} loss: {metrics.loss:8.4f}"
    351         if not isinstance(self.dataset[split], gbure.data.dataset.UnsupervisedDataset) or hasattr(self.dataset[split], "dataset") and not isinstance(self.dataset[split].dataset, gbure.data.dataset.UnsupervisedDataset):
    352             message += " accuracy: {metrics.accuracy:8.6f}"
    353         if isinstance(self.dataset[split], gbure.data.dataset.SupervisedDataset):
    354             message += " half-directed Macro F1: {metrics.half_directed_macro_f1:8.6f} (P: {metrics.half_directed_macro_precision:8.6f} R: {metrics.half_directed_macro_recall:8.6f})"
    355         if isinstance(self.dataset[split], gbure.data.dataset.GraphAdapter) and (isinstance(self.dataset[split].dataset, gbure.data.dataset.FewShotDataset) or isinstance(self.dataset[split].dataset, gbure.data.dataset.SampledFewShotDataset)):
    356             message += " accuracy non-empty: {metrics.accuracy_non_empty:8.6f} accuracy full: {metrics.accuracy_full:8.6f}"
    357         return message
    358 
    359     def evaluate(self, split: str) -> float:
    360         """ Evaluate the model on the given split and return the early stopping metric. """
    361         if split not in self.iterator:
    362             if not self.config.get("unsupervised"):
    363                 print(f"\033[31m{split} dataset does not exist, evaluation skipped.\033[0m")
    364             return math.nan
    365 
    366         self.model.zero_grad(set_to_none=True)
    367         gc.collect()
    368         self.eval_model.eval()
    369         with torch.no_grad(), gbure.outputs.Outputs(self.logdir, self.tokenizer, self.eval_relation_dictionary) as outputs:
    370             metrics = gbure.metrics.Metrics(self.config, self.tokenizer, self.eval_relation_dictionary, getattr(self.dataset[split], "graph", None))
    371             loop = self.iterator[split]()
    372             for batch in loop:
    373                 batch: Dict[str, torch.Tensor] = {key: value.to(self.device) for key, value in batch.items()}
    374                 loss, losses, variables = self.eval_model(**batch)
    375                 metrics.update(batch, loss, losses, variables)
    376                 outputs.update(batch, loss, losses, variables)
    377                 loop.set_postfix(**metrics.summary, refresh=False)
    378             print(self.metric_message(split).format(**locals()))
    379             for metric, value in metrics.all.items():
    380                 logger.info(f"epoch {self.epoch} {split} {metric} {value}")
    381 
    382             candidate: float = getattr(metrics, self.config.validation_metric) if self.config.get("validation_metric") else math.nan
    383             if split == "test" or (split == "valid" and candidate > self.best_eval):
    384                 for key, value in metrics.all.items():
    385                     self.results[f"{split}.{key}"] = float(value)
    386             return candidate
    387 
    388     def write_results(self):
    389         """ Write best valid and test score as tensorboard events. """
    390         self.writer.add_hparams(
    391                 hparam_dict=gbure.utils.flatten_dict(self.config),
    392                 metric_dict=self.results,
    393                 run_name="hparams")
    394 
    395 
    396 if __name__ == "__main__":
    397     gbure.utils.fix_transformers_logging_handler()
    398 
    399     config: gbure.utils.dotdict = gbure.utils.parse_args()
    400     logdir: pathlib.Path
    401     state_dicts: Optional[Dict[str, Any]] = None
    402     new_log_dir: bool = True
    403 
    404     if config.get("load"):
    405         state_dicts = torch.load(config.load, map_location=torch.device('cpu'))
    406         if not config.get("overwrite_config"):
    407             config = state_dicts["config"]
    408             logdir = state_dicts["logdir"]
    409             new_log_dir = False
    410 
    411     if config.get("pretrained"):
    412         state_dicts = torch.load(config.pretrained, map_location=torch.device('cpu'))
    413         state_dicts = {"model": state_dicts["model"]}
    414 
    415     if new_log_dir:
    416         logdir = gbure.utils.LOG_PATH / gbure.utils.logdir_name("GBURE")
    417         assert(not logdir.exists())
    418         logdir.mkdir()
    419 
    420     gbure.utils.add_logging_handler(logdir)
    421     Trainer(config, logdir, state_dicts).run()