Girinath11's picture
Update embeddings.py
e329b2c verified
raw
history blame
19.9 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple, List
# Constants for default configuration
DEFAULT_MAX_SEQ_LEN = 512
DEFAULT_DROPOUT = 0.1
DEFAULT_BASE = 10000.0
DEFAULT_CUTOFFS = [2000, 10000]
DEFAULT_DIV_VAL = 4.0
DEFAULT_PADDING_IDX = 0
class PositionalEncoding(nn.Module):
"""Sinusoidal positional encoding for transformer models."""
def __init__(self, d_model: int, max_seq_len: int = DEFAULT_MAX_SEQ_LEN, dropout: float = DEFAULT_DROPOUT):
"""
Initialize sinusoidal positional encoding.
Args:
d_model (int): Dimension of the model embeddings.
max_seq_len (int): Maximum sequence length for positional encodings.
dropout (float): Dropout rate for regularization.
"""
super().__init__()
self.d_model = d_model
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(DEFAULT_BASE) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term[:, :-1] if d_model % 2 == 1 else div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply positional encoding to input embeddings.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
Returns:
torch.Tensor: Tensor with positional encodings applied.
"""
batch_size, seq_len, d_model = x.size()
if d_model != self.d_model:
raise ValueError(f"Input dimension {d_model} does not match d_model {self.d_model}")
x = x + self.pe[:, :seq_len]
return self.dropout(x)
class LearnedPositionalEmbedding(nn.Module):
"""Learned positional embeddings for transformer models."""
def __init__(self, max_seq_len: int, d_model: int, dropout: float = DEFAULT_DROPOUT):
"""
Initialize learned positional embeddings.
Args:
max_seq_len (int): Maximum sequence length.
d_model (int): Dimension of the model embeddings.
dropout (float): Dropout rate for regularization.
"""
super().__init__()
self.max_seq_len = max_seq_len
self.d_model = d_model
self.pos_embedding = nn.Embedding(max_seq_len, d_model)
self.dropout = nn.Dropout(dropout)
nn.init.normal_(self.pos_embedding.weight, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply learned positional embeddings to input.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
Returns:
torch.Tensor: Tensor with positional embeddings applied.
"""
batch_size, seq_len, d_model = x.size()
if seq_len > self.max_seq_len:
raise ValueError(f"Sequence length {seq_len} exceeds maximum {self.max_seq_len}")
if d_model != self.d_model:
raise ValueError(f"Input dimension {d_model} does not match d_model {self.d_model}")
positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
pos_emb = self.pos_embedding(positions)
x = x + pos_emb
return self.dropout(x)
class RotaryPositionalEmbedding(nn.Module):
"""Rotary Positional Embedding (RoPE) for transformer models."""
def __init__(self, d_model: int, max_seq_len: int = 2048, base: float = DEFAULT_BASE):
"""
Initialize rotary positional embeddings.
Args:
d_model (int): Dimension of the model embeddings.
max_seq_len (int): Maximum sequence length.
base (float): Base for frequency calculation.
"""
super().__init__()
self.d_model = d_model
self.max_seq_len = max_seq_len
self.base = base
inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
self.register_buffer('inv_freq', inv_freq)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
"""Update cached cosine and sine values for RoPE."""
if seq_len > self._seq_len_cached:
self._seq_len_cached = seq_len
t = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq)
self._cos_cached = freqs.cos().to(dtype)
self._sin_cached = freqs.sin().to(dtype)
def _rotate_half(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
"""Apply rotary transformation to half of the tensor."""
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
def forward(self, q: torch.Tensor, k: torch.Tensor, start_pos: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary positional embeddings to query and key tensors.
Args:
q (torch.Tensor): Query tensor of shape (batch_size, seq_len, num_heads, head_dim).
k (torch.Tensor): Key tensor of shape (batch_size, seq_len, num_heads, head_dim).
start_pos (int): Starting position for positional encoding.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors.
"""
batch_size, seq_len, num_heads, head_dim = q.shape
self._update_cos_sin_cache(start_pos + seq_len, q.device, q.dtype)
cos = self._cos_cached[start_pos:start_pos + seq_len, :head_dim // 2].view(1, seq_len, 1, -1)
sin = self._sin_cached[start_pos:start_pos + seq_len, :head_dim // 2].view(1, seq_len, 1, -1)
q = q.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
k = k.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
q_rot = self._rotate_half(q, cos, sin)
k_rot = self._rotate_half(k, cos, sin)
q_rot = q_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
k_rot = k_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
return q_rot, k_rot
class TechEmbeddingLayer(nn.Module):
"""Comprehensive embedding layer with token and positional embeddings."""
def __init__(
self,
vocab_size: int,
d_model: int,
max_seq_len: int = DEFAULT_MAX_SEQ_LEN,
dropout: float = DEFAULT_DROPOUT,
padding_idx: int = DEFAULT_PADDING_IDX,
pos_encoding: str = "learned",
layer_norm: bool = True,
):
"""
Initialize the embedding layer.
Args:
vocab_size (int): Size of the vocabulary.
d_model (int): Dimension of the model embeddings.
max_seq_len (int): Maximum sequence length.
dropout (float): Dropout rate.
padding_idx (int): Index for padding token.
pos_encoding (str): Type of positional encoding ('sinusoidal', 'learned', 'rope').
layer_norm (bool): Whether to apply layer normalization.
"""
super().__init__()
self.d_model = d_model
self.vocab_size = vocab_size
self.padding_idx = padding_idx
self.pos_encoding_type = pos_encoding.lower()
self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
if pos_encoding == "sinusoidal":
self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
elif pos_encoding == "learned":
self.pos_encoding = LearnedPositionalEmbedding(max_seq_len, d_model, dropout)
elif pos_encoding == "rope":
self.pos_encoding = RotaryPositionalEmbedding(d_model, max_seq_len)
else:
raise ValueError(f"Unknown positional encoding type: {pos_encoding}")
self.layer_norm = nn.LayerNorm(d_model) if layer_norm else nn.Identity()
self.dropout = nn.Dropout(dropout)
self._init_weights()
def _init_weights(self) -> None:
"""Initialize weights for token embeddings."""
nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
if self.padding_idx is not None:
nn.init.constant_(self.token_embedding.weight[self.padding_idx], 0.0)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
Forward pass for embedding layer.
Args:
input_ids (torch.Tensor): Input tensor of shape (batch_size, seq_len).
Returns:
torch.Tensor: Embedded tensor of shape (batch_size, seq_len, d_model).
"""
if (input_ids >= self.vocab_size).any():
raise ValueError(f"Input IDs contain values >= vocab_size ({self.vocab_size})")
embeddings = self.token_embedding(input_ids)
if self.pos_encoding_type != "rope":
embeddings = self.pos_encoding(embeddings)
embeddings = self.layer_norm(embeddings)
return self.dropout(embeddings)
def get_positional_encoding(self) -> Optional[nn.Module]:
"""Return the positional encoding module if RoPE, else None."""
return self.pos_encoding if self.pos_encoding_type == "rope" else None
class AdaptiveEmbedding(nn.Module):
"""Adaptive embedding layer with variable embedding dimensions."""
def __init__(
self,
vocab_size: int,
d_model: int,
cutoffs: List[int] = DEFAULT_CUTOFFS,
div_val: float = DEFAULT_DIV_VAL,
):
"""
Initialize adaptive embedding layer.
Args:
vocab_size (int): Size of the vocabulary.
d_model (int): Dimension of the model embeddings.
cutoffs (List[int]): Cutoff points for vocabulary splits.
div_val (float): Division factor for embedding dimensions.
"""
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.cutoffs = [0] + cutoffs + [vocab_size]
self.div_val = div_val
self.embeddings = nn.ModuleList()
self.projections = nn.ModuleList()
for i in range(len(self.cutoffs) - 1):
l_idx, r_idx = self.cutoffs[i], self.cutoffs[i + 1]
d_emb = int(d_model / (div_val ** i))
emb = nn.Embedding(r_idx - l_idx, d_emb)
nn.init.normal_(emb.weight, mean=0.0, std=0.02)
self.embeddings.append(emb)
self.projections.append(
nn.Linear(d_emb, d_model, bias=False) if d_emb != d_model else nn.Identity()
)
if d_emb != d_model:
nn.init.normal_(self.projections[-1].weight, mean=0.0, std=0.02)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
Forward pass for adaptive embedding.
Args:
input_ids (torch.Tensor): Input tensor of shape (batch_size, seq_len).
Returns:
torch.Tensor: Embedded tensor of shape (batch_size, seq_len, d_model).
"""
if (input_ids >= self.vocab_size).any():
raise ValueError(f"Input IDs contain values >= vocab_size ({self.vocab_size})")
batch_size, seq_len = input_ids.shape
embeddings = torch.zeros(batch_size, seq_len, self.d_model, device=input_ids.device, dtype=torch.float32)
for i in range(len(self.cutoffs) - 1):
l_idx, r_idx = self.cutoffs[i], self.cutoffs[i + 1]
mask = (input_ids >= l_idx) & (input_ids < r_idx)
if mask.any():
indices = (input_ids[mask] - l_idx).clamp(max=r_idx - l_idx - 1)
emb = self.embeddings[i](indices)
embeddings[mask] = self.projections[i](emb)
return embeddings
def create_padding_mask(input_ids: torch.Tensor, padding_idx: int = DEFAULT_PADDING_IDX) -> torch.Tensor:
"""
Create a padding mask for input IDs.
Args:
input_ids (torch.Tensor): Input tensor of shape (batch_size, seq_len).
padding_idx (int): Index for padding token.
Returns:
torch.Tensor: Padding mask of shape (batch_size, seq_len).
"""
return input_ids == padding_idx
def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
"""
Create a causal mask for attention.
Args:
seq_len (int): Sequence length.
device (torch.device): Device for tensor allocation.
Returns:
torch.Tensor: Causal mask of shape (seq_len, seq_len).
"""
return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
def create_attention_mask(input_ids: torch.Tensor, padding_idx: int = DEFAULT_PADDING_IDX, causal: bool = True) -> torch.Tensor:
"""
Create an attention mask combining padding and causal masks.
Args:
input_ids (torch.Tensor): Input tensor of shape (batch_size, seq_len).
padding_idx (int): Index for padding token.
causal (bool): Whether to include causal masking.
Returns:
torch.Tensor: Attention mask of shape (batch_size, seq_len, seq_len).
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
padding_mask = create_padding_mask(input_ids, padding_idx).unsqueeze(1).expand(batch_size, seq_len, seq_len)
if causal:
causal_mask = create_causal_mask(seq_len, device).unsqueeze(0).expand(batch_size, seq_len, seq_len)
return padding_mask | causal_mask
return padding_mask
class EmbeddingAnalyzer:
"""Analyzer for inspecting embedding layer properties."""
def __init__(self, embedding_layer: nn.Module):
"""
Initialize the embedding analyzer.
Args:
embedding_layer (nn.Module): The embedding layer to analyze.
"""
self.embedding_layer = embedding_layer
def get_similarity_matrix(self, tokens: Optional[List[int]] = None) -> torch.Tensor:
"""
Compute the cosine similarity matrix for embeddings.
Args:
tokens (Optional[List[int]]): List of token IDs to compute similarities for.
Returns:
torch.Tensor: Cosine similarity matrix.
"""
if hasattr(self.embedding_layer, 'token_embedding'):
embeddings = self.embedding_layer.token_embedding.weight
elif hasattr(self.embedding_layer, 'embeddings'):
embeddings = torch.cat(
[self.embedding_layer.projections[i](emb.weight) for i, emb in enumerate(self.embedding_layer.embeddings)],
dim=0
)
else:
embeddings = self.embedding_layer.weight
if tokens is not None and len(tokens) > 0:
embeddings = embeddings[tokens]
return torch.mm(F.normalize(embeddings, p=2, dim=1), F.normalize(embeddings, p=2, dim=1).t())
def find_similar_tokens(self, token_id: int, top_k: int = 10) -> List[Tuple[int, float]]:
"""
Find the top-k most similar tokens to a given token ID.
Args:
token_id (int): Token ID to find similar tokens for.
top_k (int): Number of similar tokens to return.
Returns:
List[Tuple[int, float]]: List of (token_id, similarity_score) pairs.
"""
similarity_matrix = self.get_similarity_matrix()
if token_id >= similarity_matrix.shape[0]:
raise ValueError(f"Token ID {token_id} is out of range")
similarities = similarity_matrix[token_id]
top_similarities, top_indices = torch.topk(similarities, top_k + 1)
mask = top_indices != token_id
return list(zip(top_indices[mask][:top_k].tolist(), top_similarities[mask][:top_k].tolist()))
def analyze_embedding_distribution(self) -> dict:
"""
Analyze the statistical properties of the embedding weights.
Returns:
dict: Dictionary containing mean, std, min, max, norm_mean, and norm_std of embeddings.
"""
if hasattr(self.embedding_layer, 'token_embedding'):
weights = self.embedding_layer.token_embedding.weight
elif hasattr(self.embedding_layer, 'embeddings'):
weights = torch.cat([emb.weight for emb in self.embedding_layer.embeddings], dim=0)
else:
weights = self.embedding_layer.weight
return {
'mean': weights.mean().item(),
'std': weights.std().item(),
'min': weights.min().item(),
'max': weights.max().item(),
'norm_mean': weights.norm(dim=1).mean().item(),
'norm_std': weights.norm(dim=1).std().item(),
}
def test_embeddings() -> None:
"""Test the embedding layers and related utilities."""
print("Starting embedding layer tests...")
vocab_size = 1000
d_model = 512
max_seq_len = 128
batch_size = 4
seq_len = 64
input_ids = torch.randint(1, vocab_size, (batch_size, seq_len))
embedding_types = [
("Learned Position", "learned"),
("Sinusoidal Position", "sinusoidal"),
("RoPE", "rope"),
]
for name, pos_type in embedding_types:
print(f"\nTesting {name} Embedding:")
embedding_layer = TechEmbeddingLayer(
vocab_size=vocab_size,
d_model=d_model,
max_seq_len=max_seq_len,
pos_encoding=pos_type,
)
embeddings = embedding_layer(input_ids)
assert embeddings.shape == (batch_size, seq_len, d_model), f"Unexpected shape for {name}: {embeddings.shape}"
print(f"Input shape: {input_ids.shape}")
print(f"Output shape: {embeddings.shape}")
print(f"Expected shape: ({batch_size}, {seq_len}, {d_model})")
analyzer = EmbeddingAnalyzer(embedding_layer)
stats = analyzer.analyze_embedding_distribution()
print(f"Embedding statistics:")
for key, value in stats.items():
print(f" {key}: {value:.4f}")
# Test similarity for a sample token
similar_tokens = analyzer.find_similar_tokens(token_id=0, top_k=5)
print(f"Top 5 similar tokens to token 0: {similar_tokens}")
print("\nTesting Adaptive Embeddings:")
adaptive_emb = AdaptiveEmbedding(vocab_size=vocab_size, d_model=d_model, cutoffs=[200, 500], div_val=2.0)
embeddings = adaptive_emb(input_ids)
assert embeddings.shape == (batch_size, seq_len, d_model), f"Unexpected adaptive embedding shape: {embeddings.shape}"
print(f"Adaptive embedding output shape: {embeddings.shape}")
print("\nTesting masking functions:")
input_ids_padded = input_ids.clone()
input_ids_padded[:, -10:] = 0
padding_mask = create_padding_mask(input_ids_padded, padding_idx=0)
causal_mask = create_causal_mask(seq_len, input_ids.device)
attention_mask = create_attention_mask(input_ids_padded, padding_idx=0, causal=True)
assert padding_mask.shape == (batch_size, seq_len), f"Unexpected padding mask shape: {padding_mask.shape}"
assert causal_mask.shape == (seq_len, seq_len), f"Unexpected causal mask shape: {causal_mask.shape}"
assert attention_mask.shape == (batch_size, seq_len, seq_len), f"Unexpected attention mask shape: {attention_mask.shape}"
print(f"Padding mask shape: {padding_mask.shape}")
print(f"Causal mask shape: {causal_mask.shape}")
print(f"Attention mask shape: {attention_mask.shape}")
print(f"Padding positions: {padding_mask.sum().item()}")
print(f"Causal mask positions: {causal_mask.sum().item()}")
print(f"Combined mask positions: {attention_mask.sum().item()}")
print("\nAll embedding tests completed successfully!")
if __name__ == "__main__":
test_embeddings()