Ksjsjjdj commited on
Commit
0bc9661
·
verified ·
1 Parent(s): 819ad30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -76
app.py CHANGED
@@ -20,16 +20,18 @@ from fastapi.middleware.cors import CORSMiddleware
20
  from fastapi.staticfiles import StaticFiles
21
  from fastapi.middleware.gzip import GZipMiddleware
22
  from huggingface_hub import hf_hub_download
23
- from loguru import logger
24
  from snowflake import SnowflakeGenerator
25
 
26
  if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio":
27
  from modelscope import patch_hub
28
  patch_hub()
29
 
30
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"
31
  os.environ["RWKV_V7_ON"] = "1"
32
  os.environ["RWKV_JIT_ON"] = "1"
 
 
 
33
 
34
  class ChatMessage(BaseModel):
35
  role: str = Field()
@@ -44,10 +46,6 @@ class LogprobsContent(BaseModel):
44
  content: Optional[List[Logprob]] = None
45
  refusal: Optional[List[Logprob]] = None
46
 
47
- class FunctionCall(BaseModel):
48
- name: str
49
- arguments: str
50
-
51
  class ChatCompletionMessage(BaseModel):
52
  role: Optional[str] = Field(None)
53
  content: Optional[str] = Field(None)
@@ -57,11 +55,6 @@ class ChatCompletionMessage(BaseModel):
57
  class PromptTokensDetails(BaseModel):
58
  cached_tokens: int
59
 
60
- class CompletionTokensDetails(BaseModel):
61
- reasoning_tokens: int
62
- accepted_prediction_tokens: int
63
- rejected_prediction_tokens: int
64
-
65
  class Usage(BaseModel):
66
  prompt_tokens: int
67
  completion_tokens: int
@@ -75,14 +68,6 @@ class ChatCompletionChoice(BaseModel):
75
  logprobs: Optional[LogprobsContent] = None
76
  finish_reason: Optional[str] = Field(...)
77
 
78
- class ChatCompletion(BaseModel):
79
- id: str = Field(...)
80
- object: Literal["chat.completion"] = "chat.completion"
81
- created: int = Field(...)
82
- model: str
83
- choices: List[ChatCompletionChoice]
84
- usage: Usage
85
-
86
  class ChatCompletionChunk(BaseModel):
87
  id: str = Field(...)
88
  object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
@@ -113,20 +98,6 @@ def remove_nested_think_tags_stack(text):
113
  i += 1
114
  return result
115
 
116
- def parse_think_response(full_response: str):
117
- think_start = full_response.find("<think")
118
- if think_start == -1:
119
- return None, full_response.strip()
120
- think_end = full_response.find("</think>")
121
- if think_end == -1:
122
- reasoning = full_response[think_start:].strip()
123
- content = ""
124
- else:
125
- reasoning = full_response[think_start : think_end + 9].strip()
126
- content = full_response[think_end + 9 :].strip()
127
- reasoning_content = reasoning.replace("<think", "").replace("</think>", "").strip()
128
- return reasoning_content, content
129
-
130
  def cleanMessages(messages: List[ChatMessage], removeThinkingContent: bool = False):
131
  promptStrList = []
132
  for message in messages:
@@ -138,35 +109,6 @@ def cleanMessages(messages: List[ChatMessage], removeThinkingContent: bool = Fal
138
  promptStrList.append(f"{role_str}: {content}")
139
  return "\n\n".join(promptStrList)
140
 
141
- def format_bytes(size):
142
- power = 2**10
143
- n = 0
144
- power_labels = {0: "", 1: "K", 2: "M", 3: "G", 4: "T"}
145
- while size > power:
146
- size /= power
147
- n += 1
148
- return f"{size:.4f}{power_labels[n]+'B'}"
149
-
150
- LOGGER_QUEUE = queue.Queue(5)
151
-
152
- def logger_worker():
153
- while True:
154
- item = LOGGER_QUEUE.get()
155
- try:
156
- requests.post(
157
- os.environ.get("LOG_PORT"),
158
- headers={"Content-Type": "application/json"},
159
- json=item,
160
- )
161
- except Exception:
162
- pass
163
-
164
- if os.environ.get("LOG_PORT"):
165
- threading.Thread(target=logger_worker).start()
166
-
167
- def log(item):
168
- LOGGER_QUEUE.put_nowait(item)
169
-
170
  class SamplerConfig(BaseModel):
171
  max_tokens: int = 4096
172
  temperature: float = 1.0
@@ -187,6 +129,7 @@ class ModelConfig(BaseModel):
187
  DEFAULT_REASONING: bool = False
188
  REASONING: bool = False
189
  VOCAB: str = "rwkv_vocab_v20230424"
 
190
  DEFAULT_SAMPLER: SamplerConfig = Field(default_factory=SamplerConfig)
191
 
192
  class Config(BaseSettings):
@@ -200,19 +143,22 @@ class Config(BaseSettings):
200
  SERVICE_NAME="rwkv7-g1a4-2.9b-20251118-ctx8192",
201
  DOWNLOAD_MODEL_FILE_NAME="rwkv7-g1a4-2.9b-20251118-ctx8192.pth",
202
  DOWNLOAD_MODEL_REPO_ID="BlinkDL/rwkv7-g1",
203
- REASONING=True
 
204
  ),
