File size: 4,915 Bytes
b1a8e54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# production_router.py — Production wrapper for maaza-nlm-orchestrator-9.6m
# Adds spell-correction and retry logic for 92-94% end-to-end success
#
# Install: pip install symspellpy orjson
# Download dictionary: https://github.com/mammothb/symspellpy#dictionary-data

import json
import torch

try:
    import orjson
    USE_ORJSON = True
except ImportError:
    USE_ORJSON = False

try:
    from symspellpy import SymSpell, Verbosity
    sym_spell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7)
    # Load dictionary - download from symspellpy repo if not present
    import os
    dict_path = os.path.join(os.path.dirname(__file__), "frequency_dictionary_en_82_765.txt")
    if os.path.exists(dict_path):
        sym_spell.load_dictionary(dict_path, term_index=0, count_index=1)
        USE_SYMSPELL = True
    else:
        USE_SYMSPELL = False
except ImportError:
    USE_SYMSPELL = False


def spell_correct(text: str) -> str:
    """Fast spell correction using SymSpell (2ms)."""
    if not USE_SYMSPELL:
        return text
    suggestions = sym_spell.lookup_compound(text, max_edit_distance=2)
    return suggestions[0].term if suggestions else text


def parse_json(text: str) -> dict | list | None:
    """Parse JSON from model output."""
    try:
        # Try to extract JSON from output
        if "[{" in text:
            start_idx = text.index("[")
            end_idx = text.rindex("]") + 1
            json_str = text[start_idx:end_idx]
        elif "{" in text:
            start_idx = text.index("{")
            end_idx = text.rindex("}") + 1
            json_str = text[start_idx:end_idx]
        else:
            return None

        if USE_ORJSON:
            return orjson.loads(json_str)
        else:
            return json.loads(json_str)
    except (ValueError, json.JSONDecodeError):
        return None


def generate(model, tokenizer, prompt: str, max_tokens: int = 64, device: str = "cuda") -> str:
    """Generate output from model."""
    full_prompt = f"<|user|>{prompt}<|assistant|>"
    input_ids = torch.tensor([tokenizer.encode(full_prompt)]).to(device)

    with torch.no_grad():
        for _ in range(max_tokens):
            outputs = model(input_ids)
            logits = outputs["logits"]
            next_token = logits[0, -1].argmax().item()
            input_ids = torch.cat([input_ids, torch.tensor([[next_token]]).to(device)], dim=1)

            if next_token == tokenizer.vocab.get("<|eos|>"):
                break

    return tokenizer.decode(input_ids[0].tolist())


def route_with_retry(
    prompt: str,
    model,
    tokenizer,
    max_attempts: int = 2,
    device: str = "cuda"
) -> dict:
    """
    Route a natural language prompt to a tool call with spell-correction and retry.

    Returns:
        dict with keys:
            - tool_call: parsed tool call (list or dict) or None
            - attempts: number of attempts made
            - raw: raw model output
            - error: error message if failed
    """
    original_prompt = prompt
    prompt = spell_correct(prompt)

    for attempt in range(max_attempts):
        raw_output = generate(model, tokenizer, prompt, device=device)
        tool_call = parse_json(raw_output)

        if tool_call is not None:
            # Validate structure
            if isinstance(tool_call, list) and len(tool_call) > 0:
                if "tool" in tool_call[0]:
                    return {
                        "tool_call": tool_call,
                        "attempts": attempt + 1,
                        "raw": raw_output
                    }
            elif isinstance(tool_call, dict) and "tool" in tool_call:
                return {
                    "tool_call": tool_call,
                    "attempts": attempt + 1,
                    "raw": raw_output
                }

        # Retry with stronger prompt
        prompt = f"Return only valid JSON tool call: {original_prompt}"

    return {
        "tool_call": None,
        "attempts": max_attempts,
        "raw": raw_output,
        "error": "Failed to parse valid tool call after retries"
    }


# Example usage:
if __name__ == "__main__":
    from model import MaazaNanoModel, MaazaNanoConfig
    from tokenizer import BPETokenizer

    # Load model
    config = MaazaNanoConfig(**json.load(open("config.json")))
    model = MaazaNanoModel(config)
    model.load_state_dict(torch.load("model.pt", weights_only=True))
    model.eval().cuda()

    tokenizer = BPETokenizer.load("tokenizer.json")

    # Test with typos
    test_prompts = [
        "serch for cats on teh interent",
        "whats teh wether in tokyo",
        "reed the config.json fiel",
    ]

    for prompt in test_prompts:
        result = route_with_retry(prompt, model, tokenizer)
        print(f"Input: {prompt}")
        print(f"Tool: {result['tool_call']}")
        print(f"Attempts: {result['attempts']}")
        print()