import time import torch import torch.nn as nn import torch.nn.functional as F import math from torch.nn import RMSNorm from config import ModelArgs from tokenizer import Tokenizer # Initialize tokenizer globally as None - will be set later tokenizer = None model_args = ModelArgs() def initialize_tokenizer(hf_token=None): """Initialize the global tokenizer with the provided HF token""" global tokenizer if tokenizer is None: tokenizer_instance = Tokenizer(hf_token=hf_token) tokenizer = tokenizer_instance.ready_tokenizer() return tokenizer class Normalization(nn.Module): def __init__( self, embeddings_dims: int = model_args.embeddings_dims ): super().__init__() self.rmsnorm_layer = RMSNorm(embeddings_dims) def forward(self, x): x = self.rmsnorm_layer(x) return x class Swish(nn.Module): def __init__( self, block_size: int = model_args.block_size, embeddings_dims: int = model_args.embeddings_dims, device = model_args.device ): super().__init__() self.sig = torch.nn.Sigmoid() def forward(self, x): swish = x * self.sig(x) return swish class SWiGLUExpertMoE(nn.Module): def __init__( self, block_size: int = model_args.block_size, embeddings_dims: int = model_args.embeddings_dims, device = model_args.device ): super().__init__() self.hidden_dims = (embeddings_dims * 2) self.swish = Swish(block_size=block_size, embeddings_dims=embeddings_dims, device=device) self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, device = device) self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, device = device) self.linear_layer3 = nn.Linear(in_features=self.hidden_dims, out_features=embeddings_dims, bias=False, device = device) def forward(self, x): swish_res = self.swish(self.linear_layer1(x)) x_V = self.linear_layer2(x) res = torch.mul(swish_res, x_V) out = self.linear_layer3(res) return out class MoeLayer(nn.Module): def __init__( self, dropout = model_args.dropout, embeddings_size = model_args.embeddings_dims, device = model_args.device, # inner_dimensional_states: int = 3072 ): super().__init__() self.heads = nn.ModuleList([SWiGLUExpertMoE() for _ in range(model_args.experts)]) self.gate = nn.Linear(in_features=embeddings_size, out_features=model_args.experts, device=device, bias=False) # Only create shared expert if enabled if model_args.use_shared_expert: self.shared_expert = SWiGLUExpertMoE() else: self.shared_expert = None if(model_args.noisy_topk is True and model_args.use_checkpointing == False): self.noise = nn.Linear(in_features=embeddings_size, out_features=model_args.experts, device=device, bias=False) self.noisy_router = None # self.outputs = torch.zeros((batch_size,block_size, embeddings_size), device=device) #batch size needs to be defined because we are accessing it explicitly self.device = device # self.shared_expert_out = torch.zeros((model_args.batch_size, model_args.embeddings_dims), device=device) # self.b = torch.zeros((model_args.batch_size, model_args.block_size, model_args.experts), device=device) if model_args.useauxFreeLoadBalancingLoss: self.register_buffer('routing_bias', torch.zeros(model_args.experts, device=self.device)) # self.routing_bias = torch.zeros(model_args.experts, device=self.device) self.bias_update_speed = model_args.aux_free_bias_update_rate def forward(self, x): # mlp_weights_init = self.mlp.apply(weights_init) self.gate_out = self.gate(x) #[bz, seq, num_experts] if(model_args.noisy_topk == True and model_args.use_checkpointing == False): noise = self.noise(x) gaussian_noise = torch.normal(0, 1, size=self.gate_out.shape, device=self.device) self.noisy_router = F.softplus(noise) * gaussian_noise self.gate_out += self.noisy_router shared_output = 0 out = 0 if model_args.useauxFreeLoadBalancingLoss: self.gate_out += self.routing_bias # Adjust top_k based on whether shared expert is used top_k = model_args.top_experts top_k_values, top_k_indices = torch.topk(self.gate_out, k=top_k) #[bs, seq len, top k] # topkmask = torch.ones_like(top_k_values, device=self.device) # [bs, seq len, experts] # indices = torch.arange(top_k_values.size(0), device=self.device).unsqueeze(1).unsqueeze(2) # [bs, 1, 1] # topkvaluesMasked = top_k_values.masked_fill(indices != top_k_indices, float('-inf')) # Mask out negative values masked = torch.full_like(self.gate_out, float('-1e20'), device=self.device) masked_values = masked.scatter_(-1, top_k_indices, top_k_values) probs = torch.nn.functional.softmax(masked_values, dim=-1) #[bs, seq len, top k] out = torch.zeros_like(x) if model_args.use_shared_expert and self.shared_expert is not None: shared_output += self.shared_expert(x) flat_x = x.view(-1, x.size(-1)) # Flatten the input for easier processing for i in range(model_args.experts): # Iterate through each expert index (0 to num_experts-1) # Determine which tokens routed to this expert 'i' # top_k_indices is [bs, seq_len, self.top_k] # We want a mask of shape [bs, seq_len] where True if expert 'i' is in the top_k for that token expert_i_is_chosen_mask = (top_k_indices == i).any(dim=-1) # Check along the top_k dimension # expert_i_is_chosen_mask has shape [bs, seq_len] if not expert_i_is_chosen_mask.any(): # If expert 'i' was not chosen by any token continue # Flatten the mask to apply to flat_x flat_expert_i_is_chosen_mask = expert_i_is_chosen_mask.reshape(-1) # Shape: [bs * seq_len] # Select input tokens for this expert selected_input_tokens = flat_x[flat_expert_i_is_chosen_mask] # Shape: [num_active_for_expert_i, embed_dim] if selected_input_tokens.numel() == 0: # Should be caught by .any() above, but good check continue # Process through the expert expert_output_for_selected = self.heads[i](selected_input_tokens) # Get the routing probabilities for these chosen tokens specifically for expert 'i' # routing_probs is [bs, seq_len, num_experts] # expert_i_probs_original_shape = routing_probs[:, :, i] # Probabilities for expert 'i', shape [bs, seq_len] # flat_expert_i_probs = expert_i_probs_original_shape.reshape(-1) # Shape [bs * seq_len] # active_token_weights = flat_expert_i_probs[flat_expert_i_is_chosen_mask] # Shape: [num_active_for_expert_i] # Alternative way to get weights directly using the mask on routing_probs for expert i: # Get the [bs, seq_len] slice of probabilities for the current expert 'i' probs_for_expert_i = probs[:, :, i] # Shape: [bs, seq_len] # Now use the expert_i_is_chosen_mask (which is also [bs, seq_len]) to select the relevant weights active_token_weights = probs_for_expert_i[expert_i_is_chosen_mask] # Shape: [num_active_for_expert_i] weighted_expert_output = expert_output_for_selected * active_token_weights.unsqueeze(-1) # Add this expert's contribution temp_contribution_for_expert_i = torch.zeros_like(x) # Initialize with zeros temp_contribution_for_expert_i.masked_scatter_( expert_i_is_chosen_mask.unsqueeze(-1).expand_as(x), # Use the original 2D mask, expanded weighted_expert_output ) out = out + temp_contribution_for_expert_i # for expert_idx in range(model_args.experts): # # Create mask for current expert across all top_k positions # expert_mask = (top_k_indices == expert_idx) # # Sum probabilities for current expert # expert_weights = (probs * expert_mask).sum(dim=-1) # [batch, seq_len] # # Get inputs where expert is used # selected = expert_weights > 0 # if not selected.any(): # continue # # print(expert_weights.shape) # # print(x[selected].shape) # # Process all selected inputs through expert # expert_out = self.heads[expert_idx](x[selected]) # # Weight and accumulate outputs # out[selected] += expert_out * expert_weights[selected].unsqueeze(-1) out = out + shared_output # Add shared expert output if enabled if model_args.useauxFreeLoadBalancingLoss and self.training: with torch.no_grad(): ci = probs.sum(dim=(0,1)) # Su of tokens for each expert ci_avg = ci.mean() error_i = ci_avg - ci self.update = self.bias_update_speed * torch.sign(error_i) # Update routing bias self.routing_bias.add_(self.update) # self.routing_bias = self.routing_bias + self.update return out # import numpy as np class SinusoidalPositionalEmbeddings(nn.Module): def __init__( self, device, embeddings_dims: int = model_args.embeddings_dims, block_size: int = model_args.block_size, batch_size: int = model_args.batch_size, ): super().__init__() self.embeddings_dims = embeddings_dims self.block_size = block_size self.batch_size = batch_size self.device = device # Create positional encoding matrix pe = torch.zeros(block_size, embeddings_dims) position = torch.arange(0, block_size, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, embeddings_dims, 2).float() * (-math.log(10000.0) / embeddings_dims)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) # Register as buffer so it's not a parameter but moves with the model self.register_buffer('pe', pe.unsqueeze(0)) # Shape: [1, block_size, embeddings_dims] def forward(self, x): # x shape: [batch_size, seq_len, embeddings_dims] batch_size, seq_len, _ = x.shape # Add positional embeddings # pe[:, :seq_len] ensures we only use the positional embeddings up to the sequence length pos_emb = self.pe[:, :seq_len].to(x.device) return pos_emb class LatentAttention(nn.Module): def __init__( self, attn_dropout = model_args.attn_dropout, embeddings_dims = model_args.embeddings_dims, no_of_heads = model_args.no_of_heads, device = model_args.device ): super().__init__() self.head_size = embeddings_dims // no_of_heads self.no_of_heads = no_of_heads # if(model_args.use_flash_attention==False): self.latent_dim = model_args.latent_dim self.W_k = nn.Linear(in_features=self.latent_dim, out_features=self.head_size, device=device, bias=False) self.W_v = nn.Linear(in_features=self.latent_dim, out_features=self.head_size, device=device, bias=False) self.W_dkv = nn.Linear(in_features=model_args.embeddings_dims, out_features=self.latent_dim, device=device, bias=False) # 3 for query, key and value self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=model_args.device, bias=False) # self.keys = nn.Linear(in_features=embeddings_dims, out_features=self.head_size,device=model_args.device, bias=False) # self.values = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=model_args.device,bias=False) # self.dropout = nn.Dropout(p = attn_dropout) self.dropout = nn.Dropout(p = attn_dropout) self.device = device # Use sinusoidal positional embeddings instead of rotary self.pos_embeddings = SinusoidalPositionalEmbeddings(embeddings_dims=self.head_size, device=device) # self.register_buffer('absorbed_q', None) # self.absorbed_q = None def forward(self, x, kv_cache=None, mask=None): batch_size, block_size, embd_dims = x.shape # k = self.keys(x) # q = self.query(x) # v = self.values(x) self.latent_matrix = self.W_dkv(x) # print("q shape: ", q.shape) # print("Shape of latent mat: ", self.query.weight.shape) # print("Shape of compressed_k: ", self.W_k.weight.shape) # if(self.absorbed_q is None): self.absorbed_q = torch.matmul(self.query.weight.T , self.W_k.weight) # weights = q @ torch.transpose(k, dim0=-2, dim1=-1) * (k.shape[-1] ** -0.5) # if kv_cache is None: if kv_cache is None: kv_cache = self.latent_matrix else: # print(kv_cache) # print("Shape of latent matrix: ", self.latent_matrix.shape) # print("Shape of kv_cache: ", kv_cache.shape) kv_cache = torch.cat([kv_cache, self.latent_matrix], dim=1) self.compressed_k = self.W_k(kv_cache) self.compressed_v = self.W_v(kv_cache) q_res = torch.matmul(x , self.absorbed_q) weights = q_res @ torch.transpose(kv_cache, dim0=-2, dim1=-1) * (self.head_size ** -0.5) # [batch_size, block_size, block_size] # print("Shape of weights: ", weights.shape) # print("Shape of kv_cache: ", kv_cache.shape) if(mask is not None): weights = weights.masked_fill(mask == 0, float('-1e20')) #Masking the attention weights masked_table = torch.tril(torch.ones(q_res.shape[1], kv_cache.shape[1], device=model_args.device)) masked_values = weights.masked_fill(masked_table[: q_res.shape[1], : kv_cache.shape[1]] == 0, float('-1e20')) weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens weights_normalized = self.dropout(weights_normalized) # print("Shape of weights_normalized: ", weights_normalized.shape) # Apply positional embeddings to the output # print("Shape of compressed_v: ", self.compressed_v.shape) out = weights_normalized @ self.compressed_v # out = self.pos_embeddings(out) return out, kv_cache # MHA class MHLA(nn.Module): def __init__( self, device, attn_dropout = model_args.attn_dropout, embeddings_dims = model_args.embeddings_dims, no_of_heads = model_args.no_of_heads, ): super().__init__() self.heads = nn.ModuleList([LatentAttention(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads) for _ in range(no_of_heads)]) self.dropout = nn.Dropout(p = attn_dropout) self.linear = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=device, bias=False) # 12 (no of heads) * (batch_size) 64 = 768 -> gives out the text embeddings def forward(self, x, kv_cache=None, mask=None): # concat = torch.cat([head(x, kv_cache=kv_cache, mask=mask) for head in self.heads], dim=-1) res = [] for head in self.heads: head_out, kv_cache = head(x, kv_cache=kv_cache, mask=mask) res.append(head_out) concat = torch.cat(res, dim=-1) # Concatenate along the last dimension linear_layer = self.linear(concat) out = self.dropout(linear_layer) return out, kv_cache class FFN(nn.Module): def __init__(self, device, embeddings_dims: int = model_args.embeddings_dims, block_size: int = model_args.block_size, vocab_size: int = model_args.vocab_size, dropout = model_args.dropout ): super().__init__() self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, device = device) self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, device = device) self.dropout = nn.Dropout(p = dropout) # Uncommenting the dropout line def forward(self, x): x = self.linear_layer(x) x = F.gelu(x) x = self.linear_layer2(x) x = F.gelu(x) # x = self.dropout(x) # Uncommenting the dropout line return x class DecoderLayer(nn.Module): def __init__(self, device, attn_dropout: float = model_args.attn_dropout, no_of_heads: int = model_args.no_of_heads, embeddings_dims: int = model_args.embeddings_dims, dropout = model_args.dropout, block_size: int = model_args.block_size, vocab_size: int = model_args.vocab_size, ) : super().__init__() # self.base_freq = model_args.base_freq # self.feedforward_network = FFN(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, device = device) self.mha = MHLA(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads, device=device) self.layer_norm1 = Normalization(embeddings_dims=embeddings_dims) self.layer_norm2 = Normalization(embeddings_dims=embeddings_dims) # self.layer_norm3 = Normalization(embeddings_dims=embeddings_dims) self.dropout = nn.Dropout(p = dropout) self.moe_block = MoeLayer(dropout=dropout, embeddings_size=embeddings_dims) def forward(self, x, kv_cache=None, ffn=None, mask=None): out, kv_cache = self.mha(self.layer_norm1(x), kv_cache=kv_cache, mask=mask) #Very important step -> Layer Norm on input and then passes it to the subsequent blocks x = x + out # Fixed: removed in-place operation x = x + self.moe_block(self.layer_norm2(x)) #Very important step return x, kv_cache class Block(nn.Module): def __init__(self, device, embeddings_dims: int = model_args.embeddings_dims, no_of_decoder_layers: int = model_args.no_of_decoder_layers, block_size: int = model_args.block_size, vocab_size: int = model_args.vocab_size, dropout = model_args.dropout ) : super().__init__() self.base_freq = model_args.base_freq # self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims, dtype=torch.float32, device = device) self.decoder = nn.ModuleList(DecoderLayer(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, dropout=dropout, device = device) for _ in range(no_of_decoder_layers)) # self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, dtype=torch.float32, device = device) self.dropout = nn.Dropout(p = dropout) self.norm = Normalization(embeddings_dims) #weight tying # self.embeddings.weight = self.linear_layer.weight self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, x, mask=None, actual_labels = None, inference=False): index = 0 no_of_layers = 0 # x = self.embeddings(x) # # x = self.dropout(x) # if(mask is not None): kv_cache = None # x = x * mask # # mask = mask.unsqueeze(-1) # x = self.decoder(x) for layer in self.decoder: # if no_of_layers % 2 == 0: # if no_of_layers % 4 == 0: # # print("x shape: ", x.shape) # x = layer(x, rope=False, ffn=True, mask=mask) # x = layer(x, rope=True, ffn=True, mask=mask) # # print("x shape: ", x.shape) # else: # # print("x shape local: ", x.shape) # if no_of_layers % 4 == 0: # # print("x shape: ", x.shape) # x = layer(x, rope=False, ffn=False, mask=mask) x, kv_cache = layer(x, kv_cache=kv_cache, ffn=None, mask=mask) # print("x shape local: ", x.shape) # no_of_layers += 1 # print(x.shape) x = self.dropout(x) x = 2 * ((model_args.no_of_decoder_layers) ** -0.5) * x x = self.norm(x) # if(inference): # out = self.linear_layer(x) # return out # if(model_args.use_liger): # # print("yo") # y = x.contiguous().view(-1, model_args.embeddings_dims) # if(actual_labels is not None): # labels = actual_labels.contiguous().view(-1) # # Pass linear layer weights FIRST as required [2][5] # # ignore_index is already set during initialization # loss = self.le_loss(self.linear_layer.weight, y, labels) # return loss # else: # # print("Hi") # out = self.linear_layer(x) # return out return x class DeepSeekV3(nn.Module): def __init__(self, device, embeddings_dims: int = model_args.embeddings_dims, block_size: int = model_args.block_size, vocab_size: int = model_args.vocab_size, dropout = model_args.dropout ): super().__init__() self.decoder = Block(device=device, embeddings_dims=embeddings_dims, no_of_decoder_layers=model_args.no_of_decoder_layers, block_size=block_size, vocab_size=vocab_size, dropout=dropout) self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims, dtype=torch.float32, device=device) self.pos_embeddings = SinusoidalPositionalEmbeddings(embeddings_dims=embeddings_dims, device=device) self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, dtype=torch.float32, device=device, bias=False) # Weight tying - tie embedding and output projection weights self.embedding.weight = self.linear_layer.weight # Initialize the LigerFusedLinearCrossEntropyLoss for optimized training if model_args.use_liger: from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss # Initialize with ignore_index for padding tokens if enabled if model_args.ignore_pad_token_in_loss: self.le_loss = LigerFusedLinearCrossEntropyLoss( ignore_index=tokenizer.pad_token_id ) else: self.le_loss = LigerFusedLinearCrossEntropyLoss() def forward(self, x, inference=False, mask=None): if(mask is not None): x = x * mask x = self.embedding(x) x = x + self.pos_embeddings(x) # Add positional embeddings B, T, C = x.shape if inference: # For inference, we only need the last token prediction decoder_out = self.decoder(x, mask=mask) logits = self.linear_layer(decoder_out) return logits else: decoder_out = self.decoder(x, mask=mask) logits = self.linear_layer(decoder_out) return logits