API Reference
Module: src.models.encoder
Class: ByteLatentEncoder
Converts byte sequences into latent patches with positional embeddings.
class ByteLatentEncoder(nn.Module):
def __init__(
self,
d_model: int = 512,
patch_size: int = 4,
dropout: float = 0.1
)
Parameters:
d_model(int): Latent dimension sizepatch_size(int): Number of bytes per patchdropout(float): Dropout probability
Methods:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (Batch, Seq_Len) - Input bytes [0-255]
Returns:
(Batch, Num_Patches, d_model) - Latent patches
"""
Module: src.models.layers
Class: LinearAttention
$O(N)$ causal attention using ELU feature maps.
class LinearAttention(nn.Module):
def __init__(
self,
d_model: int,
num_heads: int = 8,
dropout: float = 0.1
)
Methods:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (Batch, Seq_Len, d_model)
Returns:
(Batch, Seq_Len, d_model)
"""
Algorithm:
Q, K, V = elu(Wq x) + 1, elu(Wk x) + 1, Wv x
Attention = (Q @ cumsum(K ⊗ V)) / (Q @ cumsum(K) + ε)
Class: SlidingWindowAttention
Causal attention with fixed window size.
class SlidingWindowAttention(nn.Module):
def __init__(
self,
d_model: int,
num_heads: int,
window_size: int
)
Parameters:
window_size(int): Maximum distance for attention (default: 128)
Class: HybridBlock
Combines LinearAttention + SlidingWindowAttention in parallel.
class HybridBlock(nn.Module):
def __init__(
self,
d_model: int,
num_heads: int,
window_size: int,
dropout: float
)
Methods:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (Batch, Seq_Len, d_model)
Returns:
(Batch, Seq_Len, d_model)
Algorithm:
attn_out = SlidingWindowAttention(norm(x))
ssm_out = LinearAttention(norm(x))
x = x + out_proj(attn_out + ssm_out)
x = x + MLP(norm(x))
"""
Module: src.models.reasoning
Class: RecurrentReasoningBlock
System 2 thinking loop with gated residual updates.
class RecurrentReasoningBlock(nn.Module):
def __init__(
self,
d_model: int,
thinking_steps: int = 3,
dropout: float = 0.1
)
Parameters:
thinking_steps(int): Number of refinement iterations
Methods:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (Batch, Seq_Len, d_model) - Initial latent
Returns:
(Batch, Seq_Len, d_model) - Refined latent
Algorithm:
for t in range(thinking_steps):
update = MLP(norm(x))
gate = sigmoid(W_gate @ norm(x))
x = x + gate * update
"""
Module: src.models.agiformer
Class: LocalAutoregressiveHead
GRU-based byte decoder with teacher forcing.
class LocalAutoregressiveHead(nn.Module):
def __init__(
self,
d_model: int,
patch_size: int,
hidden_dim: int = 256
)
Methods:
def forward(
self,
latents: torch.Tensor,
target_bytes: Optional[torch.Tensor] = None,
temperature: float = 0.0
) -> torch.Tensor:
"""
Args:
latents: (Batch, Num_Patches, d_model)
target_bytes: (Batch, Num_Patches * patch_size) - For training
temperature: Sampling temperature (0 = greedy)
Returns:
Training: (Batch, Num_Patches, patch_size, 256) - Logits
Inference: (Batch, Num_Patches, patch_size) - Byte IDs
"""
Class: AGIFORMER
Main model class.
class AGIFORMER(nn.Module):
def __init__(
self,
d_model: int = 512,
n_layers: int = 6,
num_heads: int = 8,
patch_size: int = 4,
window_size: int = 128,
vocab_size: int = 256,
dropout: float = 0.1,
thinking_steps: int = 3
)
Parameters:
d_model: Latent dimensionn_layers: Number of HybridBlocksnum_heads: Attention heads per layerpatch_size: Bytes per patchwindow_size: Local attention windowvocab_size: Always 256 (bytes)dropout: Dropout probabilitythinking_steps: System 2 iterations
Methods:
def forward(
self,
x: torch.Tensor,
target_bytes: Optional[torch.Tensor] = None,
temperature: float = 0.0
) -> torch.Tensor:
"""
Full forward pass: Encoder → Backbone → Reasoning → Decoder
Args:
x: (Batch, Seq_Len) - Input bytes
target_bytes: (Batch, Seq_Len_Target) - For training
temperature: Sampling temperature
Returns:
Training: (Batch, Num_Patches, patch_size, 256)
Inference: (Batch, Num_Patches, patch_size)
"""
Module: src.data.real_data
Class: Enwik8Dataset
PyTorch dataset for enwik8.
class Enwik8Dataset(torch.utils.data.Dataset):
def __init__(
self,
data_dir: str = "./data",
split: str = "train",
seq_len: int = 1024
)
Parameters:
split: "train", "val", or "test"seq_len: Sequence length per sample
Methods:
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns:
input: (seq_len,) - Context bytes
target: (seq_len,) - Next-patch bytes
"""
Function: get_enwik8_dataloader
Creates DataLoader with automatic download.
def get_enwik8_dataloader(
batch_size: int,
seq_len: int,
split: str = "train"
) -> torch.utils.data.DataLoader:
"""
Args:
batch_size: Batch size
seq_len: Sequence length
split: "train", "val", or "test"
Returns:
DataLoader yielding (input, target) batches
"""
Utility Scripts
train.py
Main training loop.
Key Functions:
def train_step(model, batch, optimizer, criterion):
"""Single training step"""
def validate(model, val_loader, criterion):
"""Validation loop"""
generate.py
Inference with temperature sampling.
Key Function:
def generate_text(
model_path: str,
prompt_text: str,
max_new_tokens: int = 200,
temperature: float = 0.7
) -> None:
"""Generate text from prompt"""
inspect_reasoning.py
System 2 diagnostics.
Key Function:
def inspect_system_2(model_path: str) -> None:
"""
Measures:
- Latent refinement (Δz)
- Gate biases
- Parameter health
"""
Example Usage
Training from Scratch
from src.models.agiformer import AGIFORMER
from src.data.real_data import get_enwik8_dataloader
import torch.optim as optim
model = AGIFORMER(d_model=512, n_layers=6, thinking_steps=3)
train_loader = get_enwik8_dataloader(batch_size=4, seq_len=1024)
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
for batch in train_loader:
x, target = batch
logits = model(x, target_bytes=target)
loss = F.cross_entropy(logits.view(-1, 256), target.view(-1))
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
Custom Inference
model = AGIFORMER()
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
prompt_bytes = torch.tensor([ord(c) for c in "Hello world"])
with torch.no_grad():
output = model(prompt_bytes.unsqueeze(0), temperature=0.7)
generated = output[0, -1, :].tolist()
text = ''.join([chr(b) for b in generated if 32 <= b <= 126])
print(text)
Type Hints Summary
# Common types
Tensor = torch.Tensor
IntTensor = torch.LongTensor
FloatTensor = torch.FloatTensor
# Shapes (notation)
B = Batch size
L = Sequence length
N = Number of patches (L / patch_size)
P = Patch size
D = d_model
H = num_heads
V = Vocabulary size (256)
# Input/Output shapes
Input: (B, L) IntTensor
Latent: (B, N, D) FloatTensor
Logits: (B, N, P, V) FloatTensor
Output: (B, N, P) IntTensor