205
  ModelConfig(
206
  SERVICE_NAME="rwkv7-g1a3-1.5b-20251015-ctx8192",
207
  DOWNLOAD_MODEL_FILE_NAME="rwkv7-g1a3-1.5b-20251015-ctx8192.pth",
208
  DOWNLOAD_MODEL_REPO_ID="BlinkDL/rwkv7-g1",
209
- REASONING=True
 
210
  ),
211
  ModelConfig(
212
  SERVICE_NAME="rwkv7-g1a-0.4b-20250905-ctx4096",
213
  DOWNLOAD_MODEL_FILE_NAME="rwkv7-g1a-0.4b-20250905-ctx4096.pth",
214
  DOWNLOAD_MODEL_REPO_ID="BlinkDL/rwkv7-g1",
215
- REASONING=True
 
216
  ),
217
  ModelConfig(
218
  SERVICE_NAME="rwkv7-g1a-0.1b-20250728-ctx4096",
@@ -220,7 +166,8 @@ class Config(BaseSettings):
220
  DOWNLOAD_MODEL_REPO_ID="BlinkDL/rwkv7-g1",
221
  REASONING=True,
222
  DEFAULT_CHAT=True,
223
- DEFAULT_REASONING=True
 
224
  ),
225
  ]
226
 
@@ -248,7 +195,6 @@ if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available():
248
  if CONFIG.RWKV_CUDA_ON and "cuda" in CONFIG.STRATEGY.lower():
249
  from pynvml import *
250
  nvmlInit()
251
- gpu_h = nvmlDeviceGetHandleByIndex(0)
252
  os.environ["RWKV_CUDA_ON"] = "1"
253
  torch.backends.cudnn.benchmark = True
254
  torch.backends.cudnn.allow_tf32 = True
@@ -365,7 +311,7 @@ def needs_verification(msg: str, model: str) -> bool:
365
  triggers = ["es verdad", "dato", "precio", "cuando", "quien", "noticia", "actualidad", "verify"]
366
  return any(t in msg.lower() for t in triggers)
367
 
368
- app = FastAPI(title="RWKV Zero-Bias Server")
369
 
370
  app.add_middleware(
371
  CORSMiddleware,
@@ -382,6 +328,28 @@ async def privacy_middleware(request: Request, call_next):
382
  request.scope["client"] = (fake.ipv4(), request.client.port if request.client else 80)
383
  return await call_next(request)
384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  async def runPrefill(request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state):
386
  ctx = ctx.replace("\r\n", "\n")
387
  tokens = MODEL_STORAGE[request.model].pipeline.encode(ctx)
@@ -404,36 +372,63 @@ def generate(request: ChatCompletionRequest, out, model_tokens: List[int], model
404
  out_tokens = []
405
  out_last = 0
406
  cache_word_list = []
 
 
 
407
  for i in range(max_tokens):
408
  for n in occurrence: out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
409
  token = MODEL_STORAGE[request.model].pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
 
410
  if token == 0:
411
  yield {"content": "".join(cache_word_list), "finish_reason": "stop", "state": model_state}
412
  del out; gc.collect(); return
 
413
  out, model_state = MODEL_STORAGE[request.model].model.forward([token], model_state)
414
  model_tokens.append(token)
415
  out_tokens.append(token)
 
416
  for xxx in occurrence: occurrence[xxx] *= request.penalty_decay
417
  occurrence[token] = 1 + (occurrence.get(token, 0))
 
418
  tmp = MODEL_STORAGE[request.model].pipeline.decode(out_tokens[out_last:])
419
  if "\ufffd" in tmp: continue
420
  cache_word_list.append(tmp)
421
  out_last = i + 1
 
 
 
 
 
 
 
 
422
  if len(cache_word_list) > 1:
423
  yield {"content": cache_word_list.pop(0), "finish_reason": None}
 
424
  yield {"content": "".join(cache_word_list), "finish_reason": "length"}
425
 
426
  async def chatResponseStream(request: ChatCompletionRequest, model_state: any, completionId: str, enableReasoning: bool):
427
  clean_msg = cleanMessages(request.messages, enableReasoning)
428
  prompt = f"{clean_msg}\n\nAssistant:{' <think' if enableReasoning else ''}"
429
- out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
430
- yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(role='Assistant', content=''), finish_reason=None)]).model_dump_json()}\n\n"
431
- for chunk in generate(request, out, model_tokens, model_state, max_tokens=request.max_tokens or 4096):
432
- content = chunk["content"]
433
- if content:
434
- yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(content=content), finish_reason=None)]).model_dump_json()}\n\n"
435
- if chunk.get("finish_reason"): break
436
- await asyncio.sleep(0)
 
 
 
 
 
 
 
 
 
 
 
