Spaces:
Runtime error
Runtime error
| from pytorch_lightning.callbacks import Callback | |
| import pytorch_lightning as pl | |
| import os | |
| from omegaconf import OmegaConf | |
| from pytorch_lightning.utilities import rank_zero_only | |
| MULTINODE_HACKS = True | |
| class SetupCallback(Callback): | |
| def __init__( | |
| self, | |
| resume, | |
| now, | |
| logdir, | |
| ckptdir, | |
| cfgdir, | |
| config, | |
| lightning_config, | |
| debug, | |
| ckpt_name=None, | |
| ): | |
| super().__init__() | |
| self.resume = resume | |
| self.now = now | |
| self.logdir = logdir | |
| self.ckptdir = ckptdir | |
| self.cfgdir = cfgdir | |
| self.config = config | |
| self.lightning_config = lightning_config | |
| self.debug = debug | |
| self.ckpt_name = ckpt_name | |
| def on_exception(self, trainer: pl.Trainer, pl_module, exception): | |
| print("Exception occurred: {}".format(exception)) | |
| if not self.debug and trainer.global_rank == 0: | |
| print("Summoning checkpoint.") | |
| if self.ckpt_name is None: | |
| ckpt_path = os.path.join(self.ckptdir, "last.ckpt") | |
| else: | |
| ckpt_path = os.path.join(self.ckptdir, self.ckpt_name) | |
| trainer.save_checkpoint(ckpt_path) | |
| def on_fit_start(self, trainer, pl_module): | |
| if trainer.global_rank == 0: | |
| # Create logdirs and save configs | |
| os.makedirs(self.logdir, exist_ok=True) | |
| os.makedirs(self.ckptdir, exist_ok=True) | |
| os.makedirs(self.cfgdir, exist_ok=True) | |
| if "callbacks" in self.lightning_config: | |
| if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]: | |
| os.makedirs( | |
| os.path.join(self.ckptdir, "trainstep_checkpoints"), | |
| exist_ok=True, | |
| ) | |
| print("Project config") | |
| print(OmegaConf.to_yaml(self.config)) | |
| if MULTINODE_HACKS: | |
| import time | |
| time.sleep(5) | |
| OmegaConf.save( | |
| self.config, | |
| os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)), | |
| ) | |
| print("Lightning config") | |
| print(OmegaConf.to_yaml(self.lightning_config)) | |
| OmegaConf.save( | |
| OmegaConf.create({"lightning": self.lightning_config}), | |
| os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)), | |
| ) | |
| else: | |
| # ModelCheckpoint callback created log directory --- remove it | |
| if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir): | |
| dst, name = os.path.split(self.logdir) | |
| dst = os.path.join(dst, "child_runs", name) | |
| os.makedirs(os.path.split(dst)[0], exist_ok=True) | |
| try: | |
| os.rename(self.logdir, dst) | |
| except FileNotFoundError: | |
| pass | |