| | import argparse |
| | import math |
| | import os |
| | from collections import defaultdict |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from tqdm import tqdm |
| | from datasets import Dataset, DatasetDict |
| |
|
| | |
| | def modulate(x, shift, scale): |
| | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
| |
|
| | class TimestepEmbedder(nn.Module): |
| | def __init__(self, hidden_size): |
| | super().__init__() |
| | self.mlp = nn.Sequential( |
| | nn.Linear(1, hidden_size, bias=True), nn.SiLU(), |
| | nn.Linear(hidden_size, hidden_size, bias=True), |
| | ) |
| | def forward(self, t): |
| | return self.mlp(t.unsqueeze(-1)) |
| |
|
| | class DiTBlock(nn.Module): |
| | def __init__(self, hidden_size, n_heads): |
| | super().__init__() |
| | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| | self.attn = nn.MultiheadAttention(hidden_size, n_heads, batch_first=True) |
| | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(hidden_size, 4 * hidden_size), nn.GELU(), |
| | nn.Linear(4 * hidden_size, hidden_size) |
| | ) |
| | self.adaLN_modulation = nn.Sequential( |
| | nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
| | ) |
| | def forward(self, x, c): |
| | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) |
| | x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa) |
| | attn_output, _ = self.attn(x_norm1, x_norm1, x_norm1) |
| | x = x + gate_msa.unsqueeze(1) * attn_output |
| | x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp) |
| | mlp_output = self.mlp(x_norm2) |
| | x = x + gate_mlp.unsqueeze(1) * mlp_output |
| | return x |
| |
|
| | class MDLM(nn.Module): |
| | def __init__(self, vocab_size, seq_len, model_dim, n_heads, n_layers): |
| | super().__init__() |
| | self.vocab_size = vocab_size |
| | self.seq_len = seq_len |
| | self.model_dim = model_dim |
| | self.mask_token_id = vocab_size |
| | self.token_embedder = nn.Embedding(vocab_size + 1, model_dim) |
| | self.pos_embedder = nn.Parameter(torch.randn(1, seq_len, model_dim)) |
| | self.time_embedder = TimestepEmbedder(model_dim) |
| | self.transformer_blocks = nn.ModuleList([DiTBlock(model_dim, n_heads) for _ in range(n_layers)]) |
| | self.final_norm = nn.LayerNorm(model_dim) |
| | self.lm_head = nn.Linear(model_dim, vocab_size) |
| | def forward(self, x, t): |
| | seq_len = x.shape[1] |
| | x_embed = self.token_embedder(x) + self.pos_embedder[:, :seq_len, :] |
| | t_embed = self.time_embedder(t) |
| | for block in self.transformer_blocks: |
| | x_embed = block(x_embed, t_embed) |
| | x_embed = self.final_norm(x_embed) |
| | logits = self.lm_head(x_embed) |
| | return logits |
| |
|
| | |
| |
|
| | def generate_x1_from_x0(model, device, x0_batch, steps, temperature): |
| | model.eval() |
| | x = x0_batch.clone() |
| | num_samples, seq_len = x.shape |
| | keep_schedule = torch.cos(torch.linspace(math.pi / 2, 0, steps, device=device)) * seq_len |
| | keep_schedule = torch.round(keep_schedule).long() |
| | with torch.no_grad(): |
| | for i in range(steps): |
| | t_continuous = torch.full((num_samples,), 1.0 - (i / steps), device=device) |
| | logits = model(x, t_continuous) |
| | scaled_logits = logits / temperature |
| | probs = torch.nn.functional.softmax(scaled_logits, dim=-1) |
| | sampled_tokens = torch.multinomial(probs.view(-1, model.vocab_size), 1).view(x.shape) |
| | if i == steps - 1: |
| | x = sampled_tokens |
| | break |
| | confidence = torch.gather(probs, 2, sampled_tokens.unsqueeze(-1)).squeeze(-1) |
| | num_to_keep = keep_schedule[i] |
| | _, indices_to_keep = torch.topk(confidence, num_to_keep, largest=True, dim=-1) |
| | keep_mask = torch.zeros_like(x, dtype=torch.bool).scatter_(1, indices_to_keep, True) |
| | x = torch.where(keep_mask, sampled_tokens, x) |
| | return x |
| |
|
| | def is_sample_valid(sample_x1): |
| | """ |
| | Checks if special tokens [0, 1, 2, 3] appear in the middle of the sequence. |
| | """ |
| | middle_sequence = sample_x1[1:-1] |
| | invalid_tokens = {0, 1, 2, 3} |
| | for token in middle_sequence: |
| | if token in invalid_tokens: |
| | return False |
| | return True |
| |
|
| | def create_prebatched_dataset(dataset, max_tokens_per_batch=500): |
| | """ |
| | Groups samples into batches and restructures the dataset. |
| | Each row in the new dataset is a complete batch. |
| | """ |
| | |
| | data_by_length = defaultdict(list) |
| | for sample in dataset: |
| | length = len(sample['input_ids_x1']) |
| | data_by_length[length].append(sample) |
| |
|
| | |
| | batched_data = {'input_ids_x0': [], 'input_ids_x1': []} |
| | for length, samples in data_by_length.items(): |
| | samples_per_batch = max(1, max_tokens_per_batch // length) |
| | for i in range(0, len(samples), samples_per_batch): |
| | batch_samples = samples[i:i + samples_per_batch] |
| | |
| | batch_x0 = [s['input_ids_x0'] for s in batch_samples] |
| | batch_x1 = [s['input_ids_x1'] for s in batch_samples] |
| | |
| | batched_data['input_ids_x0'].append(batch_x0) |
| | batched_data['input_ids_x1'].append(batch_x1) |
| | |
| | return Dataset.from_dict(batched_data) |
| |
|
| | |
| |
|
| | def main(args): |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"Using device: {device}") |
| |
|
| | print(f"Loading checkpoint from {args.checkpoint}...") |
| | try: |
| | checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False) |
| | model_args = checkpoint['args'] |
| | except Exception as e: |
| | print(f"Error loading checkpoint: {e}") |
| | return |
| |
|
| | print("Initializing model...") |
| | model = MDLM( |
| | vocab_size=model_args.vocab_size, |
| | seq_len=model_args.seq_len, |
| | model_dim=model_args.model_dim, |
| | n_heads=model_args.n_heads, |
| | n_layers=model_args.n_layers |
| | ).to(device) |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | print("Model loaded successfully.") |
| |
|
| | all_x0 = [] |
| | all_x1 = [] |
| |
|
| | |
| | for length in range(args.min_len, args.max_len + 1): |
| | print(f"Generating {args.samples_per_len} valid samples for length {length}...") |
| | valid_samples_count = 0 |
| | pbar = tqdm(total=args.samples_per_len) |
| | while valid_samples_count < args.samples_per_len: |
| | remaining = args.samples_per_len - valid_samples_count |
| | batch_size = min(args.batch_size, remaining) |
| | |
| | shape = (batch_size, length) |
| | x0_batch = torch.randint(0, model.vocab_size, shape, dtype=torch.long, device=device) |
| | x1_batch = generate_x1_from_x0(model, device, x0_batch, args.gen_steps, args.temperature) |
| |
|
| | |
| | for x0, x1 in zip(x0_batch, x1_batch): |
| | if is_sample_valid(x1.tolist()): |
| | all_x0.append(x0.cpu().tolist()) |
| | all_x1.append(x1.cpu().tolist()) |
| | valid_samples_count += 1 |
| | pbar.update(1) |
| | if valid_samples_count >= args.samples_per_len: |
| | break |
| | pbar.close() |
| |
|
| | |
| | print("Splitting dataset...") |
| | rectified_data = {'input_ids_x0': all_x0, 'input_ids_x1': all_x1} |
| | dataset = Dataset.from_dict(rectified_data) |
| | train_test_split = dataset.train_test_split(test_size=0.2, seed=42) |
| | valid_test_split = train_test_split['test'].train_test_split(test_size=0.5, seed=42) |
| | final_dataset_dict = DatasetDict({ |
| | 'train': train_test_split['train'], |
| | 'validation': valid_test_split['train'], |
| | 'test': valid_test_split['test'] |
| | }) |
| |
|
| | |
| | print("Pre-batching splits...") |
| | batched_dataset_dict = DatasetDict() |
| | for split_name, split_dataset in final_dataset_dict.items(): |
| | print(f"Processing {split_name} split...") |
| | batched_dataset_dict[split_name] = create_prebatched_dataset(split_dataset) |
| |
|
| | |
| | output_path = f"{args.output_path}/v{args.version}" |
| | print(f"Saving new batched dataset to {output_path}...") |
| | batched_dataset_dict.save_to_disk(output_path) |
| | |
| | print("Rectification complete.") |
| | print(f"Train on this by updating your training script's dataset path to '{output_path}'.") |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Generate a rectified dataset with variable lengths and pre-batching.") |
| |
|
| | parser.add_argument("--checkpoint", type=str, required=True) |
| | parser.add_argument("--output_path", type=str, default="./rectified_datasets") |
| | parser.add_argument("--version", type=str, default='1') |
| | parser.add_argument("--samples_per_len", type=int, default=10000) |
| | parser.add_argument("--min_len", type=int, default=6) |
| | parser.add_argument("--max_len", type=int, default=49) |
| | parser.add_argument("--gen_steps", type=int, default=16) |
| | parser.add_argument("--temperature", type=float, default=1.0) |
| | parser.add_argument("--batch_size", type=int, default=128) |
| |
|
| | args = parser.parse_args() |
| | main(args) |
| |
|