CycleCore-Technologies's picture
Upload folder using huggingface_hub
833cfe9 verified
#!/usr/bin/env python3
"""
Maaza Nano 9.6M - The 99ms Brain
Simple inference script for tool routing.
Usage:
python inference.py "search for cats"
python inference.py "read the file config.json"
python inference.py "send an email to bob@example.com"
"""
import torch
import json
import sys
import time
from pathlib import Path
# Add current directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from model import MaazaNanoModel, MaazaNanoConfig
from tokenizer import BPETokenizer
def load_model(model_dir: str = "."):
"""Load the Maaza Nano model."""
model_dir = Path(model_dir)
# Load tokenizer
tokenizer = BPETokenizer.load(str(model_dir / "tokenizer.json"))
# Load model config
with open(model_dir / "config.json") as f:
cfg = json.load(f)
config = MaazaNanoConfig(
vocab_size=cfg["vocab_size"],
hidden_size=cfg["hidden_size"],
num_layers=cfg["num_layers"],
num_heads=cfg["num_heads"],
intermediate_size=cfg["intermediate_size"],
max_position_embeddings=cfg["max_position_embeddings"],
)
# Load model weights
model = MaazaNanoModel(config)
model.load_state_dict(torch.load(model_dir / "model.pt", weights_only=True))
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
return model, tokenizer, device
def route_tool(prompt: str, model, tokenizer, device, max_tokens: int = 100):
"""Route a natural language prompt to a tool call."""
# Format input
full_prompt = f"<|user|>{prompt}<|assistant|>"
tokens = tokenizer.encode(full_prompt)
input_ids = torch.tensor([tokens]).to(device)
# Generate
start_time = time.time()
with torch.no_grad():
for _ in range(max_tokens):
outputs = model(input_ids)
logits = outputs["logits"]
next_token = logits[0, -1].argmax().item()
input_ids = torch.cat([input_ids, torch.tensor([[next_token]]).to(device)], dim=1)
# Stop at EOS
if next_token == tokenizer.vocab.get("<|eos|>"):
break
latency_ms = (time.time() - start_time) * 1000
# Decode output
generated = tokenizer.decode(input_ids[0].tolist())
# Extract JSON from output
try:
json_start = generated.find('[{')
if json_start >= 0:
json_end = generated.find('}]', json_start) + 2
json_str = generated[json_start:json_end]
tool_calls = json.loads(json_str)
return tool_calls, latency_ms
except json.JSONDecodeError:
pass
return None, latency_ms
def main():
if len(sys.argv) < 2:
print("Usage: python inference.py \"your prompt here\"")
print("\nExamples:")
print(" python inference.py \"search for cats\"")
print(" python inference.py \"read config.json\"")
print(" python inference.py \"send email to bob@example.com\"")
sys.exit(1)
prompt = sys.argv[1]
print("Loading Maaza Nano 9.6M...")
model, tokenizer, device = load_model()
print(f"Model loaded on {device}")
print()
print(f"Prompt: {prompt}")
tool_calls, latency = route_tool(prompt, model, tokenizer, device)
print(f"Latency: {latency:.1f}ms")
print()
if tool_calls:
print("Tool call:")
print(json.dumps(tool_calls, indent=2))
else:
print("Failed to parse tool call")
if __name__ == "__main__":
main()