| |
|
| | import re |
| | import torch |
| |
|
| | from .utils import utils |
| | |
| | from torch.utils.data import Dataset, DataLoader |
| | import lightning.pytorch as pl |
| | from functools import partial |
| | import sys |
| |
|
| | class CustomDataset(Dataset): |
| | def __init__(self, dataset, indices): |
| | self.dataset = dataset |
| | self.indices = indices |
| |
|
| | def __len__(self): |
| | return len(self.indices) |
| |
|
| | def __getitem__(self, idx): |
| | actual_idx = int(self.indices[idx]) |
| | item = self.dataset[actual_idx] |
| | return item |
| |
|
| |
|
| | |
| | def peptide_bond_mask(smiles_list): |
| | """ |
| | Returns a mask with shape (batch_size, seq_length) that has 1 at the locations |
| | of recognized bonds in the positions dictionary and 0 elsewhere. |
| | |
| | Args: |
| | smiles_list: List of peptide SMILES strings (batch of SMILES strings). |
| | |
| | Returns: |
| | np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions. |
| | """ |
| | |
| | batch_size = len(smiles_list) |
| | max_seq_length = max(len(smiles) for smiles in smiles_list) |
| | mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) |
| |
|
| | bond_patterns = [ |
| | (r'OC\(=O\)', 'ester'), |
| | (r'N\(C\)C\(=O\)', 'n_methyl'), |
| | (r'N[12]C\(=O\)', 'peptide'), |
| | (r'NC\(=O\)', 'peptide'), |
| | (r'C\(=O\)N\(C\)', 'n_methyl'), |
| | (r'C\(=O\)N[12]?', 'peptide') |
| | ] |
| |
|
| | for batch_idx, smiles in enumerate(smiles_list): |
| | positions = [] |
| | used = set() |
| |
|
| | |
| | for pattern, bond_type in bond_patterns: |
| | for match in re.finditer(pattern, smiles): |
| | if not any(p in range(match.start(), match.end()) for p in used): |
| | positions.append({ |
| | 'start': match.start(), |
| | 'end': match.end(), |
| | 'type': bond_type, |
| | 'pattern': match.group() |
| | }) |
| | used.update(range(match.start(), match.end())) |
| |
|
| | |
| | for pos in positions: |
| | mask[batch_idx, pos['start']:pos['end']] = 1 |
| |
|
| | return mask |
| |
|
| | def peptide_token_mask(smiles_list, token_lists): |
| | """ |
| | Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens |
| | where any part of the token overlaps with a peptide bond, and 0 elsewhere. |
| | |
| | Args: |
| | smiles_list: List of peptide SMILES strings (batch of SMILES strings). |
| | token_lists: List of tokenized SMILES strings (split into tokens). |
| | |
| | Returns: |
| | np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens. |
| | """ |
| | |
| | batch_size = len(smiles_list) |
| | token_seq_length = max(len(tokens) for tokens in token_lists) |
| | tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) |
| | atomwise_masks = peptide_bond_mask(smiles_list) |
| |
|
| | |
| | for batch_idx, atomwise_mask in enumerate(atomwise_masks): |
| | token_seq = token_lists[batch_idx] |
| | atom_idx = 0 |
| | |
| | for token_idx, token in enumerate(token_seq): |
| | if token_idx != 0 and token_idx != len(token_seq) - 1: |
| | if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1: |
| | tokenized_masks[batch_idx][token_idx] = 1 |
| | atom_idx += len(token) |
| | |
| | return tokenized_masks |
| |
|
| | def extract_amino_acid_sequence(helm_string): |
| | """ |
| | Extracts the amino acid sequence from a HELM peptide notation and outputs it as an array, |
| | removing any brackets around each amino acid. |
| | |
| | Args: |
| | helm_string (str): The HELM notation string for a peptide. |
| | |
| | Returns: |
| | list: A list containing each amino acid in sequence without brackets. |
| | """ |
| | |
| | matches = re.findall(r'PEPTIDE\d+\{([^}]+)\}', helm_string) |
| | |
| | if matches: |
| | |
| | amino_acid_sequence = [] |
| | for match in matches: |
| | sequence = match.replace('[', '').replace(']', '').split('.') |
| | amino_acid_sequence.extend(sequence) |
| | return amino_acid_sequence |
| | else: |
| | return "Invalid HELM notation or no peptide sequence found." |
| | |
| | def helm_collate_fn(batch, tokenizer): |
| | sequences = [item['HELM'] for item in batch] |
| | |
| | max_len = 0 |
| | for sequence in sequences: |
| | seq_len = len(extract_amino_acid_sequence(sequence)) |
| | if seq_len > max_len: |
| | max_len = seq_len |
| | |
| | tokens = tokenizer(sequences, return_tensors='pt', padding=True, truncation=True, max_length=1024) |
| | |
| | return { |
| | 'input_ids': tokens['input_ids'], |
| | 'attention_mask': tokens['attention_mask'] |
| | } |
| | |
| | |
| | def collate_fn(batch, tokenizer): |
| | """Standard data collator that truncates/pad sequences based on max_length""" |
| | valid_sequences = [] |
| | valid_items = [] |
| | |
| | for item in batch: |
| | try: |
| | test_tokens = tokenizer([item['SMILES']], return_tensors='pt', padding=False, truncation=True, max_length=1035) |
| | valid_sequences.append(item['SMILES']) |
| | valid_items.append(item) |
| | except Exception as e: |
| | print(f"Skipping sequence due to: {str(e)}") |
| | continue |
| | |
| | |
| | |
| | |
| |
|
| | tokens = tokenizer(valid_sequences, return_tensors='pt', padding=True, truncation=True, max_length=1035) |
| | |
| | token_array = tokenizer.get_token_split(tokens['input_ids']) |
| | bond_mask = peptide_token_mask(valid_sequences, token_array) |
| | |
| |
|
| | return { |
| | 'input_ids': tokens['input_ids'], |
| | 'attention_mask': tokens['attention_mask'], |
| | 'bond_mask': bond_mask |
| | } |
| | |
| |
|
| | class CustomDataModule(pl.LightningDataModule): |
| | def __init__(self, train_dataset, val_dataset, test_dataset, tokenizer, batch_size, collate_fn=collate_fn): |
| | super().__init__() |
| | self.train_dataset = train_dataset |
| | self.val_dataset = val_dataset |
| | |
| | self.batch_size = batch_size |
| | self.tokenizer = tokenizer |
| | self.collate_fn = collate_fn |
| |
|
| | def train_dataloader(self): |
| | return DataLoader(self.train_dataset, |
| | batch_size=self.batch_size, |
| | collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer), |
| | num_workers=8, |
| | pin_memory=True |
| | ) |
| | |
| |
|
| | def val_dataloader(self): |
| | return DataLoader(self.val_dataset, |
| | batch_size=self.batch_size, |
| | collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer), |
| | num_workers=8, |
| | pin_memory=True |
| | ) |
| | |
| | """def test_dataloader(self): |
| | return DataLoader(self.test_dataset, batch_size=self.batch_size, |
| | collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer), |
| | num_workers=8, pin_memory=True)""" |