437
  yield "data: [DONE]\n\n"
438
 
439
  @app.post("/api/v1/chat/completions")
@@ -446,26 +441,34 @@ async def chat_completions(request: ChatCompletionRequest):
446
  if "rwkv-latest" in model_key:
447
  if is_reasoning and DEFAULT_REASONING_MODEL_NAME: target_model = DEFAULT_REASONING_MODEL_NAME
448
  elif DEFALUT_MODEL_NAME: target_model = DEFALUT_MODEL_NAME
 
449
  if target_model not in MODEL_STORAGE:
450
  raise HTTPException(404, f"Model {target_model} not loaded.")
451
  request.model = target_model
 
452
  default_sampler = MODEL_STORAGE[target_model].MODEL_CONFIG.DEFAULT_SAMPLER
453
  req_data = request.model_dump()
454
  for k, v in default_sampler.model_dump().items():
455
  if req_data.get(k) is None: req_data[k] = v
456
  realRequest = ChatCompletionRequest(**req_data)
 
457
  sys_msg = ChatMessage(role="System", content=TruthProtocol.STRICT_SYSTEM_PROMPT)
458
  if realRequest.messages:
459
  if realRequest.messages[0].role == "System":
460
  realRequest.messages[0].content = f"{TruthProtocol.STRICT_SYSTEM_PROMPT}\n\n{realRequest.messages[0].content}"
461
  else:
462
  realRequest.messages.insert(0, sys_msg)
 
463
  last_msg = realRequest.messages[-1]
464
  if last_msg.role == "user" and needs_verification(last_msg.content, raw_model):
465
  ctx = search_facts(last_msg.content)
466
  if ctx:
467
  realRequest.messages.insert(-1, ChatMessage(role="System", content=ctx))
 
468
  TruthProtocol.enforce_truth_params(realRequest)
 
 
 
469
  return StreamingResponse(chatResponseStream(realRequest, None, completionId, is_reasoning), media_type="text/event-stream")
470
 
471
  @app.get("/api/v1/models")
 
20
  from fastapi.staticfiles import StaticFiles
21
  from fastapi.middleware.gzip import GZipMiddleware
22
  from huggingface_hub import hf_hub_download
 
23
  from snowflake import SnowflakeGenerator
24
 
25
  if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio":
26
  from modelscope import patch_hub
27
  patch_hub()
28
 
29
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"
30
  os.environ["RWKV_V7_ON"] = "1"
31
  os.environ["RWKV_JIT_ON"] = "1"
32
+ os.environ["RWKV_CUDA_ON"] = "1"
33
+
34
+ GPU_LOCK = asyncio.Lock()
35
 
36
  class ChatMessage(BaseModel):
37
  role: str = Field()
 
46
  content: Optional[List[Logprob]] = None
47
  refusal: Optional[List[Logprob]] = None
48
 
 
 
 
 
49
  class ChatCompletionMessage(BaseModel):
50
  role: Optional[str] = Field(None)
51
  content: Optional[str] = Field(None)
 
55
  class PromptTokensDetails(BaseModel):
56
  cached_tokens: int
57
 
 
 
 
 
 
58
  class Usage(BaseModel):
59
  prompt_tokens: int
60
  completion_tokens: int
 
68
  logprobs: Optional[LogprobsContent] = None
69
  finish_reason: Optional[str] = Field(...)
70
 
 
 
 
 
 
 
 
 
71
  class ChatCompletionChunk(BaseModel):
72
  id: str = Field(...)
73
  object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
 
98
  i += 1
99
  return result
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  def cleanMessages(messages: List[ChatMessage], removeThinkingContent: bool = False):
102
  promptStrList = []
