| | import argparse |
| | from pathlib import Path |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from tqdm import tqdm |
| |
|
| | |
| | from smiles_train import MDLMLightningModule, PeptideAnalyzer |
| | from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
| |
|
| | import pdb |
| |
|
| |
|
| | def generate_smiles(model, tokenizer, args): |
| | """ |
| | Generates peptide SMILES strings using the trained MDLM model |
| | with a forward (t=0 to t=1) flow matching process. |
| | |
| | Args: |
| | model (MDLMLightningModule): The trained PyTorch Lightning model. |
| | tokenizer (SMILES_SPE_Tokenizer): The tokenizer used for training. |
| | args (argparse.Namespace): Command-line arguments containing sampling parameters. |
| | |
| | Returns: |
| | list[str]: A list of generated SMILES strings. |
| | float: The validity rate of the generated SMILES. |
| | """ |
| | print("Starting SMILES generation with forward flow matching (t=0 to t=1)...") |
| | model.eval() |
| | device = args.device |
| |
|
| | |
| | x = torch.randint( |
| | 0, |
| | model.model.vocab_size, |
| | (args.n_samples, args.seq_len), |
| | device=device |
| | ) |
| |
|
| | |
| | time_steps = torch.linspace(0.0, 1.0, args.n_steps + 1, device=device) |
| |
|
| | |
| | with torch.no_grad(): |
| | for i in tqdm(range(args.n_steps), desc="Flow Matching Steps"): |
| | t_curr = time_steps[i] |
| | t_next = time_steps[i+1] |
| |
|
| | |
| | t_tensor = torch.full((args.n_samples,), t_curr, device=device) |
| |
|
| | |
| | logits = model(x, t_tensor) |
| | logits = logits / args.temperature |
| |
|
| | pred_x1 = torch.argmax(logits, dim=-1) |
| |
|
| | |
| | if i == args.n_steps - 1: |
| | x = pred_x1 |
| | break |
| |
|
| | |
| | |
| | noise_prob = 1.0 - t_next |
| | mask = torch.rand(x.shape, device=device) < noise_prob |
| | |
| | |
| | noise = torch.randint( |
| | 0, |
| | model.model.vocab_size, |
| | x.shape, |
| | device=device |
| | ) |
| |
|
| | |
| | x = torch.where(mask, noise, pred_x1) |
| |
|
| | |
| | generated_sequences = tokenizer.batch_decode(x) |
| | |
| | |
| | peptide_analyzer = PeptideAnalyzer() |
| | valid_count = 0 |
| | valid_smiles = [] |
| | for seq in generated_sequences: |
| | if peptide_analyzer.is_peptide(seq): |
| | valid_count += 1 |
| | valid_smiles.append(seq) |
| | |
| | validity_rate = valid_count / len(generated_sequences) |
| | |
| | print(f"\nGeneration complete. Validity rate: {validity_rate:.2%}") |
| | return valid_smiles, validity_rate |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Sample from a trained ReDi model.") |
| | |
| | |
| | parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the model checkpoint (.ckpt file).") |
| | |
| | |
| | parser.add_argument("--n_samples", type=int, default=16, help="Number of SMILES strings to generate.") |
| | parser.add_argument("--seq_len", type=int, default=256, help="Maximum sequence length for generated SMILES.") |
| | parser.add_argument("--n_steps", type=int, default=100, help="Number of denoising steps for sampling.") |
| | parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature. Higher values increase diversity.") |
| | |
| | |
| | parser.add_argument("--vocab_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', help="Path to tokenizer vocabulary file.") |
| | parser.add_argument("--splits_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt', help="Path to tokenizer splits file.") |
| | parser.add_argument("--output_file", type=str, default="generated_smiles.txt", help="File to save the valid generated SMILES.") |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | args.device = device |
| | print(f"Using device: {device}") |
| | |
| | |
| | print("Loading tokenizer...") |
| | tokenizer = SMILES_SPE_Tokenizer(args.vocab_path, args.splits_path) |
| |
|
| | print(f"Loading model from checkpoint: {args.checkpoint_path}") |
| | |
| | checkpoint = torch.load(args.checkpoint_path, map_location=device, weights_only=False) |
| | model_hparams = checkpoint["hyper_parameters"]["args"] |
| | |
| | |
| | model = MDLMLightningModule.load_from_checkpoint( |
| | args.checkpoint_path, |
| | args=model_hparams, |
| | tokenizer=tokenizer, |
| | map_location=device, |
| | strict=False |
| | ) |
| | model.to(device) |
| |
|
| | |
| | valid_smiles, validity_rate = generate_smiles(model, tokenizer, args) |
| |
|
| | |
| |
|
| | with open('./v0_samples_200.csv', 'a') as f: |
| | for smiles in valid_smiles: |
| | |
| | f.write(smiles + '\n') |
| | print(validity_rate) |
| |
|
| | if __name__ == "__main__": |
| | main() |