Spaces:
Configuration error
Configuration error
| import sys | |
| import argparse | |
| import os | |
| import time | |
| import logging | |
| from datetime import datetime | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", required=True, help="path to config file") | |
| parser.add_argument("--gpu", default="0", help="GPU(s) to be used") | |
| parser.add_argument( | |
| "--resume", default=None, help="path to the weights to be resumed" | |
| ) | |
| parser.add_argument( | |
| "--resume_weights_only", | |
| action="store_true", | |
| help="specify this argument to restore only the weights (w/o training states), e.g. --resume path/to/resume --resume_weights_only", | |
| ) | |
| group = parser.add_mutually_exclusive_group(required=True) | |
| group.add_argument("--train", action="store_true") | |
| group.add_argument("--validate", action="store_true") | |
| group.add_argument("--test", action="store_true") | |
| group.add_argument("--predict", action="store_true") | |
| # group.add_argument('--export', action='store_true') # TODO: a separate export action | |
| parser.add_argument("--exp_dir", default="./exp") | |
| parser.add_argument("--runs_dir", default="./runs") | |
| parser.add_argument( | |
| "--verbose", action="store_true", help="if true, set logging level to DEBUG" | |
| ) | |
| args, extras = parser.parse_known_args() | |
| # set CUDA_VISIBLE_DEVICES then import pytorch-lightning | |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |
| os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu | |
| n_gpus = len(args.gpu.split(",")) | |
| import datasets | |
| import systems | |
| import pytorch_lightning as pl | |
| from pytorch_lightning import Trainer | |
| from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor | |
| from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger | |
| from utils.callbacks import ( | |
| CodeSnapshotCallback, | |
| ConfigSnapshotCallback, | |
| CustomProgressBar, | |
| ) | |
| from utils.misc import load_config | |
| # parse YAML config to OmegaConf | |
| config = load_config(args.config, cli_args=extras) | |
| config.cmd_args = vars(args) | |
| config.trial_name = config.get("trial_name") or ( | |
| config.tag + datetime.now().strftime("@%Y%m%d-%H%M%S") | |
| ) | |
| config.exp_dir = config.get("exp_dir") or os.path.join(args.exp_dir, config.name) | |
| config.save_dir = config.get("save_dir") or os.path.join( | |
| config.exp_dir, config.trial_name, "save" | |
| ) | |
| config.ckpt_dir = config.get("ckpt_dir") or os.path.join( | |
| config.exp_dir, config.trial_name, "ckpt" | |
| ) | |
| config.code_dir = config.get("code_dir") or os.path.join( | |
| config.exp_dir, config.trial_name, "code" | |
| ) | |
| config.config_dir = config.get("config_dir") or os.path.join( | |
| config.exp_dir, config.trial_name, "config" | |
| ) | |
| logger = logging.getLogger("pytorch_lightning") | |
| if args.verbose: | |
| logger.setLevel(logging.DEBUG) | |
| if "seed" not in config: | |
| config.seed = int(time.time() * 1000) % 1000 | |
| pl.seed_everything(config.seed) | |
| dm = datasets.make(config.dataset.name, config.dataset) | |
| system = systems.make( | |
| config.system.name, | |
| config, | |
| load_from_checkpoint=None if not args.resume_weights_only else args.resume, | |
| ) | |
| callbacks = [] | |
| if args.train: | |
| callbacks += [ | |
| ModelCheckpoint(dirpath=config.ckpt_dir, **config.checkpoint), | |
| LearningRateMonitor(logging_interval="step"), | |
| # CodeSnapshotCallback( | |
| # config.code_dir, use_version=False | |
| # ), | |
| ConfigSnapshotCallback(config, config.config_dir, use_version=False), | |
| CustomProgressBar(refresh_rate=1), | |
| ] | |
| loggers = [] | |
| if args.train: | |
| loggers += [ | |
| TensorBoardLogger( | |
| args.runs_dir, name=config.name, version=config.trial_name | |
| ), | |
| CSVLogger(config.exp_dir, name=config.trial_name, version="csv_logs"), | |
| ] | |
| if sys.platform == "win32": | |
| # does not support multi-gpu on windows | |
| strategy = "dp" | |
| assert n_gpus == 1 | |
| else: | |
| strategy = "ddp_find_unused_parameters_false" | |
| trainer = Trainer( | |
| devices=n_gpus, | |
| accelerator="gpu", | |
| callbacks=callbacks, | |
| logger=loggers, | |
| strategy=strategy, | |
| **config.trainer | |
| ) | |
| if args.train: | |
| if args.resume and not args.resume_weights_only: | |
| # FIXME: different behavior in pytorch-lighting>1.9 ? | |
| trainer.fit(system, datamodule=dm, ckpt_path=args.resume) | |
| else: | |
| trainer.fit(system, datamodule=dm) | |
| trainer.test(system, datamodule=dm) | |
| elif args.validate: | |
| trainer.validate(system, datamodule=dm, ckpt_path=args.resume) | |
| elif args.test: | |
| trainer.test(system, datamodule=dm, ckpt_path=args.resume) | |
| elif args.predict: | |
| trainer.predict(system, datamodule=dm, ckpt_path=args.resume) | |
| if __name__ == "__main__": | |
| main() | |