Spaces:
Sleeping
Sleeping
| # import for colab/kaggle | |
| # !pip install datasets transformers wandb -q | |
| # !pip install pytorch-lightning lightning tiktoken -q | |
| import os | |
| import math | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from datasets import load_dataset | |
| from transformers import GPT2Tokenizer | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.callbacks import LearningRateMonitor, RichProgressBar | |
| from pytorch_lightning.loggers import WandbLogger | |
| from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme | |
| from pytorch_lightning.callbacks import ModelCheckpoint | |
| block_size = 512 | |
| batch_size = 8 | |
| max_lr = 1e-3 | |
| warmup_steps = 10 | |
| max_steps = 25000 | |
| log_every_n_steps = 100 | |
| save_checkpoints_every_n_steps = 10 | |
| effective_batch_size = 32 | |
| tokenizer: GPT2Tokenizer = GPT2Tokenizer.from_pretrained( | |
| "HuggingFaceTB/cosmo2-tokenizer" | |
| ) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| vocab_size = tokenizer.vocab_size | |
| def load_cosmopedia_dataset(batch_size=8, seq_length=1024): | |
| """ | |
| Returns a torch dataloader for the cosmopedia dataset | |
| """ | |
| try: | |
| dataset = load_dataset( | |
| "HuggingFaceTB/smollm-corpus", | |
| name="cosmopedia-v2", | |
| split="train", | |
| streaming=True, | |
| ) | |
| def encode(examples): | |
| tokens = tokenizer( | |
| examples["text"], | |
| truncation=True, | |
| padding="max_length", | |
| max_length=seq_length + 1, | |
| return_tensors="pt", | |
| ) | |
| input_ids = tokens["input_ids"].squeeze(0).clone().detach() | |
| input_ids = torch.clamp(input_ids, min=0, max=tokenizer.vocab_size - 1) | |
| labels = input_ids.clone().detach() | |
| labels = labels[1:].to(torch.int64) | |
| input_ids = input_ids[:-1].to(torch.int64) | |
| return {"input_ids": input_ids, "labels": labels} | |
| dataset = dataset.map(encode, remove_columns=["text"], batched=False) | |
| dataset = dataset.with_format("torch") | |
| dataloader = DataLoader(dataset, batch_size=batch_size) | |
| return dataloader | |
| except Exception as e: | |
| print(e) | |
| return None | |
| class SmolLMConfig: | |
| block_size = 1024 | |
| vocab_size = 49152 | |
| n_layers = 30 | |
| n_heads = 9 | |
| n_embed = 576 | |
| dropout = 0.1 | |
| mlp_hidden_dim = 1536 | |
| attention_dropout = 0.0 | |
| dropout = 0.1 | |
| n_key_value_heads = 3 | |
| rms_norm_eps = 1e-5 | |
| ## Function which enables K and V to have less heads than Q. | |
| ## it repeats the K and V heads n_rep times | |
| def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: | |
| """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" | |
| bs, n_kv_heads, slen, head_dim = x.shape | |
| if n_rep == 1: | |
| return x | |
| return ( | |
| x[:, :, :, None, :] | |
| .expand(bs, n_kv_heads, slen, n_rep, head_dim) | |
| .reshape(bs, n_kv_heads * n_rep, slen, head_dim) | |
| ) | |
| class RMSNorm(torch.nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| """ | |
| Initialize the RMSNorm normalization layer. | |
| Args: | |
| dim (int): The dimension of the input tensor. | |
| eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. | |
| Attributes: | |
| eps (float): A small value added to the denominator for numerical stability. | |
| weight (nn.Parameter): Learnable scaling parameter. | |
| """ | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def _norm(self, x): | |
| """ | |
| Apply the RMSNorm normalization to the input tensor. | |
| Args: | |
| x (torch.Tensor): The input tensor. | |
| Returns: | |
| torch.Tensor: The normalized tensor. | |
| """ | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| """ | |
| Forward pass through the RMSNorm layer. | |
| Args: | |
| x (torch.Tensor): The input tensor. | |
| Returns: | |
| torch.Tensor: The output tensor after applying RMSNorm. | |
| """ | |
| output = self._norm(x.float()).type_as(x) | |
| return output * self.weight | |
| class CausalMultiHeadAttention(nn.Module): | |
| def __init__(self, config: SmolLMConfig): | |
| super().__init__() | |
| self.config = config | |
| self.n_head = config.n_heads | |
| self.n_embd = config.n_embed | |
| # Linear projections for Q, K, V | |
| # self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed) # [n_embd, 3 * n_embd] | |
| self.w_q = nn.Linear(config.n_embed, config.n_embed, bias=False) | |
| self.w_k = nn.Linear( | |
| config.n_embed, config.n_embed // config.n_key_value_heads, bias=False | |
| ) | |
| self.w_v = nn.Linear( | |
| config.n_embed, config.n_embed // config.n_key_value_heads, bias=False | |
| ) | |
| self.c_proj = nn.Linear( | |
| config.n_embed, config.n_embed, bias=False | |
| ) # [n_embd, n_embd] | |
| self.c_proj.NANGPT_SCALE_INIT = 1 | |
| self.n_rep = self.config.n_heads // self.config.n_key_value_heads | |
| self.resid_dropout = nn.Dropout(config.dropout) | |
| self.register_buffer( | |
| "bias", | |
| torch.tril(torch.ones(config.block_size, config.block_size)).view( | |
| 1, 1, config.block_size, config.block_size | |
| ), | |
| ) | |
| def forward(self, x): | |
| B, T, C = x.size() # [B, T, n_embd] | |
| # Linear projection and split into Q, K, V | |
| # q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # [B, T, n_embd] each | |
| q = self.w_q(x) # [B, T, 576] | |
| k = self.w_k(x) # [B, T, 192] | |
| v = self.w_v(x) # [B, T, 192] | |
| # Reshape for multi-head attention | |
| k = k.view( | |
| B, | |
| T, | |
| self.config.n_key_value_heads, | |
| k.size(-1) // self.config.n_key_value_heads, | |
| ).transpose( | |
| 1, 2 | |
| ) # [B, 3, T, 64] | |
| q = q.view( | |
| B, T, self.config.n_heads, q.size(-1) // self.config.n_heads | |
| ).transpose( | |
| 1, 2 | |
| ) # [B, 9, T, 64] | |
| v = v.view( | |
| B, | |
| T, | |
| self.config.n_key_value_heads, | |
| v.size(-1) // self.config.n_key_value_heads, | |
| ).transpose( | |
| 1, 2 | |
| ) # [B, 3, T, 64] | |
| # repeat k and v for each head | |
| k = repeat_kv(k, self.n_rep) | |
| v = repeat_kv(v, self.n_rep) | |
| # # Attention scores | |
| # att = (q @ k.transpose(-2, -1)) * ( | |
| # 1.0 / (k.size(-1) ** 0.5) | |
| # ) # [B, n_head, T, T] | |
| # att = att.masked_fill( | |
| # self.bias[:, :, :T, :T] == 0, float("-inf") | |
| # ) # [B, n_head, T, T] | |
| # att = F.softmax(att, dim=-1) # [B, n_head, T, T] | |
| # # Weighted sum of values | |
| # y = att @ v # [B, n_head, T, n_embd/n_head] | |
| # Flash attention | |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # Flash attention | |
| # Reshape and project | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) # [B, T, n_embd] | |
| y = self.c_proj(y) # [B, T, n_embd] | |
| y = self.resid_dropout(y) # [B, T, n_embd] | |
| return y | |
| class MLP(nn.Module): | |
| def __init__(self, config: SmolLMConfig): | |
| super().__init__() | |
| self.c_fc = nn.Linear(config.n_embed, config.mlp_hidden_dim, bias=False) | |
| self.silu = nn.SiLU() | |
| self.c_proj = nn.Linear(config.mlp_hidden_dim, config.n_embed, bias=False) | |
| self.c_proj.NANOGPT_SCALE_INIT = 1 | |
| def forward(self, x): | |
| x = self.c_fc(x) | |
| x = self.silu(x) | |
| x = self.c_proj(x) | |
| return x | |
| class LlamaMLP(nn.Module): | |
| def __init__(self, config: SmolLMConfig): | |
| super().__init__() | |
| self.hidden_dim = config.mlp_hidden_dim # 1536 | |
| self.w1 = nn.Linear(config.n_embed, self.hidden_dim, bias=False) | |
| self.w2 = nn.Linear(self.hidden_dim, config.n_embed, bias=False) | |
| self.w3 = nn.Linear(config.n_embed, self.hidden_dim, bias=False) | |
| def forward(self, x): | |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| class DecoderBlockWithRMSNorm(nn.Module): | |
| def __init__(self, config: SmolLMConfig): | |
| super().__init__() | |
| self.config = config | |
| self.rms_1 = RMSNorm(self.config.n_embed, eps=self.config.rms_norm_eps) | |
| self.attn = CausalMultiHeadAttention(config) | |
| self.rms_2 = RMSNorm(self.config.n_embed, eps=self.config.rms_norm_eps) | |
| self.mlp = LlamaMLP(config) | |
| def forward(self, x): | |
| x = x + self.attn(self.rms_1(x)) | |
| x = x + self.mlp(self.rms_2(x)) | |
| return x | |
| class DecoderBlockWithLayerNorm(nn.Module): | |
| def __init__(self, config: SmolLMConfig): | |
| super().__init__() | |
| self.ln_1 = nn.LayerNorm(config.n_embed) | |
| self.attn = CausalMultiHeadAttention(config) | |
| self.ln_2 = nn.LayerNorm(config.n_embed) | |
| self.mlp = MLP(config) | |
| def forward(self, x): | |
| x = x + self.attn(self.ln_1(x)) | |
| x = x + self.mlp(self.ln_2(x)) | |
| return x | |
| class SmolLM(nn.Module): | |
| def __init__(self, config: SmolLMConfig): | |
| super().__init__() | |
| self.config = config | |
| self.wte = nn.Embedding( | |
| config.vocab_size, config.n_embed | |
| ) # [vocab_size, n_embd] | |
| self.wpe = nn.Embedding( | |
| config.block_size, config.n_embed | |
| ) # [max_seq_len, n_embd] | |
| self.drop = nn.Dropout(config.dropout) | |
| self.blocks = nn.ModuleList( | |
| [DecoderBlockWithRMSNorm(config) for _ in range(config.n_layers)] | |
| ) | |
| self.rms_norm = RMSNorm(config.n_embed, eps=config.rms_norm_eps) # [n_embd] | |
| self.lm_head = nn.Linear( | |
| config.n_embed, config.vocab_size, bias=False | |
| ) # [n_embd, vocab_size] | |
| # weight sharing | |
| self.wte.weight = self.lm_head.weight | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| std = 0.02 | |
| if hasattr(module, "NANGPT_SCALE_INIT"): | |
| std *= (2 * self.config.n_layers) ** -0.5 | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=std) | |
| if module.bias is not None: | |
| torch.nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| def forward(self, idx, targets=None): | |
| # idx is of shape (B, T) | |
| B, T = idx.size() | |
| assert ( | |
| T <= self.config.block_size | |
| ), f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}" | |
| pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T) | |
| pos_emb = self.wpe(pos) # position embeddings of shape (T, n_embd) | |
| x = self.wte(idx) # token embeddings of shape (B, T, n_embd) | |
| x = x + pos_emb | |
| # forward the blocks of the transformer | |
| for block in self.blocks: | |
| x = block(x) | |
| # forward the final layernorm and the classifier | |
| x = self.rms_norm(x) | |
| logits = self.lm_head(x) # (B, T, vocab_size) | |
| loss = None | |
| if targets is not None: | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | |
| return logits, loss | |
| def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): | |
| """ | |
| Generate text given a starting sequence of tokens. | |
| Args: | |
| idx (torch.Tensor): Starting token indices, shape (B, T) | |
| max_new_tokens (int): Number of tokens to generate | |
| temperature (float): Sampling temperature (1.0 = no change, < 1.0 = less random, > 1.0 = more random) | |
| top_k (int): If specified, only sample from the top k most probable tokens | |
| """ | |
| for _ in range(max_new_tokens): | |
| # if the sequence context is growing too long we must crop it at block_size | |
| idx_cond = ( | |
| idx | |
| if idx.size(1) <= self.config.block_size | |
| else idx[:, -self.config.block_size :] | |
| ) | |
| # forward the model to get the logits for the index in the sequence | |
| logits, _ = self(idx_cond) | |
| # pluck the logits at the final step and scale by desired temperature | |
| logits = logits[:, -1, :] / temperature | |
| # optionally crop the logits to only the top k options | |
| if top_k is not None: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = -float("Inf") | |
| # apply softmax to convert logits to (normalized) probabilities | |
| probs = F.softmax(logits, dim=-1) | |
| # sample from the distribution | |
| idx_next = torch.multinomial(probs, num_samples=1) | |
| # append sampled index to the running sequence | |
| idx = torch.cat((idx, idx_next), dim=1) | |
| return idx | |
| class SmolLMLightning(pl.LightningModule): | |
| def __init__(self, config: SmolLMConfig, lr, warmup_steps, max_steps): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.config = config | |
| self.model = SmolLM(self.config) | |
| self.criterion = nn.CrossEntropyLoss() | |
| self.tokenizer = tokenizer | |
| self.generation_prompt = "Once upon a time" | |
| self._generating = False | |
| def forward(self, x): | |
| return self.model(x) | |
| def training_step(self, batch, batch_idx): | |
| input_ids = batch["input_ids"] | |
| target_ids = batch["labels"] | |
| logits, _ = self(input_ids) | |
| loss = self.criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1)) | |
| # Log the loss with 4 decimal precision | |
| self.log( | |
| "train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, logger=True | |
| ) | |
| # Generate text every n steps, but only if we're not already generating | |
| if (self.global_step) % log_every_n_steps == 0 and not self._generating: | |
| self._generating = True | |
| self.generate_and_log_sample() | |
| self._generating = False | |
| return loss | |
| def generate_and_log_sample(self): | |
| """Generate and log a sample of text from the model""" | |
| try: | |
| # Encode the prompt | |
| prompt_ids = self.tokenizer.encode( | |
| self.generation_prompt, return_tensors="pt" | |
| ).to(self.device) | |
| # Generate new tokens | |
| generated_ids = self.model.generate( | |
| prompt_ids, max_new_tokens=50, temperature=0.8, top_k=40 | |
| ) | |
| # Decode the generated tokens | |
| generated_text = self.tokenizer.decode(generated_ids[0].tolist()) | |
| # Create a formatted message | |
| message = ( | |
| f"\n{'='*40}\n" | |
| f"Step {self.global_step} generation:\n" | |
| f"Prompt: {self.generation_prompt}\n" | |
| f"Generated: {generated_text}\n" | |
| f"{'='*40}\n" | |
| ) | |
| print(message) | |
| # Log to WandB | |
| if hasattr(self.logger, "experiment"): | |
| self.logger.experiment.log( | |
| {"generated_text": generated_text, "global_step": self.global_step} | |
| ) | |
| except Exception as e: | |
| print(f"Generation failed with error: {str(e)}") | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr) | |
| def lr_lambda(current_step): | |
| if current_step < self.hparams.warmup_steps: | |
| return self.hparams.lr * (current_step + 1) / self.hparams.warmup_steps | |
| elif current_step > self.hparams.max_steps: | |
| return self.hparams.lr * 0.1 | |
| decay_ratio = (current_step - self.hparams.warmup_steps) / ( | |
| self.hparams.max_steps - self.hparams.warmup_steps | |
| ) | |
| coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) | |
| return self.hparams.lr * 0.1 + coeff * ( | |
| self.hparams.lr - self.hparams.lr * 0.1 | |
| ) | |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) | |
| return [optimizer], [scheduler] | |
| if __name__ == "__main__": | |
| torch.set_float32_matmul_precision("high") | |
| dataloader = load_cosmopedia_dataset(batch_size=batch_size, seq_length=block_size) | |
| # Check if checkpoint exists | |
| checkpoint_path = "checkpoints/best-checkpoint.ckpt" | |
| if os.path.exists(checkpoint_path): | |
| print(f"Loading model from checkpoint: {checkpoint_path}") | |
| model = SmolLMLightning.load_from_checkpoint( | |
| checkpoint_path, | |
| config=SmolLMConfig(), | |
| lr=max_lr, | |
| warmup_steps=warmup_steps, | |
| max_steps=max_steps, | |
| ) | |
| else: | |
| print("Starting training from scratch") | |
| model = SmolLMLightning(SmolLMConfig(), max_lr, warmup_steps, max_steps) | |
| # Replace TensorBoard logger with WandB logger | |
| wandb_logger = WandbLogger( | |
| project="smollm", # your project name | |
| name="transformer_experiment", # name of the run | |
| log_model=True, # log model checkpoints | |
| ) | |
| os.makedirs("checkpoints", exist_ok=True) | |
| checkpoint_callback = ModelCheckpoint( | |
| dirpath="checkpoints/", | |
| filename="best-checkpoint", | |
| verbose=True, | |
| every_n_train_steps=save_checkpoints_every_n_steps, | |
| ) | |
| device = "cpu" | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| device = "mps" | |
| print(f"using device: {device}") | |
| progress_bar = RichProgressBar( | |
| refresh_rate=1, | |
| leave=False, | |
| theme=RichProgressBarTheme( | |
| description="", | |
| progress_bar="#6206E0", | |
| progress_bar_finished="#6206E0", | |
| progress_bar_pulse="#6206E0", | |
| batch_progress="", | |
| time="dim", | |
| processing_speed="dim underline", | |
| metrics="italic", | |
| metrics_text_delimiter=" ", | |
| metrics_format=".3f", | |
| ), | |
| console_kwargs=None, | |
| ) | |
| trainer = pl.Trainer( | |
| max_steps=max_steps, | |
| accelerator=device, | |
| devices=1, | |
| callbacks=[ | |
| LearningRateMonitor(logging_interval="step"), | |
| progress_bar, | |
| checkpoint_callback, | |
| ], | |
| precision="bf16-mixed", | |
| log_every_n_steps=1, | |
| enable_progress_bar=True, | |
| enable_model_summary=True, | |
| logger=wandb_logger, | |
| accumulate_grad_batches=effective_batch_size // batch_size, | |
| ) | |
| trainer.fit(model, dataloader) | |