train.py (21628B)
1 from typing import Any, Callable, Dict, Iterable, List, Optional 2 import contextlib 3 import gc 4 import logging 5 import math 6 import os 7 import pathlib 8 import shutil 9 10 from torch.utils.tensorboard import SummaryWriter 11 import torch 12 import torch.utils 13 import tqdm 14 import transformers 15 16 import gbure.data.batcher 17 import gbure.data.dataset 18 import gbure.data.dictionary 19 import gbure.data.graph 20 import gbure.metrics 21 import gbure.outputs 22 import gbure.utils 23 24 logger = logging.getLogger(__name__) 25 26 27 class Trainer(gbure.utils.Experiment): 28 """ 29 Train a model. 30 31 Config: 32 Model: the model class to use for training and evaluation 33 Optimizer: the optimizer class to use for training 34 Scheduler: the learning rate scheduler class to use 35 accumulated_batch_size: the actual number of sample in a batch after accumulation, that is the number of samples seen before an optimizer step (must be a multiple of batch_size) 36 amp: enable automatic mixed precision and gradient scaler 37 batch_size: the number of samples in the batch of data loaded 38 eval_batch_size: batch size used for evaluation 39 clip_gradient: the maximum norm of the gradient 40 dataset_name: name of the dataset to load 41 dataset_spec: dataset specification, usually None, can be used to select a (smaller) test version 42 eval_dataset_name: overwrite evaluation dataset 43 eval_dataset_spec: overwrite evaluation dataset specification 44 early_stopping_patience: how many epoch to train after best validation score has been reached 45 learning_rate: learning rate 46 max_epoch: maximum number of epoch 47 no_initial_validation: do not run evaluation on the valid dataset before first epoch 48 optimizer_hyperparameters: hyperparameters for the optimizer (e.g. weight decay, etc) 49 pretrained: path to a pretrained model to load 50 scheduler_parameters: the parameters for initializing the Scheduler 51 test_output: path to a file where the test predictions will be written 52 transformer_model: the model of transformer to use 53 unsupervised: train an unsupervised model 54 validation_metric: metric used for early stopping 55 workers: number of data generating workers to spawn 56 """ 57 58 def init(self) -> None: 59 """ Prepare training. """ 60 self.prepare_dataset() 61 self.build_model() 62 self.count_parameters() 63 self.setup_optimizer() 64 self.init_writer() 65 self.init_epochs() 66 67 def main(self) -> None: 68 """ Run the experiment (i.e. here, train). """ 69 self.train() 70 71 def prepare_dataset(self) -> None: 72 """ Load datasets and create iterators. """ 73 dataset_spec: str = self.config.transformer_model + self.config.get("dataset_spec", "") 74 self.data_dir: pathlib.Path = gbure.utils.DATA_PATH / self.config.dataset_name / dataset_spec 75 76 self.tokenizer: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(str(self.data_dir / "tokenizer")) 77 self.batcher = gbure.data.batcher.Batcher(self.tokenizer.pad_token_id) 78 self.relation_dictionary: Optional[gbure.data.dictionary.RelationDictionary] = None if self.config.get("unsupervised") else gbure.data.dictionary.RelationDictionary(path=self.data_dir / "relations") 79 entities_path: pathlib.Path = self.data_dir / "train" / "entities" if self.config.get("unsupervised") else self.data_dir / "entities" 80 self.entity_dictionary: gbure.data.dictionary.Dictionary = gbure.data.dictionary.Dictionary(path=entities_path) 81 82 if self.config.get("eval_dataset_name"): 83 eval_dataset_spec: str = self.config.transformer_model + self.config.get("eval_dataset_spec", "") 84 self.eval_data_dir: pathlib.Path = gbure.utils.DATA_PATH / self.config.eval_dataset_name / eval_dataset_spec 85 # We assume the tokenizer is the same. 86 self.eval_relation_dictionary: gbure.data.dictionary.RelationDictionary = gbure.data.dictionary.RelationDictionary(path=self.eval_data_dir / "relations") 87 # FIXME doesn't work when a test dataset is specified … 88 entities_path: pathlib.Path = self.data_dir / f"{self.config.valid_name}.entities" if self.config.get("valid_name") else self.data_dir / "entities" 89 self.eval_entity_dictionary: gbure.data.dictionary.Dictionary = gbure.data.dictionary.Dictionary(path=self.eval_data_dir / "entities") 90 else: 91 self.eval_data_dir: pathlib.Path = self.data_dir 92 self.eval_relation_dictionary: Optional[gbure.data.dictionary.RelationDictionary] = self.relation_dictionary 93 self.eval_entity_dictionary: gbure.data.dictionary.Dictionary = self.entity_dictionary 94 95 self.dataset: Dict[str, torch.utils.data.Dataset] = {} 96 self.iterator: Dict[str, Callable[[], Any]] = {} 97 graph_buffer: Optional[gbure.data.graph.Graph] = None 98 99 for split in ["train", "valid", "test"]: 100 kwargs: Dict[str, Any] = {"rng": self.state_dicts.get("train_rng")} if split == "train" and self.state_dicts else {} 101 split_path = pathlib.Path(split) 102 data_dir = self.data_dir if split == "train" else self.eval_data_dir 103 if f"{split}_name" in self.config: 104 split_path = pathlib.Path(self.config[f"{split}_name"]) 105 106 try: 107 self.dataset[split] = gbure.data.dataset.load_dataset( 108 config=self.config, 109 split=split, 110 path=data_dir / split_path, 111 tokenizer=self.tokenizer, 112 evaluation=(split != "train"), 113 **kwargs) 114 except FileNotFoundError: 115 if split == "train": 116 raise 117 continue 118 119 if self.config.get("graph_name"): 120 graph_spec: str = self.config.transformer_model + self.config.get("graph_spec", "") 121 graph_dir: pathlib.Path = gbure.utils.DATA_PATH / self.config.graph_name / graph_spec 122 entity_dictionary = self.entity_dictionary if split == "train" else self.eval_entity_dictionary 123 self.dataset[split] = gbure.data.dataset.GraphAdapter( 124 self.dataset[split], 125 entity_dictionary, 126 path=graph_dir / "train", 127 graph=graph_buffer) 128 graph_buffer = self.dataset[split].graph 129 130 self.iterator[split] = lambda trainer=self, split=split: tqdm.tqdm( 131 iterable=torch.utils.data.DataLoader( 132 dataset=trainer.dataset[split], 133 collate_fn=trainer.batcher, 134 batch_size=trainer.config.batch_size if split == "train" else trainer.config.get("eval_batch_size", trainer.config.batch_size), 135 shuffle=(split == "train" and trainer.dataset[split].shuffleable), 136 num_workers=trainer.config.workers, 137 worker_init_fn=getattr(trainer.dataset[split], "init_seed", None), 138 pin_memory=(trainer.device.type == "cuda")), 139 desc=f"Epoch {trainer.epoch:2} {split:5}", 140 unit="samples", 141 unit_scale=trainer.config.batch_size if split == "train" else trainer.config.get("eval_batch_size", trainer.config.batch_size), 142 total=math.ceil(len(trainer.dataset[split]) / (trainer.config.batch_size if split == "train" else trainer.config.get("eval_batch_size", trainer.config.batch_size))), 143 leave=False) 144 145 def build_model(self) -> None: 146 """ Instantiate the model. """ 147 self.model: torch.nn.Module = self.config.Model(self.config, self.tokenizer, self.relation_dictionary) 148 if self.state_dicts: 149 missing_keys: List[str] 150 unexpected_keys: List[str] 151 missing_keys, unexpected_keys = self.model.load_state_dict(self.state_dicts["model"], strict=False) 152 if missing_keys or unexpected_keys: 153 unexpected_display: List[str] = unexpected_keys 154 if self.config.get("pretrained"): 155 unexpected_display = list(filter(lambda key: not key.startswith("language_model."), unexpected_keys)) 156 if len(unexpected_display) < len(unexpected_keys): 157 unexpected_display.append("language_model.*") 158 print("\033[31mLoading state_dict mismatch.\033[0m\n") 159 print(f"\033[31mMissing keys: {' '.join(missing_keys)}.\033[0m\n") 160 print(f"\033[31mUnexpected keys: {' '.join(unexpected_display)}.\033[0m\n") 161 assert(self.config.get("pretrained")) 162 self.model.to(self.device) 163 if self.config.get("EvalModel"): 164 self.eval_model: torch.nn.Module = self.config.EvalModel(self.config, self.tokenizer, self.relation_dictionary, train_model=self.model) 165 else: 166 self.eval_model: torch.nn.Module = self.model 167 self.mixed_precision = torch.cuda.amp.autocast if self.config.get("amp") else contextlib.nullcontext 168 169 def count_parameters(self) -> None: 170 """ Display the number of parameters used by the model. """ 171 total: int = 0 172 for parameter in self.model.parameters(): 173 total += parameter.shape.numel() 174 print(f"\033[33mNumber of parameters: {total:,}\033[0m") 175 176 def setup_optimizer(self) -> None: 177 """ Instantiate the optimizer. """ 178 optimizer_hyperparameters: Dict[str, Any] = self.config.get("optimizer_hyperparameters", {}) 179 optimizer_hyperparameters["lr"] = self.config.learning_rate 180 self.optimizer: torch.optim.optimizer.Optimizer = self.config.Optimizer(self.model.parameters(), **optimizer_hyperparameters) 181 # if self.state_dicts and "optimizer" in self.state_dicts: 182 # self.optimizer.load_state_dict(self.state_dicts["optimizer"]) 183 if self.config.get("Scheduler"): 184 total_iters: int = math.ceil(self.config.max_epoch * len(self.dataset["train"]) / self.config.accumulated_batch_size) 185 self.scheduler = self.config.Scheduler(self.optimizer, total_iters=total_iters, **self.config.get("scheduler_parameters", {})) 186 if self.config.get("amp"): 187 self.scaler = torch.cuda.amp.GradScaler(self.config.get("initial_grad_scale", 65536)) 188 189 def save(self, path: pathlib.Path) -> None: 190 """ Save the model and all additional state which might be needed to resume training. """ 191 state_dicts: Dict[str, Any] = { 192 "logdir": self.logdir, 193 "config": self.config, 194 "model": self.model.state_dict(), 195 "optimizer": self.optimizer.state_dict(), 196 "torch_rng": torch.random.get_rng_state(), 197 "epoch": self.epoch, 198 "best_epoch": self.best_epoch, 199 "best_eval": self.best_eval, 200 "global_batch_id": self.global_batch_id, 201 "results": self.results, 202 } 203 if self.config.workers == 0: 204 state_dicts["train_rng"] = getattr(self.dataset["train"], "rng", None) 205 if torch.cuda.is_available(): 206 state_dicts["cuda_rng"] = torch.cuda.random.get_rng_state_all() 207 torch.save(state_dicts, path) 208 209 def init_writer(self) -> None: 210 """ Initialize summary writer for tensorboard logging. """ 211 self.writer = SummaryWriter(log_dir=self.logdir) 212 self.results: Dict[str, float] = {} 213 if self.state_dicts: 214 self.results = self.state_dicts.get("results", self.results) 215 216 def close(self) -> None: 217 """ Close the writer. """ 218 self.writer.close() 219 220 def init_epochs(self) -> None: 221 """ 222 Given initial values for early stopping. 223 224 Assume a "higher is better" metric. 225 """ 226 self.epoch: int = 0 227 self.best_epoch: int = 0 228 self.best_eval: float = -math.inf 229 self.global_batch_id: int = 0 230 231 if self.state_dicts: 232 self.epoch = self.state_dicts.get("epoch", self.epoch) 233 self.best_epoch = self.state_dicts.get("best_epoch", self.best_epoch) 234 self.best_eval = self.state_dicts.get("best_eval", self.best_eval) 235 self.global_batch_id = self.state_dicts.get("global_batch_id", self.global_batch_id) 236 237 def checkpoint(self) -> None: 238 """ Save current model. """ 239 self.save(self.logdir / "checkpoint.new") 240 os.rename(self.logdir / "checkpoint.new", self.logdir / "checkpoint") 241 242 def update_best_model(self, candidate: float) -> bool: 243 """ Update current best model and return whether we should early stop. """ 244 if candidate is math.nan: 245 return False 246 247 if candidate > self.best_eval: 248 logger.info(f"Updating best model at epoch {self.epoch} (metric {candidate} > {self.best_eval})") 249 self.best_epoch = self.epoch 250 self.best_eval = candidate 251 shutil.copy(self.logdir / "checkpoint", self.logdir / "best") 252 return (self.epoch - self.best_epoch > self.config.get("early_stopping_patience", self.config.max_epoch)) 253 254 def load_best_model(self) -> None: 255 """ Load the model with the best validation score. """ 256 best_path: pathlib.Path = self.logdir / "best" 257 if self.best_eval != -math.inf and best_path.exists(): 258 print(f"Loading best model from {best_path}…", end="", flush=True) 259 self.model.load_state_dict(torch.load(best_path)["model"]) 260 print(" done") 261 logger.info(f"{best_path} loaded") 262 263 def train_apply_grads(self) -> None: 264 """ Apply the gradients to the parameters. """ 265 if self.config.get("amp"): 266 if self.config.get("clip_gradient"): 267 self.scaler.unscale_(self.optimizer) 268 torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip_gradient) 269 self.scaler.step(self.optimizer) 270 self.scaler.update() 271 else: 272 if self.config.get("clip_gradient"): 273 torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip_gradient) 274 self.optimizer.step() 275 self.optimizer.zero_grad() 276 if self.config.get("Scheduler"): 277 self.scheduler.step() 278 279 def train(self) -> None: 280 """ Train the model with intermediate evaluation on the validation dataset and final evaluation on the test dataset. """ 281 assert(self.config.accumulated_batch_size % self.config.batch_size == 0) 282 data_batch_per_true_batch: int = self.config.accumulated_batch_size // self.config.batch_size 283 284 if not self.config.get("no_initial_validation"): 285 self.best_eval = self.evaluate("valid") 286 287 for self.epoch in range(self.epoch+1, self.config.max_epoch+1): 288 if self.interrupted: 289 break 290 self.model.train() 291 self.optimizer.zero_grad() 292 total_loss: float = 0.0 293 total_sample: int = 0 294 295 epoch_loop = self.iterator["train"]() 296 for batch_id, batch in enumerate(epoch_loop): 297 batch: Dict[str, torch.Tensor] = {key: value.to(self.device) for key, value in batch.items()} 298 299 loss: torch.Tensor 300 losses: Dict[str, torch.Tensor] 301 variables: Dict[str, torch.Tensor] 302 with self.mixed_precision(): 303 loss, losses, variables = self.model(**batch) 304 if self.config.get("amp"): 305 self.scaler.scale(loss).backward() 306 else: 307 loss.backward() 308 309 if torch.isfinite(loss): 310 total_loss += loss.item() 311 total_sample += next(iter(batch.values())).shape[0] 312 epoch_loop.set_postfix(loss=f"{total_loss / total_sample:.2f} .", refresh=False) 313 else: 314 epoch_loop.set_postfix(loss=f"{total_loss / (total_sample if total_sample else 1):.2f} NaN", refresh=False) 315 316 self.writer.add_scalar("Loss/train", loss, self.global_batch_id) 317 for key, value in losses.items(): 318 self.writer.add_scalar(f"{key}/train", value, self.global_batch_id) 319 320 if batch_id % 1000 == 0: 321 for key, value in variables.items(): 322 self.writer.add_histogram(f"{key}/train", value, self.global_batch_id) 323 if batch_id % data_batch_per_true_batch == data_batch_per_true_batch - 1: 324 self.train_apply_grads() 325 self.global_batch_id += 1 326 327 if total_sample > 0: 328 self.results[f"train.loss"] = float(total_loss / total_sample) 329 else: 330 total_sample = 1 331 332 if batch_id % data_batch_per_true_batch != data_batch_per_true_batch - 1: 333 self.train_apply_grads() 334 335 print(f"Epoch {self.epoch:2} train mean loss: {total_loss / total_sample:8.4f}") 336 logger.info(f"epoch {self.epoch} train mean_loss {total_loss / total_sample}") 337 338 self.checkpoint() 339 candidate: float = self.evaluate("valid") 340 if self.update_best_model(candidate): 341 break 342 343 if "test" in self.iterator: 344 self.load_best_model() 345 self.evaluate("test") 346 self.write_results() 347 348 def metric_message(self, split: str) -> str: 349 """ Message to be displayed after an evaluation. """ 350 message = "Epoch {self.epoch:2} {split:5} loss: {metrics.loss:8.4f}" 351 if not isinstance(self.dataset[split], gbure.data.dataset.UnsupervisedDataset) or hasattr(self.dataset[split], "dataset") and not isinstance(self.dataset[split].dataset, gbure.data.dataset.UnsupervisedDataset): 352 message += " accuracy: {metrics.accuracy:8.6f}" 353 if isinstance(self.dataset[split], gbure.data.dataset.SupervisedDataset): 354 message += " half-directed Macro F1: {metrics.half_directed_macro_f1:8.6f} (P: {metrics.half_directed_macro_precision:8.6f} R: {metrics.half_directed_macro_recall:8.6f})" 355 if isinstance(self.dataset[split], gbure.data.dataset.GraphAdapter) and (isinstance(self.dataset[split].dataset, gbure.data.dataset.FewShotDataset) or isinstance(self.dataset[split].dataset, gbure.data.dataset.SampledFewShotDataset)): 356 message += " accuracy non-empty: {metrics.accuracy_non_empty:8.6f} accuracy full: {metrics.accuracy_full:8.6f}" 357 return message 358 359 def evaluate(self, split: str) -> float: 360 """ Evaluate the model on the given split and return the early stopping metric. """ 361 if split not in self.iterator: 362 if not self.config.get("unsupervised"): 363 print(f"\033[31m{split} dataset does not exist, evaluation skipped.\033[0m") 364 return math.nan 365 366 self.model.zero_grad(set_to_none=True) 367 gc.collect() 368 self.eval_model.eval() 369 with torch.no_grad(), gbure.outputs.Outputs(self.logdir, self.tokenizer, self.eval_relation_dictionary) as outputs: 370 metrics = gbure.metrics.Metrics(self.config, self.tokenizer, self.eval_relation_dictionary, getattr(self.dataset[split], "graph", None)) 371 loop = self.iterator[split]() 372 for batch in loop: 373 batch: Dict[str, torch.Tensor] = {key: value.to(self.device) for key, value in batch.items()} 374 loss, losses, variables = self.eval_model(**batch) 375 metrics.update(batch, loss, losses, variables) 376 outputs.update(batch, loss, losses, variables) 377 loop.set_postfix(**metrics.summary, refresh=False) 378 print(self.metric_message(split).format(**locals())) 379 for metric, value in metrics.all.items(): 380 logger.info(f"epoch {self.epoch} {split} {metric} {value}") 381 382 candidate: float = getattr(metrics, self.config.validation_metric) if self.config.get("validation_metric") else math.nan 383 if split == "test" or (split == "valid" and candidate > self.best_eval): 384 for key, value in metrics.all.items(): 385 self.results[f"{split}.{key}"] = float(value) 386 return candidate 387 388 def write_results(self): 389 """ Write best valid and test score as tensorboard events. """ 390 self.writer.add_hparams( 391 hparam_dict=gbure.utils.flatten_dict(self.config), 392 metric_dict=self.results, 393 run_name="hparams") 394 395 396 if __name__ == "__main__": 397 gbure.utils.fix_transformers_logging_handler() 398 399 config: gbure.utils.dotdict = gbure.utils.parse_args() 400 logdir: pathlib.Path 401 state_dicts: Optional[Dict[str, Any]] = None 402 new_log_dir: bool = True 403 404 if config.get("load"): 405 state_dicts = torch.load(config.load, map_location=torch.device('cpu')) 406 if not config.get("overwrite_config"): 407 config = state_dicts["config"] 408 logdir = state_dicts["logdir"] 409 new_log_dir = False 410 411 if config.get("pretrained"): 412 state_dicts = torch.load(config.pretrained, map_location=torch.device('cpu')) 413 state_dicts = {"model": state_dicts["model"]} 414 415 if new_log_dir: 416 logdir = gbure.utils.LOG_PATH / gbure.utils.logdir_name("GBURE") 417 assert(not logdir.exists()) 418 logdir.mkdir() 419 420 gbure.utils.add_logging_handler(logdir) 421 Trainer(config, logdir, state_dicts).run()