| | import argparse |
| | import math |
| | import os |
| | from functools import partial |
| | from collections import Counter |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from datasets import load_from_disk |
| | from torch.optim import AdamW |
| | from torch.optim.lr_scheduler import LambdaLR |
| | from torch.utils.data import DataLoader |
| | import pytorch_lightning as pl |
| | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor |
| | from pytorch_lightning.loggers import WandbLogger |
| | from pytorch_lightning.strategies import DDPStrategy |
| | from rdkit import Chem |
| |
|
| | from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
| | from peptide_analyzer import PeptideAnalyzer |
| | import dataloading_for_dynamic_batching as dynamic_dataloader |
| |
|
| |
|
| | class RotaryPositionalEmbedding(nn.Module): |
| | def __init__(self, dim, max_position_embeddings=2048, base=10000): |
| | super().__init__() |
| | self.dim = dim |
| | self.max_position_embeddings = max_position_embeddings |
| | self.base = base |
| |
|
| | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| | self.register_buffer("inv_freq", inv_freq) |
| |
|
| | def forward(self, x, seq_len=None): |
| | if seq_len is None: |
| | seq_len = x.shape[1] |
| | |
| | t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) |
| | freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
| | emb = torch.cat((freqs, freqs), dim=-1) |
| | |
| | cos_emb = emb.cos()[None, :, :] |
| | sin_emb = emb.sin()[None, :, :] |
| | |
| | return cos_emb, sin_emb |
| |
|
| | def rotate_half(x): |
| | x1 = x[..., : x.shape[-1] // 2] |
| | x2 = x[..., x.shape[-1] // 2 :] |
| | return torch.cat((-x2, x1), dim=-1) |
| |
|
| | def apply_rotary_pos_emb(q, k, cos, sin): |
| | q_embed = (q * cos) + (rotate_half(q) * sin) |
| | k_embed = (k * cos) + (rotate_half(k) * sin) |
| | return q_embed, k_embed |
| |
|
| | |
| | def modulate(x, shift, scale): |
| | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
| |
|
| | class TimestepEmbedder(nn.Module): |
| | def __init__(self, hidden_size): |
| | super().__init__() |
| | self.mlp = nn.Sequential( |
| | nn.Linear(1, hidden_size, bias=True), |
| | nn.SiLU(), |
| | nn.Linear(hidden_size, hidden_size, bias=True), |
| | ) |
| |
|
| | def forward(self, t): |
| | return self.mlp(t.unsqueeze(-1)) |
| |
|
| | class MultiHeadAttentionWithRoPE(nn.Module): |
| | def __init__(self, hidden_size, n_heads): |
| | super().__init__() |
| | self.hidden_size = hidden_size |
| | self.n_heads = n_heads |
| | self.head_dim = hidden_size // n_heads |
| | |
| | assert self.head_dim * n_heads == hidden_size, "hidden_size must be divisible by n_heads" |
| | |
| | self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
| | self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
| | self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
| | self.out_proj = nn.Linear(hidden_size, hidden_size) |
| | |
| | self.rope = RotaryPositionalEmbedding(self.head_dim) |
| | |
| | def forward(self, x): |
| | batch_size, seq_len, hidden_size = x.shape |
| | |
| | q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) |
| | k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) |
| | v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) |
| | |
| | cos, sin = self.rope(q, seq_len) |
| | q, k = apply_rotary_pos_emb(q, k, cos, sin) |
| | |
| | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) |
| | attn_weights = F.softmax(scores, dim=-1) |
| | attn_output = torch.matmul(attn_weights, v) |
| | |
| | attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size) |
| | output = self.out_proj(attn_output) |
| | |
| | return output |
| |
|
| | class DiTBlock(nn.Module): |
| | def __init__(self, hidden_size, n_heads): |
| | super().__init__() |
| | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| | self.attn = MultiHeadAttentionWithRoPE(hidden_size, n_heads) |
| | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(hidden_size, 4 * hidden_size), |
| | nn.GELU(), |
| | nn.Linear(4 * hidden_size, hidden_size) |
| | ) |
| | self.adaLN_modulation = nn.Sequential( |
| | nn.SiLU(), |
| | nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
| | ) |
| |
|
| | def forward(self, x, c): |
| | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) |
| | x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa) |
| | attn_output = self.attn(x_norm1) |
| | x = x + gate_msa.unsqueeze(1) * attn_output |
| | x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp) |
| | mlp_output = self.mlp(x_norm2) |
| | x = x + gate_mlp.unsqueeze(1) * mlp_output |
| | return x |
| |
|
| | class MDLM(nn.Module): |
| | def __init__(self, vocab_size, model_dim, n_heads, n_layers): |
| | super().__init__() |
| | self.vocab_size = vocab_size |
| | self.model_dim = model_dim |
| | self.mask_token_id = vocab_size |
| |
|
| | self.token_embedder = nn.Embedding(vocab_size, model_dim) |
| | self.time_embedder = TimestepEmbedder(model_dim) |
| |
|
| | self.transformer_blocks = nn.ModuleList([ |
| | DiTBlock(model_dim, n_heads) for _ in range(n_layers) |
| | ]) |
| |
|
| | self.final_norm = nn.LayerNorm(model_dim) |
| | self.lm_head = nn.Linear(model_dim, vocab_size) |
| |
|
| | self.apply(self._init_weights) |
| |
|
| | def _init_weights(self, module): |
| | if isinstance(module, (nn.Linear, nn.Embedding)): |
| | module.weight.data.normal_(mean=0.0, std=0.02) |
| | if isinstance(module, nn.Linear) and module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, nn.LayerNorm): |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | if module.weight is not None: |
| | module.weight.data.fill_(1.0) |
| |
|
| | def forward(self, x, t): |
| | x_embed = self.token_embedder(x) |
| | t_embed = self.time_embedder(t) |
| | for block in self.transformer_blocks: |
| | x_embed = block(x_embed, t_embed) |
| | x_embed = self.final_norm(x_embed) |
| | logits = self.lm_head(x_embed) |
| | return logits |
| |
|
| | |
| | class MDLMLightningModule(pl.LightningModule): |
| | def __init__(self, args, tokenizer): |
| | super().__init__() |
| | self.save_hyperparameters(ignore=['tokenizer']) |
| | self.args = args |
| | self.tokenizer = tokenizer |
| | self.peptide_analyzer = PeptideAnalyzer() |
| | |
| | |
| | self.model = MDLM( |
| | vocab_size=tokenizer.vocab_size, |
| | model_dim=args.model_dim, |
| | n_heads=args.n_heads, |
| | n_layers=args.n_layers |
| | ) |
| | |
| | self.automatic_optimization = True |
| | self.validation_step_outputs = [] |
| | |
| | |
| | self.register_buffer('epoch_progress', torch.tensor(0.0)) |
| | |
| | def forward(self, x, t): |
| | return self.model(x, t) |
| |
|
| | def _compute_invalid_loss(self, logits, t_continuous=None): |
| | """ |
| | Original invalid loss computation from PepTune |
| | with optional time-dependent weighting |
| | """ |
| | batch_token_ids = torch.argmax(logits, dim=-1) |
| | sampled_sequences = self.tokenizer.batch_decode(batch_token_ids) |
| | |
| | |
| | penalties = torch.tensor( |
| | [1.0 if not self.peptide_analyzer.is_peptide(seq) else 0.0 for seq in sampled_sequences], |
| | dtype=torch.float32, |
| | device=self.device |
| | ) |
| | |
| | |
| | if t_continuous is not None and self.args.time_dependent_validity: |
| | |
| | time_weight = t_continuous ** self.args.validity_time_power |
| | penalties = penalties * time_weight |
| | |
| | |
| | sampled_probs = torch.softmax(logits, dim=-1).gather( |
| | dim=-1, index=batch_token_ids.unsqueeze(-1) |
| | ).squeeze(-1).to(self.device) |
| | |
| | |
| | scaled_penalty = penalties[:, None] * sampled_probs |
| | |
| | return scaled_penalty |
| |
|
| | def get_validity_weight(self): |
| | """ |
| | Compute annealed validity weight based on training progress |
| | """ |
| | current_epoch = self.current_epoch |
| | |
| | |
| | if current_epoch < self.args.validity_start_epoch: |
| | return 0.0 |
| | |
| | |
| | epochs_with_validity = current_epoch - self.args.validity_start_epoch |
| | max_epochs_with_validity = self.args.epochs - self.args.validity_start_epoch |
| | |
| | if self.args.validity_schedule == 'linear': |
| | |
| | progress = epochs_with_validity / max_epochs_with_validity |
| | weight = (self.args.validity_weight_min + |
| | (self.args.validity_weight_max - self.args.validity_weight_min) * progress) |
| | |
| | elif self.args.validity_schedule == 'exponential': |
| | |
| | progress = epochs_with_validity / max_epochs_with_validity |
| | weight = (self.args.validity_weight_min * |
| | (self.args.validity_weight_max / self.args.validity_weight_min) ** progress) |
| | |
| | elif self.args.validity_schedule == 'cosine': |
| | |
| | progress = epochs_with_validity / max_epochs_with_validity |
| | cosine_factor = 0.5 * (1 - math.cos(math.pi * progress)) |
| | weight = (self.args.validity_weight_min + |
| | (self.args.validity_weight_max - self.args.validity_weight_min) * cosine_factor) |
| | |
| | elif self.args.validity_schedule == 'step': |
| | |
| | steps = [0.25, 0.5, 0.75, 1.0] |
| | weights = [self.args.validity_weight_min, |
| | self.args.validity_weight_min * 2, |
| | self.args.validity_weight_min * 5, |
| | self.args.validity_weight_max] |
| | progress = epochs_with_validity / max_epochs_with_validity |
| | for i, step in enumerate(steps): |
| | if progress <= step: |
| | weight = weights[i] |
| | break |
| | else: |
| | |
| | weight = self.args.validity_weight_max |
| | |
| | return weight |
| |
|
| | def _loss(self, logits, x_1, attn_mask, t_continuous=None): |
| | """ |
| | Combined loss with staged validity loss |
| | """ |
| | |
| | ce_loss = F.cross_entropy( |
| | logits.view(-1, self.model.vocab_size), |
| | x_1.view(-1), |
| | reduction='none' |
| | ).view(x_1.shape[0], -1) |
| | |
| | |
| | validity_weight = self.get_validity_weight() |
| | |
| | |
| | if validity_weight > 0: |
| | invalid_loss = self._compute_invalid_loss(logits, t_continuous) |
| | else: |
| | invalid_loss = torch.zeros_like(ce_loss) |
| | |
| | |
| | total_loss = ce_loss + validity_weight * invalid_loss |
| | |
| | |
| | masked_loss = total_loss * attn_mask |
| | num_tokens = attn_mask.sum() |
| | token_nll = masked_loss.sum() / num_tokens |
| | |
| | |
| | ce_token_loss = (ce_loss * attn_mask).sum() / num_tokens |
| | invalid_token_loss = (invalid_loss * attn_mask).sum() / num_tokens |
| | |
| | return token_nll, ce_token_loss, invalid_token_loss, validity_weight |
| |
|
| | def training_step(self, batch, batch_idx): |
| | x_0 = batch['source_ids'].to(self.device) |
| | x_1 = batch['target_ids'].to(self.device) |
| | attn_mask = torch.ones_like(x_1).to(self.device) |
| | bond_mask = batch['bond_mask'].to(self.device).bool() |
| | batch_size, _ = x_1.shape |
| |
|
| | |
| | t_continuous = torch.rand(batch_size, device=self.device) |
| | |
| | |
| | peptide_bond_prob = t_continuous.view(-1, 1) ** self.args.gamma |
| | non_peptide_prob = t_continuous.view(-1, 1) |
| | |
| | masking_prob = torch.where(bond_mask, peptide_bond_prob, non_peptide_prob) |
| | mask = torch.rand(x_1.shape, device=self.device) < masking_prob |
| | x_t = torch.where(mask, x_1, x_0) |
| |
|
| | |
| | logits = self.model(x_t, t_continuous) |
| | |
| | |
| | token_nll, ce_loss, invalid_loss, validity_weight = self._loss( |
| | logits, x_1, attn_mask, t_continuous |
| | ) |
| | |
| | |
| | self.log('train/token_nll', token_nll.item(), on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size, sync_dist=True) |
| | self.log('train/ce_loss', ce_loss.item(), on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True) |
| | self.log('train/invalid_loss', invalid_loss.item(), on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True) |
| | self.log('train/validity_weight', validity_weight, on_step=False, on_epoch=True, batch_size=batch_size, sync_dist=True) |
| | |
| | |
| | if batch_idx % 1000 == 0: |
| | total_norm = 0 |
| | for p in self.model.parameters(): |
| | if p.grad is not None: |
| | param_norm = p.grad.data.norm(2) |
| | total_norm += param_norm.item() ** 2 |
| | total_norm = total_norm ** 0.5 |
| | self.log('train/grad_norm', total_norm, batch_size=batch_size, sync_dist=True) |
| | |
| | return token_nll |
| | |
| | def validation_step(self, batch, batch_idx): |
| | x_0 = batch['source_ids'].to(self.device) |
| | x_1 = batch['target_ids'].to(self.device) |
| | attn_mask = torch.ones_like(x_1).to(self.device) |
| | bond_mask = batch['bond_mask'].to(self.device).bool() |
| | batch_size, _ = x_1.shape |
| |
|
| | |
| | t_continuous = torch.rand(batch_size, device=self.device) |
| | |
| | peptide_bond_prob = t_continuous.view(-1, 1) ** self.args.gamma |
| | non_peptide_prob = t_continuous.view(-1, 1) |
| | |
| | masking_prob = torch.where(bond_mask, peptide_bond_prob, non_peptide_prob) |
| | mask = torch.rand(x_1.shape, device=self.device) < masking_prob |
| | x_t = torch.where(mask, x_1, x_0) |
| | |
| | logits = self.model(x_t, t_continuous) |
| | |
| | token_nll, ce_loss, invalid_loss, validity_weight = self._loss( |
| | logits, x_1, attn_mask, t_continuous |
| | ) |
| | |
| | self.log('val/token_nll', token_nll.item(), on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size, sync_dist=True) |
| | self.log('val/ce_loss', ce_loss.item(), on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True) |
| | self.log('val/invalid_loss', invalid_loss.item(), on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True) |
| | |
| | |
| | if batch_idx == 0: |
| | with torch.no_grad(): |
| | validity_results = {} |
| | for t_val in [0.9, 0.5, 0.1]: |
| | t_test = torch.full((batch_size,), t_val, device=self.device) |
| | test_mask = torch.rand(x_1.shape, device=self.device) < t_val |
| | x_test = torch.where(test_mask, x_1, x_0) |
| | |
| | test_logits = self.model(x_test, t_test) |
| | test_preds = torch.argmax(test_logits, dim=-1) |
| | |
| | sequences = self.tokenizer.batch_decode(test_preds) |
| | valid_count = sum(1 for seq in sequences if self.peptide_analyzer.is_peptide(seq)) |
| | validity_rate = valid_count / len(sequences) |
| | |
| | self.log(f'val/validity_rate_t{t_val}', validity_rate, batch_size=batch_size, sync_dist=True) |
| |
|
| | def configure_optimizers(self): |
| | optimizer = AdamW( |
| | self.parameters(), |
| | lr=self.args.learning_rate, |
| | weight_decay=self.args.weight_decay |
| | ) |
| | |
| | |
| | if hasattr(self.trainer, 'estimated_stepping_batches'): |
| | num_training_steps = self.trainer.estimated_stepping_batches |
| | else: |
| | num_training_steps = len(self.trainer.datamodule.train_dataloader()) * self.trainer.max_epochs |
| | |
| | warmup_steps = int(num_training_steps * 0.1) |
| | |
| | def lr_lambda(current_step): |
| | if current_step < warmup_steps: |
| | |
| | lr_factor = current_step / warmup_steps |
| | return lr_factor |
| | else: |
| | |
| | progress = (current_step - warmup_steps) / (num_training_steps - warmup_steps) |
| | cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) |
| | min_lr_ratio = 0.1 |
| | return min_lr_ratio + (1 - min_lr_ratio) * cosine_decay |
| |
|
| | scheduler = LambdaLR(optimizer, lr_lambda) |
| | |
| | return { |
| | "optimizer": optimizer, |
| | "lr_scheduler": { |
| | "scheduler": scheduler, |
| | "interval": "step", |
| | "frequency": 1, |
| | }, |
| | } |
| |
|
| | def main(args): |
| | |
| | checkpoint_dir = (args.checkpoint_dir + |
| | f"new_lr{args.learning_rate}_layer{args.n_layers}_" |
| | f"head{args.n_heads}_{args.validity_schedule}") |
| | print(f"Saving to {checkpoint_dir}") |
| | os.makedirs(checkpoint_dir, exist_ok=True) |
| |
|
| | print("Loading tokenizer...") |
| | tokenizer = SMILES_SPE_Tokenizer('/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', |
| | '/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt') |
| | print(f"Tokenizer loaded. Vocab size: {tokenizer.vocab_size}") |
| |
|
| | |
| | data_module = dynamic_dataloader.RectifyDataModule('/scratch/pranamlab/tong/data/smiles/v1') |
| | |
| | model = MDLMLightningModule(args, tokenizer) |
| | model = MDLMLightningModule.load_from_checkpoint( |
| | checkpoint_path=args.checkpoint, |
| | args=args, |
| | tokenizer=tokenizer |
| | ) |
| | |
| | logger = WandbLogger( |
| | project="smiles-redi-staged-training", |
| | entity="programmablebio", |
| | name=f"v1_lr{args.learning_rate}_epochs{args.validity_start_epoch}_{args.validity_schedule}", |
| | save_dir=checkpoint_dir |
| | ) |
| | |
| | |
| | callbacks = [ |
| | ModelCheckpoint( |
| | dirpath=checkpoint_dir, |
| | filename='best', |
| | monitor='val/token_nll', |
| | mode='min', |
| | save_top_k=1, |
| | save_last=True, |
| | |
| | ), |
| | |
| | ModelCheckpoint( |
| | dirpath=checkpoint_dir, |
| | filename='{epoch:02d}', |
| | save_top_k=-1, |
| | every_n_epochs=1, |
| | save_on_train_epoch_end=True |
| | ), |
| | LearningRateMonitor(logging_interval='step') |
| | ] |
| | |
| | |
| | trainer = pl.Trainer( |
| | max_epochs=args.epochs, |
| | devices=torch.cuda.device_count(), |
| | accelerator='gpu', |
| | strategy=DDPStrategy(find_unused_parameters=False), |
| | num_nodes=int(os.environ.get("SLURM_NNODES", 1)), |
| | precision="bf16", |
| | gradient_clip_val=args.grad_clip if args.grad_clip > 0 else None, |
| | callbacks=callbacks, |
| | logger=logger, |
| | log_every_n_steps=100, |
| | check_val_every_n_epoch=None, |
| | |
| | accumulate_grad_batches=1, |
| | enable_progress_bar=True, |
| | enable_model_summary=True |
| | ) |
| | |
| | print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters.") |
| | print(f"Training strategy: CE-only for {args.validity_start_epoch} epochs, then staged validity loss") |
| | print("Starting training...") |
| | |
| | |
| | trainer.fit(model, data_module) |
| | |
| | print("Training complete.") |
| | print(f"Best checkpoint saved at: {trainer.checkpoint_callback.best_model_path}") |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Train ReDi model with staged validity loss") |
| |
|
| | |
| | parser.add_argument("--model_dim", type=int, default=1024) |
| | parser.add_argument("--n_heads", type=int, default=8) |
| | parser.add_argument("--n_layers", type=int, default=6) |
| |
|
| | |
| | parser.add_argument("--epochs", type=int, default=5) |
| | parser.add_argument("--learning_rate", type=float, default=1e-4) |
| | parser.add_argument("--weight_decay", type=float, default=1e-5) |
| | parser.add_argument("--label_smoothing", type=float, default=0) |
| | parser.add_argument("--grad_clip", type=float, default=1.0) |
| | parser.add_argument("--gamma", type=float, default=2.0) |
| |
|
| | |
| | parser.add_argument("--validity_start_epoch", type=int, default=2, help="Epoch to start adding validity loss (0-indexed)") |
| | parser.add_argument("--validity_weight_min", type=float, default=10.0, help="Initial validity weight when starting") |
| | parser.add_argument("--validity_weight_max", type=float, default=200.0, help="Maximum validity weight") |
| | parser.add_argument("--validity_schedule", type=str, default="linear", choices=['linear', 'exponential', 'cosine', 'step', 'constant'], help="Schedule for increasing validity weight") |
| | parser.add_argument("--time_dependent_validity", type=bool, default=False, help="Whether to apply time-dependent scaling to validity loss") |
| | parser.add_argument("--validity_time_power", type=float, default=0.5, help="Power for time-dependent validity scaling") |
| | |
| | |
| | parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints_smiles") |
| | parser.add_argument("--checkpoint", type=str, required=True) |
| |
|
| | args = parser.parse_args() |
| | main(args) |