Ksjsjjdj commited on
Commit
950252a
·
verified ·
1 Parent(s): 4cb338c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -80
app.py CHANGED
@@ -3,6 +3,7 @@ import re
3
  import gc
4
  import sys
5
  import time
 
6
  import queue
7
  import random
8
  import asyncio
@@ -36,6 +37,8 @@ GPU_LOCK = asyncio.Lock()
36
  class ChatMessage(BaseModel):
37
  role: str = Field()
38
  content: str = Field()
 
 
39
 
40
  class Logprob(BaseModel):
41
  token: str
@@ -76,6 +79,15 @@ class ChatCompletionChunk(BaseModel):
76
  choices: List[ChatCompletionChoice]
77
  usage: Optional[Usage]
78
 
 
 
 
 
 
 
 
 
 
79
  def remove_nested_think_tags_stack(text):
80
  stack = []
81
  result = ""
@@ -106,7 +118,17 @@ def cleanMessages(messages: List[ChatMessage], removeThinkingContent: bool = Fal
106
  role_str = message.role.strip().lower().capitalize()
107
  if role_str == 'Assistant' and removeThinkingContent:
108
  content = remove_nested_think_tags_stack(content)
109
- promptStrList.append(f"{role_str}: {content}")
 
 
 
 
 
 
 
 
 
 
110
  return "\n\n".join(promptStrList)
111
 
112
  class SamplerConfig(BaseModel):
@@ -252,6 +274,8 @@ class ChatCompletionRequest(BaseModel):
252
  stream: Optional[bool] = Field(default=False)
253
  stop: Optional[List[str]] = Field(["\n\n"])
254
  stop_tokens: Optional[List[int]] = Field([0])
 
 
255
 
256
  @model_validator(mode="before")
257
  @classmethod
@@ -261,57 +285,105 @@ class ChatCompletionRequest(BaseModel):
261
  raise ValueError("messages and prompt cannot coexist.")
262
  return data
263
 
264
- class TruthProtocol:
265
- STRICT_SYSTEM_PROMPT = """
266
- CRITICAL PROTOCOL: OBJECTIVE_REALITY_ONLY
267
- 1. NEUTRALITY: Abandon all bias, opinions, and emotional language.
268
- 2. FACTS ONLY: State verified facts. If data is missing, say [NO DATA]. Do not guess.
269
- 3. NO HALLUCINATIONS: Do not invent dates, names, or events.
270
- 4. LABELS: Tag assertions with [VERIFIED] or [UNCERTAIN].
271
- 5. TONE: Robotic, precise, dense. No pleasantries.
 
272
  """.strip()
273
 
274
  @staticmethod
275
- def enforce_truth_params(request: ChatCompletionRequest):
276
- request.temperature = 0.12
277
- request.top_p = 0.1
278
- request.count_penalty = 1.1
279
- request.presence_penalty = 0.6
280
- request.penalty_decay = 0.996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
  @staticmethod
283
- def sanitise_search(query: str, results: List[dict]) -> str:
284
- context = "RAW DATA STREAM (IGNORE OPINIONS, EXTRACT FACTS):\n"
285
- for i, res in enumerate(results):
286
- clean_body = res['body'].replace("\n", " ").strip()
287
- context += f"SOURCE [{i+1}]: {clean_body} (Origin: {res['title']})\n"
288
- return context
289
-
290
- search_cache = collections.OrderedDict()
291
-
292
- def search_facts(query: str) -> str:
293
- if not HAS_DDG: return ""
294
- if query in search_cache: return search_cache[query]
295
- try:
296
- ddgs = DDGS()
297
- results = ddgs.text(query, max_results=4)
298
- if any(x in query.lower() for x in ["verdad", "fake", "cierto", "mentira"]):
299
- check = ddgs.text(f"{query} fact check verified", max_results=2)
300
- if check: results.extend(check)
301
- if not results: return ""
302
- ctx = TruthProtocol.sanitise_search(query, results)
303
- if len(search_cache) > 50: search_cache.popitem(last=False)
304
- search_cache[query] = ctx
305
- return ctx
306
- except:
307
- return ""
308
-
309
- def needs_verification(msg: str, model: str) -> bool:
310
- if ":online" in model: return True
311
- triggers = ["es verdad", "dato", "precio", "cuando", "quien", "noticia", "actualidad", "verify"]
312
- return any(t in msg.lower() for t in triggers)
313
-
314
- app = FastAPI(title="RWKV Ultimate Server")
315
 
316
  app.add_middleware(
317
  CORSMiddleware,
@@ -374,6 +446,7 @@ def generate(request: ChatCompletionRequest, out, model_tokens: List[int], model
374
  cache_word_list = []
375
 
376
  stop_sequences = request.stop if request.stop else []
 
377
 
378
  for i in range(max_tokens):
379
  for n in occurrence: out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
@@ -396,38 +469,92 @@ def generate(request: ChatCompletionRequest, out, model_tokens: List[int], model
396
  out_last = i + 1
397
 
398
  current_buffer = "".join(cache_word_list)
 
 
 
 
 
 
399
  for s in stop_sequences:
400
- if s in current_buffer:
401
  final_content = current_buffer.split(s)[0]
402
  yield {"content": final_content, "finish_reason": "stop", "state": model_state}
403
  del out; gc.collect(); return
404
 
405
- if len(cache_word_list) > 1:
406
  yield {"content": cache_word_list.pop(0), "finish_reason": None}
407
 
408
  yield {"content": "".join(cache_word_list), "finish_reason": "length"}
409
 
410
  async def chatResponseStream(request: ChatCompletionRequest, model_state: any, completionId: str, enableReasoning: bool):
411
- clean_msg = cleanMessages(request.messages, enableReasoning)
412
- prompt = f"{clean_msg}\n\nAssistant:{' <think' if enableReasoning else ''}"
413
 
414
- async with GPU_LOCK:
415
- try:
416
- out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
 
418
- yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(role='Assistant', content=''), finish_reason=None)]).model_dump_json()}\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
- for chunk in generate(request, out, model_tokens, model_state, max_tokens=request.max_tokens or 4096):
421
- content = chunk.get("content", "")
422
- finish = chunk.get("finish_reason", None)
423
- if content:
424
- yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(content=content), finish_reason=None)]).model_dump_json()}\n\n"
425
- if finish:
426
- yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(content=''), finish_reason=finish)]).model_dump_json()}\n\n"
427
- break
428
- await asyncio.sleep(0)
429
- finally:
430
- pass
431
 
