File size: 9,705 Bytes
5ab87e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
# single_question_recall.py
from __future__ import annotations
import re
import os
from typing import Any, Dict, Optional

from re_call import ReCall
from transformers import AutoTokenizer


import re
from typing import Optional, Any, Dict, Tuple, List

def _extract_answer_boxed(s: str) -> Optional[str]:
    """
    Return the content of the *last* \\boxed{...} or \\fbox{...} in `s`,
    with proper matching of nested braces. Escaped braces (\\{, \\}) are ignored
    for counting. If no balanced block is found, returns None.
    """
    def _iter_box_like_spans(text: str):
        # Find openings for \boxed{ and \fbox{
        openings: List[Tuple[str, int, int]] = []
        for m in re.finditer(r'\\boxed\s*\{', text):
            openings.append(("boxed", m.start(), m.end()))
        for m in re.finditer(r'\\fbox\s*\{', text):
            openings.append(("fbox", m.start(), m.end()))
        openings.sort(key=lambda x: x[1])
        # For each opening, scan forward to find its matching closing brace
        for kind, start, open_end in openings:
            depth = 1
            i = open_end
            n = len(text)
            while i < n:
                ch = text[i]
                # Skip escaped characters: backslash escapes the next char (including { or })
                if ch == '\\' and i + 1 < n:
                    i += 2
                    continue
                if ch == '{':
                    depth += 1
                elif ch == '}':
                    depth -= 1
                    if depth == 0:
                        # content is text[open_end:i]
                        yield (kind, start, open_end, i)
                        break
                i += 1

    last_content: Optional[str] = None
    for _, _start, open_end, close_idx in _iter_box_like_spans(s):
        last_content = s[open_end:close_idx]  # keep the *last* one

    return last_content.strip() if last_content is not None else None


def _extract_answer_tagged(s: str) -> Optional[str]:
    answer_tag_re = re.compile(r"<answer>(.*?)</answer>", re.S)
    m = answer_tag_re.findall(s)
    return m[-1].strip() if m else None


def _parse_answer_from_transcript(transcript: str) -> str:
    """
    Prefer balanced \\boxed{...}/\\fbox{...} content, then <answer>...</answer>,
    else fall back to the last 200 chars.
    """
    return (
        _extract_answer_boxed(transcript)
        or _extract_answer_tagged(transcript)
        # or transcript[-200:].strip()
    )


# --- main API: recall only ---
def answer_question_recall(
    question: str,
    *,
    model_url: Optional[str] = None,         # your thinker endpoint (if recall uses one)
    executor_url: Optional[str] = None,
    tokenizer_dir: str = "./tokenizer-info",
    temperature: float = 0.6,
    max_new_tokens: int = 40960,
    top_p: float = 0.95,
    search_env: str = "from search_api import search_urls, open_url, search_and_parse_query, query_url",
    func_schemas = [
        {
            "name": "search_urls",
            "description": "Google search and return links to web-pages with a brief snippet given a text query",
            "parameters": {
                "type": "object",
                "properties": {"query": {"type": "string"}, "top_k": {"type": "integer", "default": 10}},
                "required": ["query"],
            },
        },
        {
            "name": "query_url",
            "description": "Visit webpage and return evidence based retrival for the provided goal",
            "parameters": {
                "type": "object",
                "properties": {
                    "url": {"type": "string", "description": "The URL of the webpage to visit. Must be a single URL"},
                    "goal": {"type": "string", "description": "The specific information goal for visiting webpage"},
                },
                "required": ["url", "goal"],
            },
        },
    ],
    deepseek_name: str = "deepseek-ai/DeepSeek-R1",
    old_prompt: Optional[str] = None,
    deepresearch_on: bool = True,
    summary_llm: str = "gpt-4.1-mini",
    ):
    # ) -> Dict[str, Any]:
    """
    Runs a single question through ReCall and returns:
    {
      "answer": str,
      "transcript": str,
      "tool_calls": Any,
      "chat": Any | None
    }
    """
    if executor_url is None:
        executor_url = os.environ["HOST_SERPER_URL"]
        
    if model_url is None:
        model_url = os.environ["HF_MODEL_URL"]

    # 1) tokenizer (REQUIRED by ReCall.run)
    tok = AutoTokenizer.from_pretrained(tokenizer_dir, trust_remote_code=True)

    # 2) build agent
    agent = ReCall(executor_url=executor_url)
    
    last_out = ""

    # 3) call the correct entrypoint
    if model_url == deepseek_name:
        # some setups use a special deepseek path that returns (transcript, tool_calls)
        out = agent.run_deepseek(
            env=search_env,
            func_schemas=func_schemas,
            question=question,
            model_name=deepseek_name,
            temperature=temperature,
            max_tokens=max_new_tokens,
            top_p=top_p,
        )
        transcript, tool_calls, chat = _normalize_out(out, expect_chat=False)
        last_out = transcript
    else:
        # standard ReCall.run MUST receive tokenizer
        agent_generator = agent.run(
            env=search_env,
            func_schemas=func_schemas,
            question=question,
            model_url=model_url,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            tokenizer=tok, # <- fixes your "missing tokenizer" error
            top_p=top_p,
            old_prompt=old_prompt,# <- you can pass the raw old prompt here if there exists an older chat history
            # the function will append the question to the raw old prompt string (chat history) if it is not None
            deepresearch_on=deepresearch_on,
            summary_llm=summary_llm
            # deepresearch=deepresearch, # <- use the deepresearch prompt
        )
        
        while True:
            try:
                tag, out = next(agent_generator)
                if tag == "assistant_resp":
                    last_out = out[0]
                yield tag, out
                if tag == "end":
                    break
            except StopIteration as e:
                # the chat_str variable contains the whole conversation in the raw string form
                # it contains the raw tokens like "<|im_start|>system\n", "<|im_end|>"
                # and "<|im_start|>assistant\n<think>", "<tool_response>", "\n</tool_response>\n", etc.
                chat_str: str = e.value[1][0]
                yield "end", (chat_str,)
                break

    # 4) parse final answer
    answer = _parse_answer_from_transcript(last_out)

    return "answer", (answer,)


