File size: 4,915 Bytes
b1a8e54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# production_router.py — Production wrapper for maaza-nlm-orchestrator-9.6m
# Adds spell-correction and retry logic for 92-94% end-to-end success
#
# Install: pip install symspellpy orjson
# Download dictionary: https://github.com/mammothb/symspellpy#dictionary-data
import json
import torch
try:
import orjson
USE_ORJSON = True
except ImportError:
USE_ORJSON = False
try:
from symspellpy import SymSpell, Verbosity
sym_spell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7)
# Load dictionary - download from symspellpy repo if not present
import os
dict_path = os.path.join(os.path.dirname(__file__), "frequency_dictionary_en_82_765.txt")
if os.path.exists(dict_path):
sym_spell.load_dictionary(dict_path, term_index=0, count_index=1)
USE_SYMSPELL = True
else:
USE_SYMSPELL = False
except ImportError:
USE_SYMSPELL = False
def spell_correct(text: str) -> str:
"""Fast spell correction using SymSpell (2ms)."""
if not USE_SYMSPELL:
return text
suggestions = sym_spell.lookup_compound(text, max_edit_distance=2)
return suggestions[0].term if suggestions else text
def parse_json(text: str) -> dict | list | None:
"""Parse JSON from model output."""
try:
# Try to extract JSON from output
if "[{" in text:
start_idx = text.index("[")
end_idx = text.rindex("]") + 1
json_str = text[start_idx:end_idx]
elif "{" in text:
start_idx = text.index("{")
end_idx = text.rindex("}") + 1
json_str = text[start_idx:end_idx]
else:
return None
if USE_ORJSON:
return orjson.loads(json_str)
else:
return json.loads(json_str)
except (ValueError, json.JSONDecodeError):
return None
def generate(model, tokenizer, prompt: str, max_tokens: int = 64, device: str = "cuda") -> str:
"""Generate output from model."""
full_prompt = f"<|user|>{prompt}<|assistant|>"
input_ids = torch.tensor([tokenizer.encode(full_prompt)]).to(device)
with torch.no_grad():
for _ in range(max_tokens):
outputs = model(input_ids)
logits = outputs["logits"]
next_token = logits[0, -1].argmax().item()
input_ids = torch.cat([input_ids, torch.tensor([[next_token]]).to(device)], dim=1)
if next_token == tokenizer.vocab.get("<|eos|>"):
break
return tokenizer.decode(input_ids[0].tolist())
def route_with_retry(
prompt: str,
model,
tokenizer,
max_attempts: int = 2,
device: str = "cuda"
) -> dict:
"""
Route a natural language prompt to a tool call with spell-correction and retry.
Returns:
dict with keys:
- tool_call: parsed tool call (list or dict) or None
- attempts: number of attempts made
- raw: raw model output
- error: error message if failed
"""
original_prompt = prompt
prompt = spell_correct(prompt)
for attempt in range(max_attempts):
raw_output = generate(model, tokenizer, prompt, device=device)
tool_call = parse_json(raw_output)
if tool_call is not None:
# Validate structure
if isinstance(tool_call, list) and len(tool_call) > 0:
if "tool" in tool_call[0]:
return {
"tool_call": tool_call,
"attempts": attempt + 1,
"raw": raw_output
}
elif isinstance(tool_call, dict) and "tool" in tool_call:
return {
"tool_call": tool_call,
"attempts": attempt + 1,
"raw": raw_output
}
# Retry with stronger prompt
prompt = f"Return only valid JSON tool call: {original_prompt}"
return {
"tool_call": None,
"attempts": max_attempts,
"raw": raw_output,
"error": "Failed to parse valid tool call after retries"
}
# Example usage:
if __name__ == "__main__":
from model import MaazaNanoModel, MaazaNanoConfig
from tokenizer import BPETokenizer
# Load model
config = MaazaNanoConfig(**json.load(open("config.json")))
model = MaazaNanoModel(config)
model.load_state_dict(torch.load("model.pt", weights_only=True))
model.eval().cuda()
tokenizer = BPETokenizer.load("tokenizer.json")
# Test with typos
test_prompts = [
"serch for cats on teh interent",
"whats teh wether in tokyo",
"reed the config.json fiel",
]
for prompt in test_prompts:
result = route_with_retry(prompt, model, tokenizer)
print(f"Input: {prompt}")
print(f"Tool: {result['tool_call']}")
print(f"Attempts: {result['attempts']}")
print()
|