432
  yield "data: [DONE]\n\n"
433
 
@@ -453,21 +580,16 @@ async def chat_completions(request: ChatCompletionRequest):
453
  if req_data.get(k) is None: req_data[k] = v
454
  realRequest = ChatCompletionRequest(**req_data)
455
 
456
- sys_msg = ChatMessage(role="System", content=TruthProtocol.STRICT_SYSTEM_PROMPT)
457
- if realRequest.messages:
458
- if realRequest.messages[0].role == "System":
459
- realRequest.messages[0].content = f"{TruthProtocol.STRICT_SYSTEM_PROMPT}\n\n{realRequest.messages[0].content}"
460
- else:
461
- realRequest.messages.insert(0, sys_msg)
462
-
463
- last_msg = realRequest.messages[-1]
464
- if last_msg.role == "user" and needs_verification(last_msg.content, raw_model):
465
- ctx = search_facts(last_msg.content)
466
- if ctx:
467
- realRequest.messages.insert(-1, ChatMessage(role="System", content=ctx))
468
-
469
- TruthProtocol.enforce_truth_params(realRequest)
470
 
 
 
 
 
 
 
 
 
471
  realRequest.messages = prune_context(realRequest.messages, target_model, realRequest.max_tokens or 1024)
472
 
473
  return StreamingResponse(chatResponseStream(realRequest, None, completionId, is_reasoning), media_type="text/event-stream")
 
3
  import gc
4
  import sys
5
  import time
6
+ import json
7
  import queue
8
  import random
9
  import asyncio
 
37
  class ChatMessage(BaseModel):
