File size: 6,559 Bytes
13c35e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from huggingface_hub import PyTorchModelHubMixin 


# ... (rest of your model code)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# --- Hyperparameters (You can adjust these later) ---
# For a "Tiny" LLM, we keep the size very small.
n_embed = 64      # C: Embedding dimension (size of the vector representing a character)
n_head = 4        # H: Number of attention heads
n_layer = 4       # Number of repeating Transformer blocks
dropout = 0.1     # Dropout rate

# --- 1. Causal Self-Attention (The "Attention is All You Need" Component) ---

class CausalSelfAttention(nn.Module):
    """A multi-head masked self-attention module."""

    def __init__(self, n_embed, n_head, block_size, dropout):
        super().__init__()
        
        self.n_embed = n_embed
        self.n_head = n_head
        self.head_size = n_embed // n_head
        
        # Combined projection for Q, K, and V (more efficient)
        self.c_attn = nn.Linear(n_embed, 3 * n_embed, bias=False)
        # Output projection
        self.c_proj = nn.Linear(n_embed, n_embed, bias=False)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

        # Causal Mask (tril = lower triangular matrix)
        # This mask prevents a token from attending to future tokens (autoregressive)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))
                                    .view(1, 1, block_size, block_size))

    def forward(self, x):
        B, T, C = x.shape  # Batch size, Sequence length (Time), Embedding dimension (Channel)

        # 1. Compute Q, K, V and split (efficiently)
        # q, k, v are (B, T, C)
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embed, dim=2) 
        
        # 2. Reshape for Multi-Head Attention (B, T, C) -> (B, H, T, Head_size)
        # We prepare the tensors so that each head processes a smaller chunk of the dimension C
        k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2) 
        q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)
        
        # 3. Scaled Dot-Product Attention: (B, H, T, T)
        # wei = (q @ k.transpose(-2, -1)) / sqrt(Head_size)
        wei = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_size))
        
        # 4. Apply Causal Mask
        # Set attention scores to -inf for future tokens (where tril == 0)
        wei = wei.masked_fill(self.tril[:,:,:T,:T] == 0, float('-inf'))
        
        # 5. Softmax and Dropout
        wei = F.softmax(wei, dim=-1)
        wei = self.attn_dropout(wei)

        # 6. Compute Weighted Sum of Values: (B, H, T, Head_size)
        out = wei @ v 

        # 7. Re-assemble heads: (B, H, T, Head_size) -> (B, T, C)
        out = out.transpose(1, 2).contiguous().view(B, T, C)

        # 8. Final Linear Projection
        out = self.resid_dropout(self.c_proj(out))
        return out

# --- 2. Feed Forward Network (FFN) ---

class FeedForward(nn.Module):
    """A two-layer MLP for processing attention output."""
    def __init__(self, n_embed, dropout):
        super().__init__()
        self.net = nn.Sequential(
            # Standard ratio is 4x the embedding size
            nn.Linear(n_embed, 4 * n_embed),
            nn.GELU(), # Modern activation function (smoother than ReLU)
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


# --- 3. Transformer Block (The Repeating Unit) ---

class TransformerBlock(nn.Module):
    """A standard Transformer decoder block with Attention and FFN."""

    def __init__(self, n_embed, n_head, block_size, dropout):
        super().__init__()
        # LayerNorm applied BEFORE the sub-layer (Pre-Norm style)
        self.ln_1 = nn.LayerNorm(n_embed) 
        self.attn = CausalSelfAttention(n_embed, n_head, block_size, dropout)
        self.ln_2 = nn.LayerNorm(n_embed)
        self.ffn = FeedForward(n_embed, dropout)

    def forward(self, x):
        # 1. Attention with Residual Connection and LayerNorm
        x = x + self.attn(self.ln_1(x))
        # 2. FFN with Residual Connection and LayerNorm
        x = x + self.ffn(self.ln_2(x))
        return x

# --- 4. The Final TinyLLM Model ---

class TinyLLM(nn.Module, PyTorchModelHubMixin):
    """The complete Decoder-Only Transformer model."""
    
    def __init__(self, vocab_size, n_embed, n_head, n_layer, block_size, dropout):
        super().__init__()
        
        self.block_size = block_size
        
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        # Positional Encoding: A fixed table for position information
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        
        # Stack of Transformer Blocks
        self.blocks = nn.Sequential(*[
            TransformerBlock(n_embed, n_head, block_size, dropout) 
            for _ in range(n_layer)
        ])
        
        self.ln_f = nn.LayerNorm(n_embed) # Final LayerNorm
        # Linear layer to map the embedding vector back to the vocabulary space
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        # idx is the input tensor X of shape (B, T)
        B, T = idx.shape
        
        # 1. Token and Positional Embeddings
        # Token embedding: (B, T, C)
        tok_emb = self.token_embedding_table(idx) 
        # Position embedding: (T, C) -> expanded to (B, T, C)
        pos = torch.arange(T, device=idx.device) 
        pos_emb = self.position_embedding_table(pos) 
        
        # 2. Combine (Add) Embeddings
        x = tok_emb + pos_emb # (B, T, C)
        
        # 3. Pass through Transformer Blocks
        x = self.blocks(x) # (B, T, C)
        
        # 4. Final LayerNorm and Linear Head
        x = self.ln_f(x)
        logits = self.lm_head(x) # (B, T, vocab_size)

        loss = None
        if targets is not None:
            # Reshape for CrossEntropyLoss: (B*T, vocab_size) and (B*T)
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            
            # Compute the negative log-likelihood loss
            loss = F.cross_entropy(logits, targets)

        return logits, loss