103
  for message in messages:
 
109
  promptStrList.append(f"{role_str}: {content}")
110
  return "\n\n".join(promptStrList)
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  class SamplerConfig(BaseModel):
113
  max_tokens: int = 4096
114
  temperature: float = 1.0
 
129
  DEFAULT_REASONING: bool = False
130
  REASONING: bool = False
131
  VOCAB: str = "rwkv_vocab_v20230424"
132
+ CTX_LEN: int = 4096
133
  DEFAULT_SAMPLER: SamplerConfig = Field(default_factory=SamplerConfig)
134
 
135
  class Config(BaseSettings):
 
143
  SERVICE_NAME="rwkv7-g1a4-2.9b-20251118-ctx8192",
144
  DOWNLOAD_MODEL_FILE_NAME="rwkv7-g1a4-2.9b-20251118-ctx8192.pth",
145
  DOWNLOAD_MODEL_REPO_ID="BlinkDL/rwkv7-g1",
146
+ REASONING=True,
147
+ CTX_LEN=8192
148
  ),
149
  ModelConfig(
150
  SERVICE_NAME="rwkv7-g1a3-1.5b-20251015-ctx8192",
151
  DOWNLOAD_MODEL_FILE_NAME="rwkv7-g1a3-1.5b-20251015-ctx8192.pth",
152
  DOWNLOAD_MODEL_REPO_ID="BlinkDL/rwkv7-g1",
153
+ REASONING=True,
154
+ CTX_LEN=8192
155
  ),
156
  ModelConfig(
157
  SERVICE_NAME="rwkv7-g1a-0.4b-20250905-ctx4096",
158
  DOWNLOAD_MODEL_FILE_NAME="rwkv7-g1a-0.4b-20250905-ctx4096.pth",
159
  DOWNLOAD_MODEL_REPO_ID="BlinkDL/rwkv7-g1",
160
+ REASONING=True,
161
+ CTX_LEN=4096
162
  ),
163
  ModelConfig(
164
  SERVICE_NAME="rwkv7-g1a-0.1b-20250728-ctx4096",
 
166
  DOWNLOAD_MODEL_REPO_ID="BlinkDL/rwkv7-g1",
167
  REASONING=True,
168
  DEFAULT_CHAT=True,
169
+ DEFAULT_REASONING=True,
170
+ CTX_LEN=4096
171
  ),
172
  ]
173
 
 
195
  if CONFIG.RWKV_CUDA_ON and "cuda" in CONFIG.STRATEGY.lower():
196
  from pynvml import *
197
  nvmlInit()
 
198
  os.environ["RWKV_CUDA_ON"] = "1"
199
  torch.backends.cudnn.benchmark = True
200
  torch.backends.cudnn.allow_tf32 = True
 
311
  triggers = ["es verdad", "dato", "precio", "cuando", "quien", "noticia", "actualidad", "verify"]
312
  return any(t in msg.lower() for t in triggers)
313
 
314
+ app = FastAPI(title="RWKV Ultimate Server")
315
 