38
  role: str = Field()
39
  content: str = Field()
40
+ name: Optional[str] = Field(None)
41
+ tool_call_id: Optional[str] = Field(None)
42
 
43
  class Logprob(BaseModel):
44
  token: str
 
79
  choices: List[ChatCompletionChoice]
80
  usage: Optional[Usage]
81
 
82
+ class ToolFunction(BaseModel):
83
+ name: str
84
+ description: str
85
+ parameters: Dict[str, Any]
86
+
87
+ class Tool(BaseModel):
88
+ type: Literal["function"] = "function"
89
+ function: ToolFunction
90
+
91
  def remove_nested_think_tags_stack(text):
92
  stack = []
93
  result = ""
 
118
  role_str = message.role.strip().lower().capitalize()
119
  if role_str == 'Assistant' and removeThinkingContent:
120
  content = remove_nested_think_tags_stack(content)
121
+
122
+ if message.role == "tool":
123
+ promptStrList.append(f"Tool Output ({message.name}): {content}")
124
+ elif message.role == "system":
125
+ promptStrList.append(f"System: {content}")
126
+ elif message.role == "user":
127
+ promptStrList.append(f"User: {content}")
128
+ elif message.role == "assistant":
129
+ promptStrList.append(f"Assistant: {content}")
130
+ else:
131
+ promptStrList.append(f"{role_str}: {content}")
132
  return "\n\n".join(promptStrList)
133
 
134
  class SamplerConfig(BaseModel):
 
274
  stream: Optional[bool] = Field(default=False)
275
  stop: Optional[List[str]] = Field(["\n\n"])
276
  stop_tokens: Optional[List[int]] = Field([0])
277
+ tools: Optional[List[Tool]] = Field(default=None)
278
+ tool_choice: Optional[Union[str, Dict]] = Field(default="auto")
279
 
280
  @model_validator(mode="before")
281
  @classmethod
 
285
  raise ValueError("messages and prompt cannot coexist.")
286
  return data
287
 
288
+ class ToolEngine:
289
+ TOOL_SYSTEM_PROMPT = """
290
+ CAPABILITY: You have access to real-time tools.
291
+ INSTRUCTION: To use a tool, output exactly: <call>tool_name("argument")</call>
292
+ Do not describe the tool, just call it. After the System provides the result, synthesize the answer.
293
+
294
+ AVAILABLE TOOLS:
295
+ 1. google_search(query): Searches Google and DuckDuckGo for real-time information.
296
+ 2. visit_page(url): Accesses a specific link, reads the text, and finds sub-links.
297
  """.strip()
298
 
299
  @staticmethod
