agiformer / docs /inference.md
tefoteknik's picture
Update AGIFORMER with Turkish benchmark
58413f0 verified

Inference Guide

Quick Start

python generate.py

Default Output: ``` Prompt: 'The history of '

The history of Tomadination of the [[New Gouple de aparty]]...


---

## Basic Usage

### 1. Load Model
```python
from src.models.agiformer import AGIFORMER
import torch

model = AGIFORMER(d_model=512, n_layers=6, patch_size=4, thinking_steps=3)
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

2. Prepare Input

prompt = "The history of artificial intelligence"
input_bytes = [ord(c) for c in prompt]

# Pad to patch_size boundary
pad = (4 - len(input_bytes) % 4) % 4
input_bytes.extend([32] * pad)

x = torch.tensor(input_bytes).unsqueeze(0)  # (1, seq_len)

3. Generate

with torch.no_grad():
    output = model(x, temperature=0.7)  # (1, num_patches, patch_size)
    
# Decode
generated_bytes = output[0, -1, :].tolist()
text = ''.join([chr(b) for b in generated_bytes if 32 <= b <= 126])

Temperature Sampling

Greedy (Temperature = 0)

output = model(x, temperature=0.0)
  • Picks most likely byte every step
  • Deterministic (same output each run)
  • Prone to repetition loops

Example:

The history of of of of of...

Low Temperature (0.3 - 0.5)

output = model(x, temperature=0.3)
  • Slightly random, still conservative
  • Good for coherent text
  • Reduces repetition

Example:

The history of the computer system...

Medium Temperature (0.7 - 0.9)

output = model(x, temperature=0.7)  # Default
  • Balanced creativity/coherence
  • Recommended for exploration

Example:

The history of Tomadination of the [[New Gouple]]...

High Temperature (1.0+)

output = model(x, temperature=1.2)
  • Very random
  • Incoherent but diverse
  • Good for debugging model knowledge

Example:

The history qw8#$x [[zap]] nullification...

Advanced: Token-by-Token Generation

For streaming output:

def generate_stream(model, prompt, max_tokens=200, temperature=0.7):
    # Encode prompt
    context = [ord(c) for c in prompt]
    pad = (4 - len(context) % 4) % 4
    context.extend([32] * pad)
    
    for _ in range(max_tokens // 4):  # Generate patch-by-patch
        x = torch.tensor(context[-1024:]).unsqueeze(0)  # Sliding window
        
        with torch.no_grad():
            pred = model(x, temperature=temperature)
        
        # Get last patch
        new_bytes = pred[0, -1, :].cpu().tolist()
        context.extend(new_bytes)
        
        # Decode and print
        chunk = ''.join([chr(b) for b in new_bytes if 32 <= b <= 126])
        print(chunk, end='', flush=True)

Usage:

generate_stream(model, "The history of ", max_tokens=200)

System 2 Control

Disable Thinking (Baseline)

model = AGIFORMER(thinking_steps=0)  # Skip System 2
  • Faster inference (~2× speedup)
  • Lower quality output

Increase Thinking

model = AGIFORMER(thinking_steps=5)  # More refinement
  • Slower inference
  • Potentially better coherence

Runtime Control

System 2 is part of the model, so you must reinitialize:

# Not possible to change thinking_steps after model creation
# Must create new model with desired config

Batch Inference

Process multiple prompts:

prompts = ["The history of ", "In the year 2050, ", "Once upon a time, "]
batch = []

for prompt in prompts:
    bytes = [ord(c) for c in prompt]
    pad = (4 - len(bytes) % 4) % 4
    bytes.extend([32] * pad)
    batch.append(torch.tensor(bytes))

# Pad to same length
max_len = max(t.size(0) for t in batch)
batch_tensor = torch.stack([
    F.pad(t, (0, max_len - t.size(0)))
    for t in batch
])

# Generate
with torch.no_grad():
    outputs = model(batch_tensor, temperature=0.7)

Debugging Output

Check Raw Bytes

generated = model(x, temperature=0.0)
raw_bytes = generated[0, -1, :].tolist()
print(f"Raw: {raw_bytes}")  # e.g., [116, 104, 101, 32]

Detect Non-Printables

for b in raw_bytes:
    if not (32 <= b <= 126):
        print(f"Warning: Non-ASCII byte {b}")

Measure Entropy

import torch.nn.functional as F

logits = model.head(latents)  # Get raw logits
probs = F.softmax(logits, dim=-1)
entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1).mean()

print(f"Avg Entropy: {entropy.item():.2f} bits")
# Low (<2): Confident, may repeat
# High (>6): Confused, will be random

Common Issues

Repetition Loops

Problem:

of of of of of...

Solutions:

  1. Increase temperature: 0.0 → 0.7
  2. Use nucleus sampling (top-p):
    probs = F.softmax(logits / temp, dim=-1)
    sorted_probs, indices = torch.sort(probs, descending=True)
    cumsum = torch.cumsum(sorted_probs, dim=-1)
    mask = cumsum > 0.9  # Keep top 90%
    sorted_probs[mask] = 0
    next_byte = torch.multinomial(sorted_probs, 1)
    

Gibberish Output

Problem:

xq#$8z [[nullification]]...

Causes:

  • Temperature too high
  • Model undertrained

Solutions:

  • Lower temperature: 1.2 → 0.5
  • Train longer (20k+ steps)

Slow Inference

Problem: >1s per token

Solutions:

  • Use GPU: model.cuda()
  • Reduce thinking_steps: 3 → 1
  • Disable System 2: thinking_steps=0

Performance Benchmarks

GPU: NVIDIA T4
Prompt Length: 100 bytes
Generation Length: 200 bytes

Config Latency Throughput
Greedy (temp=0) 45ms 22 tokens/s
Sampling (temp=0.7) 52ms 19 tokens/s
System 2 disabled 28ms 36 tokens/s

API Reference

Model Forward

def forward(
    x: torch.Tensor,           # (Batch, Seq_Len) bytes
    target_bytes: Optional[torch.Tensor] = None,  # For training
    temperature: float = 0.0   # Sampling temp (0 = greedy)
) -> torch.Tensor:
    # Returns: (Batch, Num_Patches, Patch_Size, 256) if training
    #          (Batch, Num_Patches, Patch_Size) if inference

Generation Utilities

See generate.py for full implementation:

  • generate_text(model_path, prompt, max_tokens, temperature)
  • Automatic padding and decoding

Next Steps

  1. Experiment with Prompts: Try different domains
  2. Tune Temperature: Find sweet spot for your use case
  3. Extend Context: Modify generate.py to use longer contexts
  4. Fine-tune: Retrain on domain-specific data