| | |
| |
|
| | import os |
| | import wandb |
| | import lightning.pytorch as pl |
| |
|
| | from omegaconf import OmegaConf |
| | from lightning.pytorch.strategies import DDPStrategy |
| | from lightning.pytorch.loggers import WandbLogger |
| | from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor |
| |
|
| | from src.utils.model_utils import _print |
| | from src.guidance.solubility_module import SolubilityClassifier |
| | from src.guidance.dataloader import MembraneDataModule, get_datasets |
| |
|
| |
|
| | config = OmegaConf.load("/scratch/sgoel/MeMDLM_v2/src/configs/guidance.yaml") |
| | wandb.login(key='2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f') |
| |
|
| | |
| | datasets = get_datasets(config) |
| | data_module = MembraneDataModule( |
| | config=config, |
| | train_dataset=datasets['train'], |
| | val_dataset=datasets['val'], |
| | test_dataset=datasets['test'], |
| | ) |
| |
|
| | |
| | |
| | wandb_logger = WandbLogger(**config.wandb) |
| |
|
| | |
| | lr_monitor = LearningRateMonitor(logging_interval="step") |
| | checkpoint_callback = ModelCheckpoint( |
| | monitor="val/loss", |
| | save_top_k=1, |
| | mode="min", |
| | dirpath=config.checkpointing.save_dir, |
| | filename="best_model", |
| | ) |
| |
|
| | |
| | trainer = pl.Trainer( |
| | max_steps=config.training.max_steps, |
| | accelerator="cuda", |
| | devices=1, |
| | |
| | callbacks=[checkpoint_callback, lr_monitor], |
| | logger=wandb_logger, |
| | log_every_n_steps=config.training.log_every_n_steps |
| | ) |
| |
|
| | |
| | ckpt_dir = config.checkpointing.save_dir |
| | os.makedirs(ckpt_dir, exist_ok=True) |
| |
|
| | |
| | model = SolubilityClassifier(config) |
| |
|
| | |
| | if config.training.mode == "train": |
| | trainer.fit(model, datamodule=data_module) |
| |
|
| | elif config.training.mode == "test": |
| | ckpt_path = os.path.join(ckpt_dir, "best_model.ckpt") |
| | state_dict = model.get_state_dict(ckpt_path) |
| | model.load_state_dict(state_dict) |
| | trainer.test(model, datamodule=data_module, ckpt_path=ckpt_path) |
| | else: |
| | raise ValueError(f"{config.training.mode} is invalid. Must be 'train' or 'test'") |
| |
|
| | wandb.finish() |
| |
|