300
+ def google_search_request(query: str) -> str:
301
+ try:
302
+ headers = {"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"}
303
+ resp = requests.get("https://www.google.com/search", params={"q": query, "gl": "us", "hl": "en"}, headers=headers, timeout=6)
304
+
305
+ if resp.status_code != 200: raise Exception("Google blocked request")
306
+
307
+ clean_text = re.sub(r'<script.*?>.*?</script>', '', resp.text, flags=re.DOTALL)
308
+ clean_text = re.sub(r'<style.*?>.*?</style>', '', clean_text, flags=re.DOTALL)
309
+
310
+ headings = re.findall(r'<h3.*?>(.*?)</h3>', clean_text)
311
+ links = re.findall(r'<a href="/url\?q=(.*?)&', clean_text)
312
+
313
+ limit = min(len(headings), len(links), 5)
314
+ output = "Google Results:\n"
315
+ for i in range(limit):
316
+ output += f"{i+1}. {re.sub(r'<.*?>', '', headings[i])} - Link: {links[i]}\n"
317
+
318
+ if not headings:
319
+ return ToolEngine.duckduckgo_fallback(query)
320
+
321
+ return output
322
+ except:
323
+ return ToolEngine.duckduckgo_fallback(query)
324
+
325
+ @staticmethod
326
+ def duckduckgo_fallback(query: str) -> str:
327
+ try:
328
+ if HAS_DDG:
329
+ res = DDGS().text(query, max_results=5)
330
+ return "\n".join([f"- {r['title']}: {r['body']} ({r['href']})" for r in res])
331
+
332
+ resp = requests.get("https://html.duckduckgo.com/html/", params={"q": query}, headers={"User-Agent": "Mozilla/5.0"}, timeout=5)
333
+ titles = re.findall(r'<a class="result__a"[^>]*>(.*?)</a>', resp.text)
334
+ snippets = re.findall(r'<a class="result__snippet"[^>]*>(.*?)</a>', resp.text)
335
+
336
+ limit = min(len(titles), len(snippets), 4)
337
+ out = "DuckDuckGo HTML Results:\n"
338
+ for i in range(limit):
339
+ t = re.sub(r'<.*?>', '', titles[i]).strip()
340
+ s = re.sub(r'<.*?>', '', snippets[i]).strip()
341
+ out += f"{i+1}. {t}: {s}\n"
342
+ return out
343
+ except Exception as e:
344
+ return f"Search failed: {str(e)}"
345
+
346
+ @staticmethod
347
+ def visit_page(url: str) -> str:
348
+ try:
349
+ headers = {"User-Agent": "Mozilla/5.0 (compatible; RWKV-Bot/1.0)"}
350
+ resp = requests.get(url, headers=headers, timeout=8)
351
+ resp.encoding = resp.apparent_encoding
352
+
353
+ text = re.sub(r'<head.*?>.*?</head>', '', resp.text, flags=re.DOTALL)
354
+ text = re.sub(r'<script.*?>.*?</script>', '', text, flags=re.DOTALL)
355
+ text = re.sub(r'<style.*?>.*?</style>', '', text, flags=re.DOTALL)
356
+ text = re.sub(r'<!--.*?-->', '', text, flags=re.DOTALL)
357
+ text = re.sub(r'<[^>]+>', ' ', text)
358
+ text = re.sub(r'\s+', ' ', text).strip()
359
+
360
+ links = re.findall(r'href=["\'](http[s]?://[^"\']+)["\']', resp.text)
361
+ unique_links = list(set(links))[:5]
362
+
363
+ content_preview = text[:3000] + ("..." if len(text) > 3000 else "")
364
+
365
+ return f"PAGE CONTENT ({url}):\n{content_preview}\n\nFOUND SUB-LINKS:\n" + "\n".join(unique_links)
366
+ except Exception as e:
367
+ return f"Error visiting page: {str(e)}"
368
 
369
  @staticmethod
