Update app.py
Browse files
app.py
CHANGED
|
@@ -3,6 +3,7 @@ import re
|
|
| 3 |
import gc
|
| 4 |
import sys
|
| 5 |
import time
|
|
|
|
| 6 |
import queue
|
| 7 |
import random
|
| 8 |
import asyncio
|
|
@@ -36,6 +37,8 @@ GPU_LOCK = asyncio.Lock()
|
|
| 36 |
class ChatMessage(BaseModel):
|
| 37 |
role: str = Field()
|
| 38 |
content: str = Field()
|
|
|
|
|
|
|
| 39 |
|
| 40 |
class Logprob(BaseModel):
|
| 41 |
token: str
|
|
@@ -76,6 +79,15 @@ class ChatCompletionChunk(BaseModel):
|
|
| 76 |
choices: List[ChatCompletionChoice]
|
| 77 |
usage: Optional[Usage]
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
def remove_nested_think_tags_stack(text):
|
| 80 |
stack = []
|
| 81 |
result = ""
|
|
@@ -106,7 +118,17 @@ def cleanMessages(messages: List[ChatMessage], removeThinkingContent: bool = Fal
|
|
| 106 |
role_str = message.role.strip().lower().capitalize()
|
| 107 |
if role_str == 'Assistant' and removeThinkingContent:
|
| 108 |
content = remove_nested_think_tags_stack(content)
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
return "\n\n".join(promptStrList)
|
| 111 |
|
| 112 |
class SamplerConfig(BaseModel):
|
|
@@ -252,6 +274,8 @@ class ChatCompletionRequest(BaseModel):
|
|
| 252 |
stream: Optional[bool] = Field(default=False)
|
| 253 |
stop: Optional[List[str]] = Field(["\n\n"])
|
| 254 |
stop_tokens: Optional[List[int]] = Field([0])
|
|
|
|
|
|
|
| 255 |
|
| 256 |
@model_validator(mode="before")
|
| 257 |
@classmethod
|
|
@@ -261,57 +285,105 @@ class ChatCompletionRequest(BaseModel):
|
|
| 261 |
raise ValueError("messages and prompt cannot coexist.")
|
| 262 |
return data
|
| 263 |
|
| 264 |
-
class
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
|
|
|
| 272 |
""".strip()
|
| 273 |
|
| 274 |
@staticmethod
|
| 275 |
-
def
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
@staticmethod
|
| 283 |
-
def
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
if check: results.extend(check)
|
| 301 |
-
if not results: return ""
|
| 302 |
-
ctx = TruthProtocol.sanitise_search(query, results)
|
| 303 |
-
if len(search_cache) > 50: search_cache.popitem(last=False)
|
| 304 |
-
search_cache[query] = ctx
|
| 305 |
-
return ctx
|
| 306 |
-
except:
|
| 307 |
-
return ""
|
| 308 |
-
|
| 309 |
-
def needs_verification(msg: str, model: str) -> bool:
|
| 310 |
-
if ":online" in model: return True
|
| 311 |
-
triggers = ["es verdad", "dato", "precio", "cuando", "quien", "noticia", "actualidad", "verify"]
|
| 312 |
-
return any(t in msg.lower() for t in triggers)
|
| 313 |
-
|
| 314 |
-
app = FastAPI(title="RWKV Ultimate Server")
|
| 315 |
|
| 316 |
app.add_middleware(
|
| 317 |
CORSMiddleware,
|
|
@@ -374,6 +446,7 @@ def generate(request: ChatCompletionRequest, out, model_tokens: List[int], model
|
|
| 374 |
cache_word_list = []
|
| 375 |
|
| 376 |
stop_sequences = request.stop if request.stop else []
|
|
|
|
| 377 |
|
| 378 |
for i in range(max_tokens):
|
| 379 |
for n in occurrence: out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
|
|
@@ -396,38 +469,92 @@ def generate(request: ChatCompletionRequest, out, model_tokens: List[int], model
|
|
| 396 |
out_last = i + 1
|
| 397 |
|
| 398 |
current_buffer = "".join(cache_word_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
for s in stop_sequences:
|
| 400 |
-
if s in current_buffer:
|
| 401 |
final_content = current_buffer.split(s)[0]
|
| 402 |
yield {"content": final_content, "finish_reason": "stop", "state": model_state}
|
| 403 |
del out; gc.collect(); return
|
| 404 |
|
| 405 |
-
if len(cache_word_list) >
|
| 406 |
yield {"content": cache_word_list.pop(0), "finish_reason": None}
|
| 407 |
|
| 408 |
yield {"content": "".join(cache_word_list), "finish_reason": "length"}
|
| 409 |
|
| 410 |
async def chatResponseStream(request: ChatCompletionRequest, model_state: any, completionId: str, enableReasoning: bool):
|
| 411 |
-
|
| 412 |
-
prompt = f"{clean_msg}\n\nAssistant:{' <think' if enableReasoning else ''}"
|
| 413 |
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
|
| 418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
await asyncio.sleep(0)
|
| 429 |
-
finally:
|
| 430 |
-
pass
|
| 431 |
|
| 432 |
yield "data: [DONE]\n\n"
|
| 433 |
|
|
@@ -453,21 +580,16 @@ async def chat_completions(request: ChatCompletionRequest):
|
|
| 453 |
if req_data.get(k) is None: req_data[k] = v
|
| 454 |
realRequest = ChatCompletionRequest(**req_data)
|
| 455 |
|
| 456 |
-
|
| 457 |
-
if realRequest.messages:
|
| 458 |
-
if realRequest.messages[0].role == "System":
|
| 459 |
-
realRequest.messages[0].content = f"{TruthProtocol.STRICT_SYSTEM_PROMPT}\n\n{realRequest.messages[0].content}"
|
| 460 |
-
else:
|
| 461 |
-
realRequest.messages.insert(0, sys_msg)
|
| 462 |
-
|
| 463 |
-
last_msg = realRequest.messages[-1]
|
| 464 |
-
if last_msg.role == "user" and needs_verification(last_msg.content, raw_model):
|
| 465 |
-
ctx = search_facts(last_msg.content)
|
| 466 |
-
if ctx:
|
| 467 |
-
realRequest.messages.insert(-1, ChatMessage(role="System", content=ctx))
|
| 468 |
-
|
| 469 |
-
TruthProtocol.enforce_truth_params(realRequest)
|
| 470 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
realRequest.messages = prune_context(realRequest.messages, target_model, realRequest.max_tokens or 1024)
|
| 472 |
|
| 473 |
return StreamingResponse(chatResponseStream(realRequest, None, completionId, is_reasoning), media_type="text/event-stream")
|
|
|
|
| 3 |
import gc
|
| 4 |
import sys
|
| 5 |
import time
|
| 6 |
+
import json
|
| 7 |
import queue
|
| 8 |
import random
|
| 9 |
import asyncio
|
|
|
|
| 37 |
class ChatMessage(BaseModel):
|
| 38 |
role: str = Field()
|
| 39 |
content: str = Field()
|
| 40 |
+
name: Optional[str] = Field(None)
|
| 41 |
+
tool_call_id: Optional[str] = Field(None)
|
| 42 |
|
| 43 |
class Logprob(BaseModel):
|
| 44 |
token: str
|
|
|
|
| 79 |
choices: List[ChatCompletionChoice]
|
| 80 |
usage: Optional[Usage]
|
| 81 |
|
| 82 |
+
class ToolFunction(BaseModel):
|
| 83 |
+
name: str
|
| 84 |
+
description: str
|
| 85 |
+
parameters: Dict[str, Any]
|
| 86 |
+
|
| 87 |
+
class Tool(BaseModel):
|
| 88 |
+
type: Literal["function"] = "function"
|
| 89 |
+
function: ToolFunction
|
| 90 |
+
|
| 91 |
def remove_nested_think_tags_stack(text):
|
| 92 |
stack = []
|
| 93 |
result = ""
|
|
|
|
| 118 |
role_str = message.role.strip().lower().capitalize()
|
| 119 |
if role_str == 'Assistant' and removeThinkingContent:
|
| 120 |
content = remove_nested_think_tags_stack(content)
|
| 121 |
+
|
| 122 |
+
if message.role == "tool":
|
| 123 |
+
promptStrList.append(f"Tool Output ({message.name}): {content}")
|
| 124 |
+
elif message.role == "system":
|
| 125 |
+
promptStrList.append(f"System: {content}")
|
| 126 |
+
elif message.role == "user":
|
| 127 |
+
promptStrList.append(f"User: {content}")
|
| 128 |
+
elif message.role == "assistant":
|
| 129 |
+
promptStrList.append(f"Assistant: {content}")
|
| 130 |
+
else:
|
| 131 |
+
promptStrList.append(f"{role_str}: {content}")
|
| 132 |
return "\n\n".join(promptStrList)
|
| 133 |
|
| 134 |
class SamplerConfig(BaseModel):
|
|
|
|
| 274 |
stream: Optional[bool] = Field(default=False)
|
| 275 |
stop: Optional[List[str]] = Field(["\n\n"])
|
| 276 |
stop_tokens: Optional[List[int]] = Field([0])
|
| 277 |
+
tools: Optional[List[Tool]] = Field(default=None)
|
| 278 |
+
tool_choice: Optional[Union[str, Dict]] = Field(default="auto")
|
| 279 |
|
| 280 |
@model_validator(mode="before")
|
| 281 |
@classmethod
|
|
|
|
| 285 |
raise ValueError("messages and prompt cannot coexist.")
|
| 286 |
return data
|
| 287 |
|
| 288 |
+
class ToolEngine:
|
| 289 |
+
TOOL_SYSTEM_PROMPT = """
|
| 290 |
+
CAPABILITY: You have access to real-time tools.
|
| 291 |
+
INSTRUCTION: To use a tool, output exactly: <call>tool_name("argument")</call>
|
| 292 |
+
Do not describe the tool, just call it. After the System provides the result, synthesize the answer.
|
| 293 |
+
|
| 294 |
+
AVAILABLE TOOLS:
|
| 295 |
+
1. google_search(query): Searches Google and DuckDuckGo for real-time information.
|
| 296 |
+
2. visit_page(url): Accesses a specific link, reads the text, and finds sub-links.
|
| 297 |
""".strip()
|
| 298 |
|
| 299 |
@staticmethod
|
| 300 |
+
def google_search_request(query: str) -> str:
|
| 301 |
+
try:
|
| 302 |
+
headers = {"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"}
|
| 303 |
+
resp = requests.get("https://www.google.com/search", params={"q": query, "gl": "us", "hl": "en"}, headers=headers, timeout=6)
|
| 304 |
+
|
| 305 |
+
if resp.status_code != 200: raise Exception("Google blocked request")
|
| 306 |
+
|
| 307 |
+
clean_text = re.sub(r'<script.*?>.*?</script>', '', resp.text, flags=re.DOTALL)
|
| 308 |
+
clean_text = re.sub(r'<style.*?>.*?</style>', '', clean_text, flags=re.DOTALL)
|
| 309 |
+
|
| 310 |
+
headings = re.findall(r'<h3.*?>(.*?)</h3>', clean_text)
|
| 311 |
+
links = re.findall(r'<a href="/url\?q=(.*?)&', clean_text)
|
| 312 |
+
|
| 313 |
+
limit = min(len(headings), len(links), 5)
|
| 314 |
+
output = "Google Results:\n"
|
| 315 |
+
for i in range(limit):
|
| 316 |
+
output += f"{i+1}. {re.sub(r'<.*?>', '', headings[i])} - Link: {links[i]}\n"
|
| 317 |
+
|
| 318 |
+
if not headings:
|
| 319 |
+
return ToolEngine.duckduckgo_fallback(query)
|
| 320 |
+
|
| 321 |
+
return output
|
| 322 |
+
except:
|
| 323 |
+
return ToolEngine.duckduckgo_fallback(query)
|
| 324 |
+
|
| 325 |
+
@staticmethod
|
| 326 |
+
def duckduckgo_fallback(query: str) -> str:
|
| 327 |
+
try:
|
| 328 |
+
if HAS_DDG:
|
| 329 |
+
res = DDGS().text(query, max_results=5)
|
| 330 |
+
return "\n".join([f"- {r['title']}: {r['body']} ({r['href']})" for r in res])
|
| 331 |
+
|
| 332 |
+
resp = requests.get("https://html.duckduckgo.com/html/", params={"q": query}, headers={"User-Agent": "Mozilla/5.0"}, timeout=5)
|
| 333 |
+
titles = re.findall(r'<a class="result__a"[^>]*>(.*?)</a>', resp.text)
|
| 334 |
+
snippets = re.findall(r'<a class="result__snippet"[^>]*>(.*?)</a>', resp.text)
|
| 335 |
+
|
| 336 |
+
limit = min(len(titles), len(snippets), 4)
|
| 337 |
+
out = "DuckDuckGo HTML Results:\n"
|
| 338 |
+
for i in range(limit):
|
| 339 |
+
t = re.sub(r'<.*?>', '', titles[i]).strip()
|
| 340 |
+
s = re.sub(r'<.*?>', '', snippets[i]).strip()
|
| 341 |
+
out += f"{i+1}. {t}: {s}\n"
|
| 342 |
+
return out
|
| 343 |
+
except Exception as e:
|
| 344 |
+
return f"Search failed: {str(e)}"
|
| 345 |
+
|
| 346 |
+
@staticmethod
|
| 347 |
+
def visit_page(url: str) -> str:
|
| 348 |
+
try:
|
| 349 |
+
headers = {"User-Agent": "Mozilla/5.0 (compatible; RWKV-Bot/1.0)"}
|
| 350 |
+
resp = requests.get(url, headers=headers, timeout=8)
|
| 351 |
+
resp.encoding = resp.apparent_encoding
|
| 352 |
+
|
| 353 |
+
text = re.sub(r'<head.*?>.*?</head>', '', resp.text, flags=re.DOTALL)
|
| 354 |
+
text = re.sub(r'<script.*?>.*?</script>', '', text, flags=re.DOTALL)
|
| 355 |
+
text = re.sub(r'<style.*?>.*?</style>', '', text, flags=re.DOTALL)
|
| 356 |
+
text = re.sub(r'<!--.*?-->', '', text, flags=re.DOTALL)
|
| 357 |
+
text = re.sub(r'<[^>]+>', ' ', text)
|
| 358 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 359 |
+
|
| 360 |
+
links = re.findall(r'href=["\'](http[s]?://[^"\']+)["\']', resp.text)
|
| 361 |
+
unique_links = list(set(links))[:5]
|
| 362 |
+
|
| 363 |
+
content_preview = text[:3000] + ("..." if len(text) > 3000 else "")
|
| 364 |
+
|
| 365 |
+
return f"PAGE CONTENT ({url}):\n{content_preview}\n\nFOUND SUB-LINKS:\n" + "\n".join(unique_links)
|
| 366 |
+
except Exception as e:
|
| 367 |
+
return f"Error visiting page: {str(e)}"
|
| 368 |
|
| 369 |
@staticmethod
|
| 370 |
+
def execute(call_str: str) -> str:
|
| 371 |
+
try:
|
| 372 |
+
match = re.match(r'(\w+)\(["\'](.*?)["\']\)', call_str)
|
| 373 |
+
if not match: return "Invalid tool call syntax."
|
| 374 |
+
|
| 375 |
+
func, arg = match.groups()
|
| 376 |
+
|
| 377 |
+
if func == "google_search":
|
| 378 |
+
return ToolEngine.google_search_request(arg)
|
| 379 |
+
elif func == "visit_page":
|
| 380 |
+
return ToolEngine.visit_page(arg)
|
| 381 |
+
else:
|
| 382 |
+
return f"Unknown tool: {func}"
|
| 383 |
+
except Exception as e:
|
| 384 |
+
return f"Tool execution error: {e}"
|
| 385 |
+
|
| 386 |
+
app = FastAPI(title="RWKV Ultimate Agent Server")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
|
| 388 |
app.add_middleware(
|
| 389 |
CORSMiddleware,
|
|
|
|
| 446 |
cache_word_list = []
|
| 447 |
|
| 448 |
stop_sequences = request.stop if request.stop else []
|
| 449 |
+
stop_sequences.append("<call>")
|
| 450 |
|
| 451 |
for i in range(max_tokens):
|
| 452 |
for n in occurrence: out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
|
|
|
|
| 469 |
out_last = i + 1
|
| 470 |
|
| 471 |
current_buffer = "".join(cache_word_list)
|
| 472 |
+
|
| 473 |
+
if "<call>" in current_buffer:
|
| 474 |
+
pre_call = current_buffer.split("<call>")[0]
|
| 475 |
+
yield {"content": pre_call, "finish_reason": "tool_start", "state": model_state}
|
| 476 |
+
del out; gc.collect(); return
|
| 477 |
+
|
| 478 |
for s in stop_sequences:
|
| 479 |
+
if s in current_buffer and s != "<call>":
|
| 480 |
final_content = current_buffer.split(s)[0]
|
| 481 |
yield {"content": final_content, "finish_reason": "stop", "state": model_state}
|
| 482 |
del out; gc.collect(); return
|
| 483 |
|
| 484 |
+
if len(cache_word_list) > 2:
|
| 485 |
yield {"content": cache_word_list.pop(0), "finish_reason": None}
|
| 486 |
|
| 487 |
yield {"content": "".join(cache_word_list), "finish_reason": "length"}
|
| 488 |
|
| 489 |
async def chatResponseStream(request: ChatCompletionRequest, model_state: any, completionId: str, enableReasoning: bool):
|
| 490 |
+
current_messages = request.messages
|
|
|
|
| 491 |
|
| 492 |
+
for step in range(4):
|
| 493 |
+
clean_msg = cleanMessages(current_messages, enableReasoning)
|
| 494 |
+
prompt = f"{clean_msg}\n\nAssistant:{' <think' if enableReasoning else ''}"
|
| 495 |
+
|
| 496 |
+
tool_buffer = ""
|
| 497 |
+
tool_call_mode = False
|
| 498 |
+
|
| 499 |
+
async with GPU_LOCK:
|
| 500 |
+
try:
|
| 501 |
+
out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
|
| 502 |
+
|
| 503 |
+
if step == 0:
|
| 504 |
+
yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(role='Assistant', content=''), finish_reason=None)]).model_dump_json()}\n\n"
|
| 505 |
+
|
| 506 |
+
for chunk in generate(request, out, model_tokens, model_state, max_tokens=request.max_tokens or 4096):
|
| 507 |
+
content = chunk.get("content", "")
|
| 508 |
+
finish = chunk.get("finish_reason", None)
|
| 509 |
+
|
| 510 |
+
if finish == "tool_start":
|
| 511 |
+
tool_call_mode = True
|
| 512 |
+
if content:
|
| 513 |
+
yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(content=content), finish_reason=None)]).model_dump_json()}\n\n"
|
| 514 |
+
break
|
| 515 |
+
|
| 516 |
+
if content:
|
| 517 |
+
yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(content=content), finish_reason=None)]).model_dump_json()}\n\n"
|
| 518 |
+
|
| 519 |
+
if finish:
|
| 520 |
+
yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(content=''), finish_reason=finish)]).model_dump_json()}\n\n"
|
| 521 |
+
return
|
| 522 |
+
|
| 523 |
+
finally:
|
| 524 |
+
pass
|
| 525 |
+
|
| 526 |
+
if tool_call_mode:
|
| 527 |
+
full_tool_call = ""
|
| 528 |
|
| 529 |
+
async with GPU_LOCK:
|
| 530 |
+
try:
|
| 531 |
+
tool_out, tool_tokens, tool_state = await runPrefill(request, "", [0], model_state)
|
| 532 |
+
temp_tokens = []
|
| 533 |
+
|
| 534 |
+
current_gen = ""
|
| 535 |
+
|
| 536 |
+
for i in range(200):
|
| 537 |
+
args = PIPELINE_ARGS(temperature=0.1, top_p=0.1)
|
| 538 |
+
tool_token = MODEL_STORAGE[request.model].pipeline.sample_logits(tool_out, temperature=0.1, top_p=0.1)
|
| 539 |
+
tool_out, tool_state = MODEL_STORAGE[request.model].model.forward([tool_token], tool_state)
|
| 540 |
+
|
| 541 |
+
char = MODEL_STORAGE[request.model].pipeline.decode([tool_token])
|
| 542 |
+
current_gen += char
|
| 543 |
+
|
| 544 |
+
if "</call>" in current_gen:
|
| 545 |
+
full_tool_call = current_gen.split("</call>")[0]
|
| 546 |
+
break
|
| 547 |
+
finally:
|
| 548 |
+
pass
|
| 549 |
|
| 550 |
+
if full_tool_call:
|
| 551 |
+
result = ToolEngine.execute(full_tool_call)
|
| 552 |
+
current_messages.append(ChatMessage(role="assistant", content=f"<call>{full_tool_call}</call>"))
|
| 553 |
+
current_messages.append(ChatMessage(role="tool", content=result, name="system"))
|
| 554 |
+
else:
|
| 555 |
+
break
|
| 556 |
+
else:
|
| 557 |
+
break
|
|
|
|
|
|
|
|
|
|
| 558 |
|
| 559 |
yield "data: [DONE]\n\n"
|
| 560 |
|
|
|
|
| 580 |
if req_data.get(k) is None: req_data[k] = v
|
| 581 |
realRequest = ChatCompletionRequest(**req_data)
|
| 582 |
|
| 583 |
+
enable_tools = ":online" in raw_model or realRequest.tools is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
|
| 585 |
+
if enable_tools:
|
| 586 |
+
sys_msg = ChatMessage(role="System", content=ToolEngine.TOOL_SYSTEM_PROMPT)
|
| 587 |
+
if realRequest.messages:
|
| 588 |
+
if realRequest.messages[0].role == "System":
|
| 589 |
+
realRequest.messages[0].content += f"\n\n{ToolEngine.TOOL_SYSTEM_PROMPT}"
|
| 590 |
+
else:
|
| 591 |
+
realRequest.messages.insert(0, sys_msg)
|
| 592 |
+
|
| 593 |
realRequest.messages = prune_context(realRequest.messages, target_model, realRequest.max_tokens or 1024)
|
| 594 |
|
| 595 |
return StreamingResponse(chatResponseStream(realRequest, None, completionId, is_reasoning), media_type="text/event-stream")
|