File size: 3,551 Bytes
833cfe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python3
"""
Maaza Nano 9.6M - The 99ms Brain
Simple inference script for tool routing.

Usage:
    python inference.py "search for cats"
    python inference.py "read the file config.json"
    python inference.py "send an email to bob@example.com"
"""

import torch
import json
import sys
import time
from pathlib import Path

# Add current directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))

from model import MaazaNanoModel, MaazaNanoConfig
from tokenizer import BPETokenizer


def load_model(model_dir: str = "."):
    """Load the Maaza Nano model."""
    model_dir = Path(model_dir)

    # Load tokenizer
    tokenizer = BPETokenizer.load(str(model_dir / "tokenizer.json"))

    # Load model config
    with open(model_dir / "config.json") as f:
        cfg = json.load(f)

    config = MaazaNanoConfig(
        vocab_size=cfg["vocab_size"],
        hidden_size=cfg["hidden_size"],
        num_layers=cfg["num_layers"],
        num_heads=cfg["num_heads"],
        intermediate_size=cfg["intermediate_size"],
        max_position_embeddings=cfg["max_position_embeddings"],
    )

    # Load model weights
    model = MaazaNanoModel(config)
    model.load_state_dict(torch.load(model_dir / "model.pt", weights_only=True))

    # Use GPU if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)
    model.eval()

    return model, tokenizer, device


def route_tool(prompt: str, model, tokenizer, device, max_tokens: int = 100):
    """Route a natural language prompt to a tool call."""
    # Format input
    full_prompt = f"<|user|>{prompt}<|assistant|>"
    tokens = tokenizer.encode(full_prompt)
    input_ids = torch.tensor([tokens]).to(device)

    # Generate
    start_time = time.time()
    with torch.no_grad():
        for _ in range(max_tokens):
            outputs = model(input_ids)
            logits = outputs["logits"]
            next_token = logits[0, -1].argmax().item()
            input_ids = torch.cat([input_ids, torch.tensor([[next_token]]).to(device)], dim=1)

            # Stop at EOS
            if next_token == tokenizer.vocab.get("<|eos|>"):
                break

    latency_ms = (time.time() - start_time) * 1000

    # Decode output
    generated = tokenizer.decode(input_ids[0].tolist())

    # Extract JSON from output
    try:
        json_start = generated.find('[{')
        if json_start >= 0:
            json_end = generated.find('}]', json_start) + 2
            json_str = generated[json_start:json_end]
            tool_calls = json.loads(json_str)
            return tool_calls, latency_ms
    except json.JSONDecodeError:
        pass

    return None, latency_ms


def main():
    if len(sys.argv) < 2:
        print("Usage: python inference.py \"your prompt here\"")
        print("\nExamples:")
        print("  python inference.py \"search for cats\"")
        print("  python inference.py \"read config.json\"")
        print("  python inference.py \"send email to bob@example.com\"")
        sys.exit(1)

    prompt = sys.argv[1]

    print("Loading Maaza Nano 9.6M...")
    model, tokenizer, device = load_model()
    print(f"Model loaded on {device}")
    print()

    print(f"Prompt: {prompt}")
    tool_calls, latency = route_tool(prompt, model, tokenizer, device)

    print(f"Latency: {latency:.1f}ms")
    print()

    if tool_calls:
        print("Tool call:")
        print(json.dumps(tool_calls, indent=2))
    else:
        print("Failed to parse tool call")


if __name__ == "__main__":
    main()