|
|
|
|
|
""" |
|
|
Maaza Nano 9.6M - The 99ms Brain |
|
|
Simple inference script for tool routing. |
|
|
|
|
|
Usage: |
|
|
python inference.py "search for cats" |
|
|
python inference.py "read the file config.json" |
|
|
python inference.py "send an email to bob@example.com" |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import json |
|
|
import sys |
|
|
import time |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent)) |
|
|
|
|
|
from model import MaazaNanoModel, MaazaNanoConfig |
|
|
from tokenizer import BPETokenizer |
|
|
|
|
|
|
|
|
def load_model(model_dir: str = "."): |
|
|
"""Load the Maaza Nano model.""" |
|
|
model_dir = Path(model_dir) |
|
|
|
|
|
|
|
|
tokenizer = BPETokenizer.load(str(model_dir / "tokenizer.json")) |
|
|
|
|
|
|
|
|
with open(model_dir / "config.json") as f: |
|
|
cfg = json.load(f) |
|
|
|
|
|
config = MaazaNanoConfig( |
|
|
vocab_size=cfg["vocab_size"], |
|
|
hidden_size=cfg["hidden_size"], |
|
|
num_layers=cfg["num_layers"], |
|
|
num_heads=cfg["num_heads"], |
|
|
intermediate_size=cfg["intermediate_size"], |
|
|
max_position_embeddings=cfg["max_position_embeddings"], |
|
|
) |
|
|
|
|
|
|
|
|
model = MaazaNanoModel(config) |
|
|
model.load_state_dict(torch.load(model_dir / "model.pt", weights_only=True)) |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
return model, tokenizer, device |
|
|
|
|
|
|
|
|
def route_tool(prompt: str, model, tokenizer, device, max_tokens: int = 100): |
|
|
"""Route a natural language prompt to a tool call.""" |
|
|
|
|
|
full_prompt = f"<|user|>{prompt}<|assistant|>" |
|
|
tokens = tokenizer.encode(full_prompt) |
|
|
input_ids = torch.tensor([tokens]).to(device) |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
with torch.no_grad(): |
|
|
for _ in range(max_tokens): |
|
|
outputs = model(input_ids) |
|
|
logits = outputs["logits"] |
|
|
next_token = logits[0, -1].argmax().item() |
|
|
input_ids = torch.cat([input_ids, torch.tensor([[next_token]]).to(device)], dim=1) |
|
|
|
|
|
|
|
|
if next_token == tokenizer.vocab.get("<|eos|>"): |
|
|
break |
|
|
|
|
|
latency_ms = (time.time() - start_time) * 1000 |
|
|
|
|
|
|
|
|
generated = tokenizer.decode(input_ids[0].tolist()) |
|
|
|
|
|
|
|
|
try: |
|
|
json_start = generated.find('[{') |
|
|
if json_start >= 0: |
|
|
json_end = generated.find('}]', json_start) + 2 |
|
|
json_str = generated[json_start:json_end] |
|
|
tool_calls = json.loads(json_str) |
|
|
return tool_calls, latency_ms |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
return None, latency_ms |
|
|
|
|
|
|
|
|
def main(): |
|
|
if len(sys.argv) < 2: |
|
|
print("Usage: python inference.py \"your prompt here\"") |
|
|
print("\nExamples:") |
|
|
print(" python inference.py \"search for cats\"") |
|
|
print(" python inference.py \"read config.json\"") |
|
|
print(" python inference.py \"send email to bob@example.com\"") |
|
|
sys.exit(1) |
|
|
|
|
|
prompt = sys.argv[1] |
|
|
|
|
|
print("Loading Maaza Nano 9.6M...") |
|
|
model, tokenizer, device = load_model() |
|
|
print(f"Model loaded on {device}") |
|
|
print() |
|
|
|
|
|
print(f"Prompt: {prompt}") |
|
|
tool_calls, latency = route_tool(prompt, model, tokenizer, device) |
|
|
|
|
|
print(f"Latency: {latency:.1f}ms") |
|
|
print() |
|
|
|
|
|
if tool_calls: |
|
|
print("Tool call:") |
|
|
print(json.dumps(tool_calls, indent=2)) |
|
|
else: |
|
|
print("Failed to parse tool call") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|