def _normalize_out(out: Any, expect_chat: bool) -> tuple[str, Any, Any]:
    """
    Normalize ReCall outputs to (transcript, tool_calls, chat)
    Handles:
      - (transcript, tool_calls, chat)
      - (transcript, tool_calls)
      - "transcript"
      - {"transcript": ..., "tool_calls": ..., "chat": ...} variants
    """
    transcript, tool_calls, chat = "", None, None

    if isinstance(out, tuple):
        if len(out) == 3:
            transcript, tool_calls, chat = out
        elif len(out) == 2:
            transcript, tool_calls = out
        elif len(out) == 1:
            transcript = out[0]
        else:
            transcript = str(out[-1])
    elif isinstance(out, dict):
        transcript = out.get("transcript") or out.get("output") or out.get("response") or ""
        tool_calls = out.get("tool_calls")
        chat = out.get("chat")
    else:
        transcript = str(out)

    # Some implementations return None/empty; keep things predictable
    if chat is None and expect_chat is False:
        chat = None
    return transcript, tool_calls, chat


# quick demo
if __name__ == "__main__":
    old_prompt = None
    
    answer_generator = answer_question_recall(
        "What is the most popular restraunt in kolkata?",
        old_prompt=old_prompt
    )
    
    # print("ANSWER:", res["answer"])
    # print("\n")
    # # print(type(res["tool_calls"]), len(res["tool_calls"]))
    # for i in res["tool_calls"]:
    #     print(f"{i}\n")
    # print("\n")
    # if res["chat"] is not None:
    #     # print(type(res["chat"]), len(res["chat"]))
    #     for i in res["chat"]:
    #         print(f"{i}\n")
    #     print("\n")
    # print("TRANSCRIPT (tail):\n", res["transcript"][-300:])
    
    final_chat_str = ""
    
    while True:
        try:
            tag, out = next(answer_generator)
            if tag == "assistant_resp":
                assistant_text, tool_calls = out
                print(f"ASSISTANT RESPONSE:\n{assistant_text}\n\n")
                print("TOOL CALLS:\n")
                for tool_call in tool_calls:
                    print(f"{tool_call}")
                print("\n")
            elif tag == "tool_results":
                results = out[0]
                print("TOOL RESULTS:\n")
                for result in results:
                    print(f"{result}")
                print("\n")
            elif tag == "end":
                print(f"{'='*20}\nASSISTANT RESPONSE ENDED\n{'='*20}\n\n")
                final_chat_str = out[0]
            elif tag == "answer":
                answer = out[0]
                print(f"FINAL ANSWER:\n{answer}\n\n")
                break
        except StopIteration as e:
            print(f"FINAL ANSWER:\n{e.value[1][0]}\n\n")
            break
    
    print(f"{'='*20}\nEND\n{'='*20}\n\n\nFINAL CHAT STRING:\n{final_chat_str}\n\n")