| | import os |
| | import gc |
| | import sys |
| | import torch |
| | import wandb |
| | import torch.nn as nn |
| | import lightning.pytorch as pl |
| |
|
| | from omegaconf import OmegaConf |
| | from lightning.pytorch.strategies import DDPStrategy |
| | from lightning.pytorch.loggers import WandbLogger |
| | from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor |
| |
|
| | from src.lm.memdlm.diffusion_module import MembraneDiffusion |
| | from src.lm.memdlm.dataloader import MembraneDataModule, get_datasets |
| | from src.utils.model_utils import apply_rdm_freezing |
| |
|
| | wandb.login(key='2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f') |
| |
|
| |
|
| | |
| | config = OmegaConf.load("/scratch/pranamlab/sgoel/MeMDLM_v2/src/configs/lm.yaml") |
| |
|
| | |
| | datasets = get_datasets(config) |
| | data_module = MembraneDataModule( |
| | config=config, |
| | train_dataset=datasets['train'], |
| | val_dataset=datasets['val'], |
| | test_dataset=datasets['test'], |
| | ) |
| |
|
| | |
| | wandb.init(project=config.wandb.project, name=config.wandb.name) |
| | wandb_logger = WandbLogger(**config.wandb) |
| |
|
| | |
| | lr_monitor = LearningRateMonitor(logging_interval="step") |
| | checkpoint_callback = ModelCheckpoint( |
| | monitor="val/loss", |
| | save_top_k=1, |
| | mode="min", |
| | dirpath=config.checkpointing.save_dir, |
| | filename="best_model", |
| | every_n_train_steps=config.checkpointing.save_every_n_steps |
| | ) |
| |
|
| | |
| | trainer = pl.Trainer( |
| | max_steps=config.training.max_steps, |
| | max_epochs=None, |
| | accelerator="cuda" if torch.cuda.is_available() else "cpu", |
| | devices=config.training.devices if config.training.mode=='train' else [0], |
| | strategy=DDPStrategy(find_unused_parameters=True), |
| | callbacks=[checkpoint_callback, lr_monitor], |
| | logger=wandb_logger, |
| | log_every_n_steps=config.training.log_every_n_steps |
| | ) |
| |
|
| |
|
| | |
| | ckpt_path = config.checkpointing.save_dir |
| | try: os.makedirs(ckpt_path, exist_ok=False) |
| | except FileExistsError: pass |
| |
|
| | |
| | diffusion = MembraneDiffusion(config) |
| | diffusion.validate_config() |
| |
|
| | |
| | model_type = "evoflow" |
| | if config.training.mode == "train": |
| | apply_rdm_freezing(diffusion.model, config.training.n_layers, model_type) |
| | trainer.fit(diffusion, datamodule=data_module) |
| |
|
| | elif config.training.mode == "test": |
| | state_dict = diffusion.get_state_dict(config.checkpointing.best_ckpt_path) |
| | diffusion.load_state_dict(state_dict) |
| | trainer.test(diffusion, datamodule=data_module, ckpt_path=config.checkpointing.best_ckpt_path) |
| |
|
| | elif config.training.mode == "resume_from_checkpoint": |
| | state_dict = diffusion.get_state_dict(config.training.resume_ckpt_path) |
| | diffusion.load_state_dict(state_dict) |
| | apply_rdm_freezing(diffusion.model, config.training.n_layers, model_type) |
| | trainer.fit(diffusion, datamodule=data_module, ckpt_path=ckpt_path) |
| |
|
| | wandb.finish() |
| |
|
| |
|