File size: 6,559 Bytes
13c35e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
from huggingface_hub import PyTorchModelHubMixin
# ... (rest of your model code)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# --- Hyperparameters (You can adjust these later) ---
# For a "Tiny" LLM, we keep the size very small.
n_embed = 64 # C: Embedding dimension (size of the vector representing a character)
n_head = 4 # H: Number of attention heads
n_layer = 4 # Number of repeating Transformer blocks
dropout = 0.1 # Dropout rate
# --- 1. Causal Self-Attention (The "Attention is All You Need" Component) ---
class CausalSelfAttention(nn.Module):
"""A multi-head masked self-attention module."""
def __init__(self, n_embed, n_head, block_size, dropout):
super().__init__()
self.n_embed = n_embed
self.n_head = n_head
self.head_size = n_embed // n_head
# Combined projection for Q, K, and V (more efficient)
self.c_attn = nn.Linear(n_embed, 3 * n_embed, bias=False)
# Output projection
self.c_proj = nn.Linear(n_embed, n_embed, bias=False)
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
# Causal Mask (tril = lower triangular matrix)
# This mask prevents a token from attending to future tokens (autoregressive)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))
.view(1, 1, block_size, block_size))
def forward(self, x):
B, T, C = x.shape # Batch size, Sequence length (Time), Embedding dimension (Channel)
# 1. Compute Q, K, V and split (efficiently)
# q, k, v are (B, T, C)
qkv = self.c_attn(x)
q, k, v = qkv.split(self.n_embed, dim=2)
# 2. Reshape for Multi-Head Attention (B, T, C) -> (B, H, T, Head_size)
# We prepare the tensors so that each head processes a smaller chunk of the dimension C
k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2)
q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)
# 3. Scaled Dot-Product Attention: (B, H, T, T)
# wei = (q @ k.transpose(-2, -1)) / sqrt(Head_size)
wei = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_size))
# 4. Apply Causal Mask
# Set attention scores to -inf for future tokens (where tril == 0)
wei = wei.masked_fill(self.tril[:,:,:T,:T] == 0, float('-inf'))
# 5. Softmax and Dropout
wei = F.softmax(wei, dim=-1)
wei = self.attn_dropout(wei)
# 6. Compute Weighted Sum of Values: (B, H, T, Head_size)
out = wei @ v
# 7. Re-assemble heads: (B, H, T, Head_size) -> (B, T, C)
out = out.transpose(1, 2).contiguous().view(B, T, C)
# 8. Final Linear Projection
out = self.resid_dropout(self.c_proj(out))
return out
# --- 2. Feed Forward Network (FFN) ---
class FeedForward(nn.Module):
"""A two-layer MLP for processing attention output."""
def __init__(self, n_embed, dropout):
super().__init__()
self.net = nn.Sequential(
# Standard ratio is 4x the embedding size
nn.Linear(n_embed, 4 * n_embed),
nn.GELU(), # Modern activation function (smoother than ReLU)
nn.Linear(4 * n_embed, n_embed),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
# --- 3. Transformer Block (The Repeating Unit) ---
class TransformerBlock(nn.Module):
"""A standard Transformer decoder block with Attention and FFN."""
def __init__(self, n_embed, n_head, block_size, dropout):
super().__init__()
# LayerNorm applied BEFORE the sub-layer (Pre-Norm style)
self.ln_1 = nn.LayerNorm(n_embed)
self.attn = CausalSelfAttention(n_embed, n_head, block_size, dropout)
self.ln_2 = nn.LayerNorm(n_embed)
self.ffn = FeedForward(n_embed, dropout)
def forward(self, x):
# 1. Attention with Residual Connection and LayerNorm
x = x + self.attn(self.ln_1(x))
# 2. FFN with Residual Connection and LayerNorm
x = x + self.ffn(self.ln_2(x))
return x
# --- 4. The Final TinyLLM Model ---
class TinyLLM(nn.Module, PyTorchModelHubMixin):
"""The complete Decoder-Only Transformer model."""
def __init__(self, vocab_size, n_embed, n_head, n_layer, block_size, dropout):
super().__init__()
self.block_size = block_size
self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
# Positional Encoding: A fixed table for position information
self.position_embedding_table = nn.Embedding(block_size, n_embed)
# Stack of Transformer Blocks
self.blocks = nn.Sequential(*[
TransformerBlock(n_embed, n_head, block_size, dropout)
for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embed) # Final LayerNorm
# Linear layer to map the embedding vector back to the vocabulary space
self.lm_head = nn.Linear(n_embed, vocab_size)
def forward(self, idx, targets=None):
# idx is the input tensor X of shape (B, T)
B, T = idx.shape
# 1. Token and Positional Embeddings
# Token embedding: (B, T, C)
tok_emb = self.token_embedding_table(idx)
# Position embedding: (T, C) -> expanded to (B, T, C)
pos = torch.arange(T, device=idx.device)
pos_emb = self.position_embedding_table(pos)
# 2. Combine (Add) Embeddings
x = tok_emb + pos_emb # (B, T, C)
# 3. Pass through Transformer Blocks
x = self.blocks(x) # (B, T, C)
# 4. Final LayerNorm and Linear Head
x = self.ln_f(x)
logits = self.lm_head(x) # (B, T, vocab_size)
loss = None
if targets is not None:
# Reshape for CrossEntropyLoss: (B*T, vocab_size) and (B*T)
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
# Compute the negative log-likelihood loss
loss = F.cross_entropy(logits, targets)
return logits, loss |