maaza-nlm-orchestrator-9.6m / production_router.py
CycleCore-Technologies's picture
Upload production_router.py with huggingface_hub
b1a8e54 verified
# production_router.py — Production wrapper for maaza-nlm-orchestrator-9.6m
# Adds spell-correction and retry logic for 92-94% end-to-end success
#
# Install: pip install symspellpy orjson
# Download dictionary: https://github.com/mammothb/symspellpy#dictionary-data
import json
import torch
try:
import orjson
USE_ORJSON = True
except ImportError:
USE_ORJSON = False
try:
from symspellpy import SymSpell, Verbosity
sym_spell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7)
# Load dictionary - download from symspellpy repo if not present
import os
dict_path = os.path.join(os.path.dirname(__file__), "frequency_dictionary_en_82_765.txt")
if os.path.exists(dict_path):
sym_spell.load_dictionary(dict_path, term_index=0, count_index=1)
USE_SYMSPELL = True
else:
USE_SYMSPELL = False
except ImportError:
USE_SYMSPELL = False
def spell_correct(text: str) -> str:
"""Fast spell correction using SymSpell (2ms)."""
if not USE_SYMSPELL:
return text
suggestions = sym_spell.lookup_compound(text, max_edit_distance=2)
return suggestions[0].term if suggestions else text
def parse_json(text: str) -> dict | list | None:
"""Parse JSON from model output."""
try:
# Try to extract JSON from output
if "[{" in text:
start_idx = text.index("[")
end_idx = text.rindex("]") + 1
json_str = text[start_idx:end_idx]
elif "{" in text:
start_idx = text.index("{")
end_idx = text.rindex("}") + 1
json_str = text[start_idx:end_idx]
else:
return None
if USE_ORJSON:
return orjson.loads(json_str)
else:
return json.loads(json_str)
except (ValueError, json.JSONDecodeError):
return None
def generate(model, tokenizer, prompt: str, max_tokens: int = 64, device: str = "cuda") -> str:
"""Generate output from model."""
full_prompt = f"<|user|>{prompt}<|assistant|>"
input_ids = torch.tensor([tokenizer.encode(full_prompt)]).to(device)
with torch.no_grad():
for _ in range(max_tokens):
outputs = model(input_ids)
logits = outputs["logits"]
next_token = logits[0, -1].argmax().item()
input_ids = torch.cat([input_ids, torch.tensor([[next_token]]).to(device)], dim=1)
if next_token == tokenizer.vocab.get("<|eos|>"):
break
return tokenizer.decode(input_ids[0].tolist())
def route_with_retry(
prompt: str,
model,
tokenizer,
max_attempts: int = 2,
device: str = "cuda"
) -> dict:
"""
Route a natural language prompt to a tool call with spell-correction and retry.
Returns:
dict with keys:
- tool_call: parsed tool call (list or dict) or None
- attempts: number of attempts made
- raw: raw model output
- error: error message if failed
"""
original_prompt = prompt
prompt = spell_correct(prompt)
for attempt in range(max_attempts):
raw_output = generate(model, tokenizer, prompt, device=device)
tool_call = parse_json(raw_output)
if tool_call is not None:
# Validate structure
if isinstance(tool_call, list) and len(tool_call) > 0:
if "tool" in tool_call[0]:
return {
"tool_call": tool_call,
"attempts": attempt + 1,
"raw": raw_output
}
elif isinstance(tool_call, dict) and "tool" in tool_call:
return {
"tool_call": tool_call,
"attempts": attempt + 1,
"raw": raw_output
}
# Retry with stronger prompt
prompt = f"Return only valid JSON tool call: {original_prompt}"
return {
"tool_call": None,
"attempts": max_attempts,
"raw": raw_output,
"error": "Failed to parse valid tool call after retries"
}
# Example usage:
if __name__ == "__main__":
from model import MaazaNanoModel, MaazaNanoConfig
from tokenizer import BPETokenizer
# Load model
config = MaazaNanoConfig(**json.load(open("config.json")))
model = MaazaNanoModel(config)
model.load_state_dict(torch.load("model.pt", weights_only=True))
model.eval().cuda()
tokenizer = BPETokenizer.load("tokenizer.json")
# Test with typos
test_prompts = [
"serch for cats on teh interent",
"whats teh wether in tokyo",
"reed the config.json fiel",
]
for prompt in test_prompts:
result = route_with_retry(prompt, model, tokenizer)
print(f"Input: {prompt}")
print(f"Tool: {result['tool_call']}")
print(f"Attempts: {result['attempts']}")
print()