|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
import torch |
|
|
|
|
|
try: |
|
|
import orjson |
|
|
USE_ORJSON = True |
|
|
except ImportError: |
|
|
USE_ORJSON = False |
|
|
|
|
|
try: |
|
|
from symspellpy import SymSpell, Verbosity |
|
|
sym_spell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7) |
|
|
|
|
|
import os |
|
|
dict_path = os.path.join(os.path.dirname(__file__), "frequency_dictionary_en_82_765.txt") |
|
|
if os.path.exists(dict_path): |
|
|
sym_spell.load_dictionary(dict_path, term_index=0, count_index=1) |
|
|
USE_SYMSPELL = True |
|
|
else: |
|
|
USE_SYMSPELL = False |
|
|
except ImportError: |
|
|
USE_SYMSPELL = False |
|
|
|
|
|
|
|
|
def spell_correct(text: str) -> str: |
|
|
"""Fast spell correction using SymSpell (2ms).""" |
|
|
if not USE_SYMSPELL: |
|
|
return text |
|
|
suggestions = sym_spell.lookup_compound(text, max_edit_distance=2) |
|
|
return suggestions[0].term if suggestions else text |
|
|
|
|
|
|
|
|
def parse_json(text: str) -> dict | list | None: |
|
|
"""Parse JSON from model output.""" |
|
|
try: |
|
|
|
|
|
if "[{" in text: |
|
|
start_idx = text.index("[") |
|
|
end_idx = text.rindex("]") + 1 |
|
|
json_str = text[start_idx:end_idx] |
|
|
elif "{" in text: |
|
|
start_idx = text.index("{") |
|
|
end_idx = text.rindex("}") + 1 |
|
|
json_str = text[start_idx:end_idx] |
|
|
else: |
|
|
return None |
|
|
|
|
|
if USE_ORJSON: |
|
|
return orjson.loads(json_str) |
|
|
else: |
|
|
return json.loads(json_str) |
|
|
except (ValueError, json.JSONDecodeError): |
|
|
return None |
|
|
|
|
|
|
|
|
def generate(model, tokenizer, prompt: str, max_tokens: int = 64, device: str = "cuda") -> str: |
|
|
"""Generate output from model.""" |
|
|
full_prompt = f"<|user|>{prompt}<|assistant|>" |
|
|
input_ids = torch.tensor([tokenizer.encode(full_prompt)]).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for _ in range(max_tokens): |
|
|
outputs = model(input_ids) |
|
|
logits = outputs["logits"] |
|
|
next_token = logits[0, -1].argmax().item() |
|
|
input_ids = torch.cat([input_ids, torch.tensor([[next_token]]).to(device)], dim=1) |
|
|
|
|
|
if next_token == tokenizer.vocab.get("<|eos|>"): |
|
|
break |
|
|
|
|
|
return tokenizer.decode(input_ids[0].tolist()) |
|
|
|
|
|
|
|
|
def route_with_retry( |
|
|
prompt: str, |
|
|
model, |
|
|
tokenizer, |
|
|
max_attempts: int = 2, |
|
|
device: str = "cuda" |
|
|
) -> dict: |
|
|
""" |
|
|
Route a natural language prompt to a tool call with spell-correction and retry. |
|
|
|
|
|
Returns: |
|
|
dict with keys: |
|
|
- tool_call: parsed tool call (list or dict) or None |
|
|
- attempts: number of attempts made |
|
|
- raw: raw model output |
|
|
- error: error message if failed |
|
|
""" |
|
|
original_prompt = prompt |
|
|
prompt = spell_correct(prompt) |
|
|
|
|
|
for attempt in range(max_attempts): |
|
|
raw_output = generate(model, tokenizer, prompt, device=device) |
|
|
tool_call = parse_json(raw_output) |
|
|
|
|
|
if tool_call is not None: |
|
|
|
|
|
if isinstance(tool_call, list) and len(tool_call) > 0: |
|
|
if "tool" in tool_call[0]: |
|
|
return { |
|
|
"tool_call": tool_call, |
|
|
"attempts": attempt + 1, |
|
|
"raw": raw_output |
|
|
} |
|
|
elif isinstance(tool_call, dict) and "tool" in tool_call: |
|
|
return { |
|
|
"tool_call": tool_call, |
|
|
"attempts": attempt + 1, |
|
|
"raw": raw_output |
|
|
} |
|
|
|
|
|
|
|
|
prompt = f"Return only valid JSON tool call: {original_prompt}" |
|
|
|
|
|
return { |
|
|
"tool_call": None, |
|
|
"attempts": max_attempts, |
|
|
"raw": raw_output, |
|
|
"error": "Failed to parse valid tool call after retries" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
from model import MaazaNanoModel, MaazaNanoConfig |
|
|
from tokenizer import BPETokenizer |
|
|
|
|
|
|
|
|
config = MaazaNanoConfig(**json.load(open("config.json"))) |
|
|
model = MaazaNanoModel(config) |
|
|
model.load_state_dict(torch.load("model.pt", weights_only=True)) |
|
|
model.eval().cuda() |
|
|
|
|
|
tokenizer = BPETokenizer.load("tokenizer.json") |
|
|
|
|
|
|
|
|
test_prompts = [ |
|
|
"serch for cats on teh interent", |
|
|
"whats teh wether in tokyo", |
|
|
"reed the config.json fiel", |
|
|
] |
|
|
|
|
|
for prompt in test_prompts: |
|
|
result = route_with_retry(prompt, model, tokenizer) |
|
|
print(f"Input: {prompt}") |
|
|
print(f"Tool: {result['tool_call']}") |
|
|
print(f"Attempts: {result['attempts']}") |
|
|
print() |
|
|
|