316
  app.add_middleware(
317
  CORSMiddleware,
 
328
  request.scope["client"] = (fake.ipv4(), request.client.port if request.client else 80)
329
  return await call_next(request)
330
 
331
+ def prune_context(messages: List[ChatMessage], model_name: str, max_gen_tokens: int):
332
+ storage = MODEL_STORAGE[model_name]
333
+ limit = storage.MODEL_CONFIG.CTX_LEN
334
+ pipeline = storage.pipeline
335
+
336
+ current_text = cleanMessages(messages)
337
+ tokens = pipeline.encode(current_text)
338
+
339
+ if len(tokens) + max_gen_tokens < limit:
340
+ return messages
341
+
342
+ system_msgs = [m for m in messages if m.role == "System"]
343
+ other_msgs = [m for m in messages if m.role != "System"]
344
+
345
+ while len(other_msgs) > 1:
346
+ candidate_text = cleanMessages(system_msgs + other_msgs)
347
+ if len(pipeline.encode(candidate_text)) + max_gen_tokens < limit:
348
+ break
349
+ other_msgs.pop(0)
350
+
351
+ return system_msgs + other_msgs
352
+
353
  async def runPrefill(request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state):
354
  ctx = ctx.replace("\r\n", "\n")
355
  tokens = MODEL_STORAGE[request.model].pipeline.encode(ctx)
 
372
  out_tokens = []
373
  out_last = 0
374
  cache_word_list = []
375
+
376
+ stop_sequences = request.stop if request.stop else []
377
+
378
  for i in range(max_tokens):
379
  for n in occurrence: out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
380
  token = MODEL_STORAGE[request.model].pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
381
+
382
  if token == 0:
383
  yield {"content": "".join(cache_word_list), "finish_reason": "stop", "state": model_state}
384
  del out; gc.collect(); return
385
+
386
  out, model_state = MODEL_STORAGE[request.model].model.forward([token], model_state)
387
  model_tokens.append(token)
388
  out_tokens.append(token)
389
+
390
  for xxx in occurrence: occurrence[xxx] *= request.penalty_decay
391
  occurrence[token] = 1 + (occurrence.get(token, 0))
392
+
393
  tmp = MODEL_STORAGE[request.model].pipeline.decode(out_tokens[out_last:])
394
  if "\ufffd" in tmp: continue
395
  cache_word_list.append(tmp)
396
  out_last = i + 1
397
+
398
+ current_buffer = "".join(cache_word_list)
399
+ for s in stop_sequences:
400
+ if s in current_buffer:
401
+ final_content = current_buffer.split(s)[0]
402
+ yield {"content": final_content, "finish_reason": "stop", "state": model_state}
403
+ del out; gc.collect(); return
404
+
405
  if len(cache_word_list) > 1:
406
  yield {"content": cache_word_list.pop(0), "finish_reason": None}
407
+
408
  yield {"content": "".join(cache_word_list), "finish_reason": "length"}
409
 
410
  async def chatResponseStream(request: ChatCompletionRequest, model_state: any, completionId: str, enableReasoning: bool):
411
  clean_msg = cleanMessages(request.messages, enableReasoning)
412
  prompt = f"{clean_msg}\n\nAssistant:{' <think' if enableReasoning else ''}"
413
+
414
+ async with GPU_LOCK:
415
+ try:
416
+ out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
417
+
418
+ yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(role='Assistant', content=''), finish_reason=None)]).model_dump_json()}\n\n"
419
+
420
+ for chunk in generate(request, out, model_tokens, model_state, max_tokens=request.max_tokens or 4096):
421
+ content = chunk.get("content", "")
422
+ finish = chunk.get("finish_reason", None)
423
+ if content:
424
+ yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(content=content), finish_reason=None)]).model_dump_json()}\n\n"
425
+ if finish:
426
+ yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(content=''), finish_reason=finish)]).model_dump_json()}\n\n"
427
+ break
428
+ await asyncio.sleep(0)
429
+ finally:
430
+ pass
431
+
432
  yield "data: [DONE]\n\n"
433
 
434
  @app.post("/api/v1/chat/completions")
 
441
  if "rwkv-latest" in model_key:
442
  if is_reasoning and DEFAULT_REASONING_MODEL_NAME: target_model = DEFAULT_REASONING_MODEL_NAME
443
  elif DEFALUT_MODEL_NAME: target_model = DEFALUT_MODEL_NAME
444
+
445
  if target_model not in MODEL_STORAGE:
446
  raise HTTPException(404, f"Model {target_model} not loaded.")
447
  request.model = target_model
448
+
449
  default_sampler = MODEL_STORAGE[target_model].MODEL_CONFIG.DEFAULT_SAMPLER
450
  req_data = request.model_dump()
451
  for k, v in default_sampler.model_dump().items():
452
  if req_data.get(k) is None: req_data[k] = v
453
  realRequest = ChatCompletionRequest(**req_data)
454
+
455
  sys_msg = ChatMessage(role="System", content=TruthProtocol.STRICT_SYSTEM_PROMPT)
456
  if realRequest.messages:
457
  if realRequest.messages[0].role == "System":
458
  realRequest.messages[0].content = f"{TruthProtocol.STRICT_SYSTEM_PROMPT}\n\n{realRequest.messages[0].content}"
459
  else:
460
  realRequest.messages.insert(0, sys_msg)
461
+
462
  last_msg = realRequest.messages[-1]
463
  if last_msg.role == "user" and needs_verification(last_msg.content, raw_model):
464
  ctx = search_facts(last_msg.content)
465
  if ctx:
466
  realRequest.messages.insert(-1, ChatMessage(role="System", content=ctx))
467
+
468
  TruthProtocol.enforce_truth_params(realRequest)
469
+
470
+ realRequest.messages = prune_context(realRequest.messages, target_model, realRequest.max_tokens or 1024)
471
+
472
  return StreamingResponse(chatResponseStream(realRequest, None, completionId, is_reasoning), media_type="text/event-stream")
473
 
474
  @app.get("/api/v1/models")