370
+ def execute(call_str: str) -> str:
371
+ try:
372
+ match = re.match(r'(\w+)\(["\'](.*?)["\']\)', call_str)
373
+ if not match: return "Invalid tool call syntax."
374
+
375
+ func, arg = match.groups()
376
+
377
+ if func == "google_search":
378
+ return ToolEngine.google_search_request(arg)
379
+ elif func == "visit_page":
380
+ return ToolEngine.visit_page(arg)
381
+ else:
382
+ return f"Unknown tool: {func}"
383
+ except Exception as e:
384
+ return f"Tool execution error: {e}"
385
+
386
+ app = FastAPI(title="RWKV Ultimate Agent Server")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
  app.add_middleware(
389
  CORSMiddleware,
 
446
  cache_word_list = []
447
 
448
  stop_sequences = request.stop if request.stop else []
449
+ stop_sequences.append("<call>")
450
 
451
  for i in range(max_tokens):
452
  for n in occurrence: out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
 
469
  out_last = i + 1
470
 
471
  current_buffer = "".join(cache_word_list)
472
+
473
+ if "<call>" in current_buffer:
474
+ pre_call = current_buffer.split("<call>")[0]
475
+ yield {"content": pre_call, "finish_reason": "tool_start", "state": model_state}
476
+ del out; gc.collect(); return
477
+
478
  for s in stop_sequences:
479
+ if s in current_buffer and s != "<call>":
480
  final_content = current_buffer.split(s)[0]
481
  yield {"content": final_content, "finish_reason": "stop", "state": model_state}
482
  del out; gc.collect(); return
483
 
484
+ if len(cache_word_list) > 2:
485
  yield {"content": cache_word_list.pop(0), "finish_reason": None}
486
 
487
  yield {"content": "".join(cache_word_list), "finish_reason": "length"}
488
 
489
  async def chatResponseStream(request: ChatCompletionRequest, model_state: any, completionId: str, enableReasoning: bool):
490
+ current_messages = request.messages
 
491
 
492
+ for step in range(4):
493
+ clean_msg = cleanMessages(current_messages, enableReasoning)
494
+ prompt = f"{clean_msg}\n\nAssistant:{' <think' if enableReasoning else ''}"
495
+
496
+ tool_buffer = ""
497
+ tool_call_mode = False
498
+
499
+ async with GPU_LOCK:
500
+ try:
501
+ out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
502
+
503
+ if step == 0:
504
+ yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(role='Assistant', content=''), finish_reason=None)]).model_dump_json()}\n\n"
505
+
506
+ for chunk in generate(request, out, model_tokens, model_state, max_tokens=request.max_tokens or 4096):
507
+ content = chunk.get("content", "")
508
+ finish = chunk.get("finish_reason", None)
509
+
510
+ if finish == "tool_start":
511
+ tool_call_mode = True
512
+ if content:
513
+ yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(content=content), finish_reason=None)]).model_dump_json()}\n\n"
514
+ break
515
+
516
+ if content:
517
+ yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(content=content), finish_reason=None)]).model_dump_json()}\n\n"
518
+
519
+ if finish:
520
+ yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(content=''), finish_reason=finish)]).model_dump_json()}\n\n"
521
+ return
522
+
523
+ finally:
524
+ pass
525
+
526
+ if tool_call_mode:
527
+ full_tool_call = ""
528
 
529
+ async with GPU_LOCK:
530
+ try:
531
+ tool_out, tool_tokens, tool_state = await runPrefill(request, "", [0], model_state)
532
+ temp_tokens = []
533
+
534
+ current_gen = ""
535
+
536
+ for i in range(200):
537
+ args = PIPELINE_ARGS(temperature=0.1, top_p=0.1)
538
+ tool_token = MODEL_STORAGE[request.model].pipeline.sample_logits(tool_out, temperature=0.1, top_p=0.1)
539
+ tool_out, tool_state = MODEL_STORAGE[request.model].model.forward([tool_token], tool_state)
540
+
541
+ char = MODEL_STORAGE[request.model].pipeline.decode([tool_token])
542
+ current_gen += char
543
+
544
+ if "</call>" in current_gen:
545
+ full_tool_call = current_gen.split("</call>")[0]
546
+ break
547
+ finally:
548
+ pass
549
 
550
+ if full_tool_call:
551
+ result = ToolEngine.execute(full_tool_call)
552
+ current_messages.append(ChatMessage(role="assistant", content=f"<call>{full_tool_call}</call>"))
553
+ current_messages.append(ChatMessage(role="tool", content=result, name="system"))
554
+ else:
555
+ break
556
+ else:
557
+ break
 
 
 
558
 
559
  yield "data: [DONE]\n\n"
560
 
 
580
  if req_data.get(k) is None: req_data[k] = v
581
  realRequest = ChatCompletionRequest(**req_data)
582
 
583
+ enable_tools = ":online" in raw_model or realRequest.tools is not None
 
 
 
 
 
 
 
 
 
 
 
 
 
584
 
585
+ if enable_tools:
586
+ sys_msg = ChatMessage(role="System", content=ToolEngine.TOOL_SYSTEM_PROMPT)
587
+ if realRequest.messages:
588
+ if realRequest.messages[0].role == "System":
589
+ realRequest.messages[0].content += f"\n\n{ToolEngine.TOOL_SYSTEM_PROMPT}"
590
+ else:
591
+ realRequest.messages.insert(0, sys_msg)
592
+
593
  realRequest.messages = prune_context(realRequest.messages, target_model, realRequest.max_tokens or 1024)
594
 
595
  return StreamingResponse(chatResponseStream(realRequest, None, completionId, is_reasoning), media_type="text/event-stream")