Ksjsjjdj commited on
Commit
bbff189
·
verified ·
1 Parent(s): 1fd8dc6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +514 -717
app.py CHANGED
@@ -1,717 +1,514 @@
1
- import os
2
-
3
- if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio":
4
- from modelscope import patch_hub
5
-
6
- patch_hub()
7
-
8
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"
9
-
10
-
11
- from config import CONFIG, ModelConfig
12
- from utils import (
13
- cleanMessages,
14
- parse_think_response,
15
- remove_nested_think_tags_stack,
16
- format_bytes,
17
- log,
18
- )
19
-
20
- import copy, types, gc, sys, re, time, collections, asyncio
21
- from huggingface_hub import hf_hub_download
22
- from loguru import logger
23
- from rich import print
24
-
25
- from snowflake import SnowflakeGenerator
26
-
27
- CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
28
-
29
- from typing import List, Optional, Union, Any, Dict
30
- from pydantic import BaseModel, Field, model_validator
31
- from pydantic_settings import BaseSettings
32
-
33
-
34
- import numpy as np
35
- import torch
36
-
37
-
38
- if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available():
39
- logger.info(f"CUDA not found, fall back to cpu")
40
- CONFIG.STRATEGY = "cpu fp16"
41
-
42
- if "cuda" in CONFIG.STRATEGY.lower():
43
- from pynvml import *
44
-
45
- nvmlInit()
46
- gpu_h = nvmlDeviceGetHandleByIndex(0)
47
-
48
-
49
- def logGPUState():
50
- if "cuda" in CONFIG.STRATEGY:
51
- gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
52
- logger.info(
53
- f"[STATUS] Torch - {format_bytes(torch.cuda.memory_allocated())} - NVML - vram {format_bytes(gpu_info.total)} used {format_bytes(gpu_info.used)} free {format_bytes(gpu_info.free)}"
54
- )
55
-
56
-
57
- torch.backends.cudnn.benchmark = True
58
- torch.backends.cudnn.allow_tf32 = True
59
- torch.backends.cuda.matmul.allow_tf32 = True
60
- os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models
61
- os.environ["RWKV_JIT_ON"] = "1"
62
- os.environ["RWKV_CUDA_ON"] = (
63
- "1" if CONFIG.RWKV_CUDA_ON and "cuda" in CONFIG.STRATEGY.lower() else "0"
64
- )
65
-
66
- from rwkv.model import RWKV
67
- from rwkv.utils import PIPELINE, PIPELINE_ARGS
68
-
69
- from fastapi import FastAPI, HTTPException
70
- from fastapi.responses import StreamingResponse
71
- from fastapi.middleware.cors import CORSMiddleware
72
- from fastapi.staticfiles import StaticFiles
73
- from fastapi.middleware.gzip import GZipMiddleware
74
-
75
-
76
- from api_types import (
77
- ChatMessage,
78
- ChatCompletion,
79
- ChatCompletionChunk,
80
- Usage,
81
- PromptTokensDetails,
82
- ChatCompletionChoice,
83
- ChatCompletionMessage,
84
- )
85
-
86
-
87
- class ModelStorage:
88
- MODEL_CONFIG: Optional[ModelConfig] = None
89
- model: Optional[RWKV] = None
90
- pipeline: Optional[PIPELINE] = None
91
-
92
-
93
- MODEL_STORAGE: Dict[str, ModelStorage] = {}
94
-
95
- DEFALUT_MODEL_NAME = None
96
- DEFAULT_REASONING_MODEL_NAME = None
97
-
98
- logger.info(f"STRATEGY - {CONFIG.STRATEGY}")
99
-
100
- logGPUState()
101
-
102
- for model_config in CONFIG.MODELS:
103
- logger.info(f"Load Model - {model_config.SERVICE_NAME}")
104
-
105
- if model_config.MODEL_FILE_PATH == None:
106
- model_config.MODEL_FILE_PATH = hf_hub_download(
107
- repo_id=model_config.DOWNLOAD_MODEL_REPO_ID,
108
- filename=model_config.DOWNLOAD_MODEL_FILE_NAME,
109
- local_dir=model_config.DOWNLOAD_MODEL_DIR,
110
- )
111
- logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}")
112
-
113
- if model_config.DEFAULT_CHAT:
114
- if DEFALUT_MODEL_NAME != None:
115
- logger.info(
116
- f"Load Model - Replace `DEFALUT_MODEL_NAME` from `{DEFALUT_MODEL_NAME}` to `{model_config.SERVICE_NAME}`"
117
- )
118
- DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
119
-
120
- if model_config.DEFAULT_REASONING:
121
- if DEFAULT_REASONING_MODEL_NAME != None:
122
- logger.info(
123
- f"Load Model - Replace `DEFAULT_REASONING_MODEL_NAME` from `{DEFAULT_REASONING_MODEL_NAME}` to `{model_config.SERVICE_NAME}`"
124
- )
125
- DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME
126
-
127
- logger.info(f"Load Model - Loading `{model_config.SERVICE_NAME}`")
128
- print(model_config.DEFAULT_SAMPLER)
129
-
130
- MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
131
- MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
132
- MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV(
133
- model=model_config.MODEL_FILE_PATH.replace(".pth", ""),
134
- strategy=CONFIG.STRATEGY,
135
- )
136
- MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE(
137
- MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB
138
- )
139
- if "cuda" in CONFIG.STRATEGY:
140
- torch.cuda.empty_cache()
141
- gc.collect()
142
- logGPUState()
143
-
144
-
145
- logger.info(f"Load Model - DEFALUT_MODEL_NAME is `{DEFALUT_MODEL_NAME}`")
146
- logger.info(
147
- f"Load Model - DEFAULT_REASONING_MODEL_NAME is `{DEFAULT_REASONING_MODEL_NAME}`"
148
- )
149
-
150
-
151
- class ChatCompletionRequest(BaseModel):
152
- model: str = Field(
153
- default="rwkv-latest",
154
- description="Add `:thinking` suffix to the model name to enable reasoning. Example: `rwkv-latest:thinking`",
155
- )
156
- messages: Optional[List[ChatMessage]] = Field(default=None)
157
- prompt: Optional[str] = Field(default=None)
158
- max_tokens: Optional[int] = Field(default=None)
159
- temperature: Optional[float] = Field(default=None)
160
- top_p: Optional[float] = Field(default=None)
161
- presence_penalty: Optional[float] = Field(default=None)
162
- count_penalty: Optional[float] = Field(default=None)
163
- penalty_decay: Optional[float] = Field(default=None)
164
- stream: Optional[bool] = Field(default=False)
165
- state_name: Optional[str] = Field(default=None)
166
- include_usage: Optional[bool] = Field(default=False)
167
- stop: Optional[list[str]] = Field(["\n\n"])
168
- stop_tokens: Optional[list[int]] = Field([0])
169
-
170
- @model_validator(mode="before")
171
- @classmethod
172
- def validate_mutual_exclusivity(cls, data: Any) -> Any:
173
- if not isinstance(data, dict):
174
- return data
175
-
176
- messages_provided = "messages" in data and data["messages"] != None
177
- prompt_provided = "prompt" in data and data["prompt"] != None
178
-
179
- if messages_provided and prompt_provided:
180
- raise ValueError("messages and prompt cannot coexist. Choose one.")
181
- if not messages_provided and not prompt_provided:
182
- raise ValueError("Either messages or prompt must be provided.")
183
- return data
184
-
185
-
186
- app = FastAPI(title="RWKV OpenAI-Compatible API")
187
-
188
- app.add_middleware(
189
- CORSMiddleware,
190
- allow_origins=["*"],
191
- allow_credentials=True,
192
- allow_methods=["*"],
193
- allow_headers=["*"],
194
- )
195
- app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5)
196
-
197
-
198
- async def runPrefill(
199
- request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state
200
- ):
201
- ctx = ctx.replace("\r\n", "\n")
202
-
203
- tokens = MODEL_STORAGE[request.model].pipeline.encode(ctx)
204
- tokens = [int(x) for x in tokens]
205
- model_tokens += tokens
206
-
207
- while len(tokens) > 0:
208
- out, model_state = MODEL_STORAGE[request.model].model.forward(
209
- tokens[: CONFIG.CHUNK_LEN], model_state
210
- )
211
- tokens = tokens[CONFIG.CHUNK_LEN :]
212
- await asyncio.sleep(0)
213
-
214
- return out, model_tokens, model_state
215
-
216
-
217
- def generate(
218
- request: ChatCompletionRequest,
219
- out,
220
- model_tokens: List[int],
221
- model_state,
222
- max_tokens=2048,
223
- ):
224
- args = PIPELINE_ARGS(
225
- temperature=max(0.2, request.temperature),
226
- top_p=request.top_p,
227
- alpha_frequency=request.count_penalty,
228
- alpha_presence=request.presence_penalty,
229
- token_ban=[], # ban the generation of some tokens
230
- token_stop=[0],
231
- ) # stop generation whenever you see any token here
232
-
233
- occurrence = {}
234
- out_tokens: List[int] = []
235
- out_last = 0
236
-
237
- cache_word_list = []
238
- cache_word_len = 5
239
-
240
- for i in range(max_tokens):
241
- for n in occurrence:
242
- out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
243
- # out[0] -= 1e10 # disable END_OF_TEXT
244
-
245
- token = MODEL_STORAGE[request.model].pipeline.sample_logits(
246
- out, temperature=args.temperature, top_p=args.top_p
247
- )
248
-
249
- if token == 0 and token in request.stop_tokens:
250
- yield {
251
- "content": "".join(cache_word_list),
252
- "tokens": out_tokens[out_last:],
253
- "finish_reason": "stop:token:0",
254
- "state": model_state,
255
- }
256
-
257
- del out
258
- gc.collect()
259
- return
260
-
261
- out, model_state = MODEL_STORAGE[request.model].model.forward(
262
- [token], model_state
263
- )
264
-
265
- model_tokens.append(token)
266
- out_tokens.append(token)
267
-
268
- if token in request.stop_tokens:
269
- yield {
270
- "content": "".join(cache_word_list),
271
- "tokens": out_tokens[out_last:],
272
- "finish_reason": f"stop:token:{token}",
273
- "state": model_state,
274
- }
275
-
276
- del out
277
- gc.collect()
278
- return
279
-
280
- for xxx in occurrence:
281
- occurrence[xxx] *= request.penalty_decay
282
- occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
283
-
284
- tmp: str = MODEL_STORAGE[request.model].pipeline.decode(out_tokens[out_last:])
285
-
286
- if "\ufffd" in tmp:
287
- continue
288
-
289
- cache_word_list.append(tmp)
290
- output_cache_str = "".join(cache_word_list)
291
-
292
- for stop_words in request.stop:
293
- if stop_words in output_cache_str:
294
- yield {
295
- "content": output_cache_str.replace(stop_words, ""),
296
- "tokens": out_tokens[out_last - cache_word_len :],
297
- "finish_reason": f"stop:words:{stop_words}",
298
- "state": model_state,
299
- }
300
-
301
- del out
302
- gc.collect()
303
- return
304
-
305
- if len(cache_word_list) > cache_word_len:
306
- yield {
307
- "content": cache_word_list.pop(0),
308
- "tokens": out_tokens[out_last - cache_word_len :],
309
- "finish_reason": None,
310
- }
311
-
312
- out_last = i + 1
313
-
314
- else:
315
- yield {
316
- "content": "",
317
- "tokens": [],
318
- "finish_reason": "length",
319
- }
320
-
321
-
322
- async def chatResponse(
323
- request: ChatCompletionRequest,
324
- model_state: any,
325
- completionId: str,
326
- enableReasoning: bool,
327
- ) -> ChatCompletion:
328
- createTimestamp = time.time()
329
-
330
- prompt = (
331
- f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}"
332
- if request.prompt == None
333
- else request.prompt.strip()
334
- )
335
- logger.info(f"[REQ] {completionId} - prompt - {prompt}")
336
-
337
- out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
338
-
339
- prefillTime = time.time()
340
- promptTokenCount = len(model_tokens)
341
-
342
- fullResponse = " <think" if enableReasoning else ""
343
- completionTokenCount = 0
344
- finishReason = None
345
-
346
- for chunk in generate(
347
- request,
348
- out,
349
- model_tokens,
350
- model_state,
351
- max_tokens=(
352
- 64000
353
- if "max_tokens" not in request.model_fields_set and enableReasoning
354
- else request.max_tokens
355
- ),
356
- ):
357
- fullResponse += chunk["content"]
358
- completionTokenCount += 1
359
-
360
- if chunk["finish_reason"]:
361
- finishReason = chunk["finish_reason"]
362
- await asyncio.sleep(0)
363
-
364
- genenrateTime = time.time()
365
-
366
- responseLog = {
367
- "content": fullResponse,
368
- "finish": finishReason,
369
- "prefill_len": promptTokenCount,
370
- "prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
371
- "gen_len": completionTokenCount,
372
- "gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2),
373
- }
374
- logger.info(f"[RES] {completionId} - {responseLog}")
375
-
376
- reasoning_content, content = parse_think_response(fullResponse)
377
-
378
- response = ChatCompletion(
379
- id=completionId,
380
- created=int(createTimestamp),
381
- model=request.model,
382
- usage=Usage(
383
- prompt_tokens=promptTokenCount,
384
- completion_tokens=completionTokenCount,
385
- total_tokens=promptTokenCount + completionTokenCount,
386
- prompt_tokens_details={"cached_tokens": 0},
387
- ),
388
- choices=[
389
- ChatCompletionChoice(
390
- index=0,
391
- message=ChatCompletionMessage(
392
- role="Assistant",
393
- content=content,
394
- reasoning_content=reasoning_content if reasoning_content else None,
395
- ),
396
- logprobs=None,
397
- finish_reason=finishReason,
398
- )
399
- ],
400
- )
401
-
402
- return response
403
-
404
-
405
- async def chatResponseStream(
406
- request: ChatCompletionRequest,
407
- model_state: any,
408
- completionId: str,
409
- enableReasoning: bool,
410
- ):
411
- createTimestamp = int(time.time())
412
-
413
- prompt = (
414
- f"{cleanMessages(request.messages,enableReasoning)}\n\nAssistant:{' <think' if enableReasoning else ''}"
415
- if request.prompt == None
416
- else request.prompt.strip()
417
- )
418
-
419
- logger.info(f"[REQ] {completionId} - context\n```{prompt}```")
420
-
421
- out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
422
-
423
- prefillTime = time.time()
424
- promptTokenCount = len(model_tokens)
425
-
426
- completionTokenCount = 0
427
- finishReason = None
428
-
429
- response = ChatCompletionChunk(
430
- id=completionId,
431
- created=createTimestamp,
432
- model=request.model,
433
- usage=(
434
- Usage(
435
- prompt_tokens=promptTokenCount,
436
- completion_tokens=completionTokenCount,
437
- total_tokens=promptTokenCount + completionTokenCount,
438
- prompt_tokens_details={"cached_tokens": 0},
439
- )
440
- if request.include_usage
441
- else None
442
- ),
443
- choices=[
444
- ChatCompletionChoice(
445
- index=0,
446
- delta=ChatCompletionMessage(
447
- role="Assistant",
448
- content="",
449
- reasoning_content="" if enableReasoning else None,
450
- ),
451
- logprobs=None,
452
- finish_reason=finishReason,
453
- )
454
- ],
455
- )
456
- yield f"data: {response.model_dump_json()}\n\n"
457
-
458
- buffer = []
459
-
460
- if enableReasoning:
461
- buffer.append("<think")
462
-
463
- streamConfig = {
464
- "isChecking": False, # check whether is <think> tag
465
- "fullTextCursor": 0,
466
- "in_think": False,
467
- "cacheStr": "",
468
- }
469
-
470
- for chunk in generate(
471
- request,
472
- out,
473
- model_tokens,
474
- model_state,
475
- max_tokens=(
476
- 64000
477
- if "max_tokens" not in request.model_fields_set and enableReasoning
478
- else request.max_tokens
479
- ),
480
- ):
481
- completionTokenCount += 1
482
-
483
- chunkContent: str = chunk["content"]
484
- buffer.append(chunkContent)
485
-
486
- fullText = "".join(buffer)
487
-
488
- if chunk["finish_reason"]:
489
- finishReason = chunk["finish_reason"]
490
-
491
- response = ChatCompletionChunk(
492
- id=completionId,
493
- created=createTimestamp,
494
- model=request.model,
495
- usage=(
496
- Usage(
497
- prompt_tokens=promptTokenCount,
498
- completion_tokens=completionTokenCount,
499
- total_tokens=promptTokenCount + completionTokenCount,
500
- prompt_tokens_details={"cached_tokens": 0},
501
- )
502
- if request.include_usage
503
- else None
504
- ),
505
- choices=[
506
- ChatCompletionChoice(
507
- index=0,
508
- delta=ChatCompletionMessage(
509
- content=None, reasoning_content=None
510
- ),
511
- logprobs=None,
512
- finish_reason=finishReason,
513
- )
514
- ],
515
- )
516
-
517
- markStart = fullText.find("<", streamConfig["fullTextCursor"])
518
- if not streamConfig["isChecking"] and markStart != -1:
519
- streamConfig["isChecking"] = True
520
-
521
- if streamConfig["in_think"]:
522
- response.choices[0].delta.reasoning_content = fullText[
523
- streamConfig["fullTextCursor"] : markStart
524
- ]
525
- else:
526
- response.choices[0].delta.content = fullText[
527
- streamConfig["fullTextCursor"] : markStart
528
- ]
529
-
530
- streamConfig["cacheStr"] = ""
531
- streamConfig["fullTextCursor"] = markStart
532
-
533
- if streamConfig["isChecking"]:
534
- streamConfig["cacheStr"] = fullText[streamConfig["fullTextCursor"] :]
535
- else:
536
- if streamConfig["in_think"]:
537
- response.choices[0].delta.reasoning_content = chunkContent
538
- else:
539
- response.choices[0].delta.content = chunkContent
540
- streamConfig["fullTextCursor"] = len(fullText)
541
-
542
- markEnd = fullText.find(">", streamConfig["fullTextCursor"])
543
- if (streamConfig["isChecking"] and markEnd != -1) or finishReason != None:
544
- streamConfig["isChecking"] = False
545
-
546
- if (
547
- not streamConfig["in_think"]
548
- and streamConfig["cacheStr"].find("<think>") != -1
549
- ):
550
- streamConfig["in_think"] = True
551
-
552
- response.choices[0].delta.reasoning_content = (
553
- response.choices[0].delta.reasoning_content
554
- if response.choices[0].delta.reasoning_content != None
555
- else "" + streamConfig["cacheStr"].replace("<think>", "")
556
- )
557
-
558
- elif (
559
- streamConfig["in_think"]
560
- and streamConfig["cacheStr"].find("</think>") != -1
561
- ):
562
- streamConfig["in_think"] = False
563
-
564
- response.choices[0].delta.content = (
565
- response.choices[0].delta.content
566
- if response.choices[0].delta.content != None
567
- else "" + streamConfig["cacheStr"].replace("</think>", "")
568
- )
569
- else:
570
- if streamConfig["in_think"]:
571
- response.choices[0].delta.reasoning_content = (
572
- response.choices[0].delta.reasoning_content
573
- if response.choices[0].delta.reasoning_content != None
574
- else "" + streamConfig["cacheStr"]
575
- )
576
- else:
577
- response.choices[0].delta.content = (
578
- response.choices[0].delta.content
579
- if response.choices[0].delta.content != None
580
- else "" + streamConfig["cacheStr"]
581
- )
582
- streamConfig["fullTextCursor"] = len(fullText)
583
-
584
- if (
585
- response.choices[0].delta.content != None
586
- or response.choices[0].delta.reasoning_content != None
587
- ):
588
- yield f"data: {response.model_dump_json()}\n\n"
589
-
590
- await asyncio.sleep(0)
591
-
592
- del streamConfig
593
- else:
594
- for chunk in generate(request, out, model_tokens, model_state):
595
- completionTokenCount += 1
596
- buffer.append(chunk["content"])
597
-
598
- if chunk["finish_reason"]:
599
- finishReason = chunk["finish_reason"]
600
-
601
- response = ChatCompletionChunk(
602
- id=completionId,
603
- created=createTimestamp,
604
- model=request.model,
605
- usage=(
606
- Usage(
607
- prompt_tokens=promptTokenCount,
608
- completion_tokens=completionTokenCount,
609
- total_tokens=promptTokenCount + completionTokenCount,
610
- prompt_tokens_details={"cached_tokens": 0},
611
- )
612
- if request.include_usage
613
- else None
614
- ),
615
- choices=[
616
- ChatCompletionChoice(
617
- index=0,
618
- delta=ChatCompletionMessage(content=chunk["content"]),
619
- logprobs=None,
620
- finish_reason=finishReason,
621
- )
622
- ],
623
- )
624
-
625
- yield f"data: {response.model_dump_json()}\n\n"
626
- await asyncio.sleep(0)
627
-
628
- genenrateTime = time.time()
629
-
630
- responseLog = {
631
- "content": "".join(buffer),
632
- "finish": finishReason,
633
- "prefill_len": promptTokenCount,
634
- "prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
635
- "gen_len": completionTokenCount,
636
- "gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2),
637
- }
638
- logger.info(f"[RES] {completionId} - {responseLog}")
639
- request.messages.append(
640
- ChatMessage(role="Assistant", content=responseLog["content"])
641
- )
642
- log(
643
- {
644
- **request.model_dump(),
645
- **responseLog,
646
- "completionId": completionId,
647
- "machineLabel": os.environ.get("MACHINE_LABEL"),
648
- }
649
- )
650
-
651
- del buffer
652
-
653
- yield "data: [DONE]\n\n"
654
-
655
-
656
- @app.post("/api/v1/chat/completions")
657
- async def chat_completions(request: ChatCompletionRequest):
658
- completionId = str(next(CompletionIdGenerator))
659
- logger.info(f"[REQ] {completionId} - {request.model_dump()}")
660
-
661
- modelName = request.model.split(":")[0]
662
- enableReasoning = ":thinking" in request.model
663
-
664
- if "rwkv-latest" in request.model:
665
- if enableReasoning:
666
- if DEFAULT_REASONING_MODEL_NAME == None:
667
- raise HTTPException(404, "DEFAULT_REASONING_MODEL_NAME not set")
668
- defaultSamplerConfig = MODEL_STORAGE[
669
- DEFAULT_REASONING_MODEL_NAME
670
- ].MODEL_CONFIG.DEFAULT_SAMPLER
671
- request.model = DEFAULT_REASONING_MODEL_NAME
672
-
673
- else:
674
- if DEFALUT_MODEL_NAME == None:
675
- raise HTTPException(404, "DEFALUT_MODEL_NAME not set")
676
- defaultSamplerConfig = MODEL_STORAGE[
677
- DEFALUT_MODEL_NAME
678
- ].MODEL_CONFIG.DEFAULT_SAMPLER
679
- request.model = DEFALUT_MODEL_NAME
680
-
681
- elif modelName in MODEL_STORAGE:
682
- defaultSamplerConfig = MODEL_STORAGE[modelName].MODEL_CONFIG.DEFAULT_SAMPLER
683
- request.model = modelName
684
- else:
685
- raise f"Can not find `{modelName}`"
686
-
687
- async def chatResponseStreamDisconnect():
688
- logGPUState()
689
-
690
- model_state = None
691
- request_dict = request.model_dump()
692
-
693
- for k, v in defaultSamplerConfig.model_dump().items():
694
- if request_dict[k] == None:
695
- request_dict[k] = v
696
- realRequest = ChatCompletionRequest(**request_dict)
697
-
698
- logger.info(f"[REQ] {completionId} - Real - {request.model_dump()}")
699
-
700
- if request.stream:
701
- r = StreamingResponse(
702
- chatResponseStream(realRequest, model_state, completionId, enableReasoning),
703
- media_type="text/event-stream",
704
- background=chatResponseStreamDisconnect,
705
- )
706
- else:
707
- r = await chatResponse(realRequest, model_state, completionId, enableReasoning)
708
-
709
- return r
710
-
711
-
712
- app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static")
713
-
714
- if __name__ == "__main__":
715
- import uvicorn
716
-
717
- uvicorn.run(app, host=CONFIG.HOST, port=CONFIG.PORT)
 
1
+ import os
2
+ import copy
3
+ import types
4
+ import gc
5
+ import sys
6
+ import re
7
+ import time
8
+ import collections
9
+ import asyncio
10
+ import random
11
+ from typing import List, Optional, Union, Any, Dict
12
+
13
+ # --- LIBRERÍAS DE TERCEROS ---
14
+ if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio":
15
+ from modelscope import patch_hub
16
+ patch_hub()
17
+
18
+ # Configuración de Pytorch para evitar fragmentación
19
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"
20
+
21
+ # Configuración RWKV
22
+ os.environ["RWKV_V7_ON"] = "1"
23
+ os.environ["RWKV_JIT_ON"] = "1"
24
+
25
+ # Imports del proyecto
26
+ from config import CONFIG, ModelConfig
27
+ from utils import (
28
+ cleanMessages,
29
+ parse_think_response,
30
+ remove_nested_think_tags_stack,
31
+ format_bytes,
32
+ log,
33
+ )
34
+
35
+ from huggingface_hub import hf_hub_download
36
+ from loguru import logger
37
+ from rich import print
38
+ from snowflake import SnowflakeGenerator
39
+ import numpy as np
40
+ import torch
41
+ import requests
42
+
43
+ # --- NUEVAS LIBRERÍAS (Faker y Búsqueda) ---
44
+ try:
45
+ from duckduckgo_search import DDGS
46
+ HAS_DDG = True
47
+ except ImportError:
48
+ logger.warning("duckduckgo_search not found. Web search disabled.")
49
+ HAS_DDG = False
50
+
51
+ try:
52
+ from faker import Faker
53
+ fake = Faker()
54
+ HAS_FAKER = True
55
+ except ImportError:
56
+ logger.warning("Faker not found. IP masking disabled. Install with `pip install faker`")
57
+ HAS_FAKER = False
58
+
59
+ # FastAPI Imports
60
+ from fastapi import FastAPI, HTTPException, Request, Response
61
+ from fastapi.responses import StreamingResponse
62
+ from fastapi.middleware.cors import CORSMiddleware
63
+ from fastapi.staticfiles import StaticFiles
64
+ from fastapi.middleware.gzip import GZipMiddleware
65
+ from pydantic import BaseModel, Field, model_validator
66
+
67
+ # --- INICIALIZACIÓN DE GENERADORES Y MODELOS ---
68
+
69
+ CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
70
+
71
+ # Configuración de Estrategia (CUDA/CPU)
72
+ if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available():
73
+ logger.info(f"CUDA not found, fall back to cpu")
74
+ CONFIG.STRATEGY = "cpu fp16"
75
+
76
+ if "cuda" in CONFIG.STRATEGY.lower():
77
+ from pynvml import *
78
+ nvmlInit()
79
+ gpu_h = nvmlDeviceGetHandleByIndex(0)
80
+ # Habilitar optimizaciones de CUDA para RWKV
81
+ torch.backends.cudnn.benchmark = True
82
+ torch.backends.cudnn.allow_tf32 = True
83
+ torch.backends.cuda.matmul.allow_tf32 = True
84
+ os.environ["RWKV_CUDA_ON"] = "1" if CONFIG.RWKV_CUDA_ON else "0"
85
+ else:
86
+ os.environ["RWKV_CUDA_ON"] = "0"
87
+
88
+ from rwkv.model import RWKV
89
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
90
+ from api_types import (
91
+ ChatMessage, ChatCompletion, ChatCompletionChunk, Usage,
92
+ ChatCompletionChoice, ChatCompletionMessage
93
+ )
94
+
95
+ # --- GESTIÓN DE ESTADO DE GPU ---
96
+ def logGPUState():
97
+ if "cuda" in CONFIG.STRATEGY:
98
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
99
+ logger.info(
100
+ f"[STATUS] Torch - {format_bytes(torch.cuda.memory_allocated())} - "
101
+ f"NVML - vram {format_bytes(gpu_info.total)} used {format_bytes(gpu_info.used)} free {format_bytes(gpu_info.free)}"
102
+ )
103
+
104
+ # --- CARGA DE MODELOS ---
105
+ class ModelStorage:
106
+ MODEL_CONFIG: Optional[ModelConfig] = None
107
+ model: Optional[RWKV] = None
108
+ pipeline: Optional[PIPELINE] = None
109
+
110
+ MODEL_STORAGE: Dict[str, ModelStorage] = {}
111
+ DEFALUT_MODEL_NAME = None
112
+ DEFAULT_REASONING_MODEL_NAME = None
113
+
114
+ logger.info(f"STRATEGY - {CONFIG.STRATEGY}")
115
+ logGPUState()
116
+
117
+ for model_config in CONFIG.MODELS:
118
+ logger.info(f"Load Model - {model_config.SERVICE_NAME}")
119
+
120
+ if model_config.MODEL_FILE_PATH is None:
121
+ model_config.MODEL_FILE_PATH = hf_hub_download(
122
+ repo_id=model_config.DOWNLOAD_MODEL_REPO_ID,
123
+ filename=model_config.DOWNLOAD_MODEL_FILE_NAME,
124
+ local_dir=model_config.DOWNLOAD_MODEL_DIR,
125
+ )
126
+
127
+ # Gestión de modelos por defecto
128
+ if model_config.DEFAULT_CHAT:
129
+ DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
130
+ if model_config.DEFAULT_REASONING:
131
+ DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME
132
+
133
+ # Carga física del modelo
134
+ MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
135
+ MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
136
+ MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV(
137
+ model=model_config.MODEL_FILE_PATH.replace(".pth", ""),
138
+ strategy=CONFIG.STRATEGY,
139
+ )
140
+ MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE(
141
+ MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB
142
+ )
143
+
144
+ # Limpieza de VRAM tras carga
145
+ if "cuda" in CONFIG.STRATEGY:
146
+ torch.cuda.empty_cache()
147
+ gc.collect()
148
+
149
+ logGPUState()
150
+
151
+ # --- CLASES DE DATOS ---
152
+ class ChatCompletionRequest(BaseModel):
153
+ model: str = Field(
154
+ default="rwkv-latest",
155
+ description="Suffixes: `:thinking` for reasoning, `:online` for web search.",
156
+ )
157
+ messages: Optional[List[ChatMessage]] = Field(default=None)
158
+ prompt: Optional[str] = Field(default=None)
159
+ max_tokens: Optional[int] = Field(default=None)
160
+ temperature: Optional[float] = Field(default=None)
161
+ top_p: Optional[float] = Field(default=None)
162
+ presence_penalty: Optional[float] = Field(default=None)
163
+ count_penalty: Optional[float] = Field(default=None)
164
+ penalty_decay: Optional[float] = Field(default=None)
165
+ stream: Optional[bool] = Field(default=False)
166
+ state_name: Optional[str] = Field(default=None)
167
+ include_usage: Optional[bool] = Field(default=False)
168
+ stop: Optional[list[str]] = Field(["\n\n"])
169
+ stop_tokens: Optional[list[int]] = Field([0])
170
+
171
+ @model_validator(mode="before")
172
+ @classmethod
173
+ def validate_mutual_exclusivity(cls, data: Any) -> Any:
174
+ if not isinstance(data, dict): return data
175
+ if "messages" in data and "prompt" in data and data["messages"] and data["prompt"]:
176
+ raise ValueError("messages and prompt cannot coexist.")
177
+ return data
178
+
179
+ # --- SETUP APP & MIDDLEWARE AVANZADO ---
180
+ app = FastAPI(title="RWKV Advanced Server")
181
+
182
+ app.add_middleware(
183
+ CORSMiddleware,
184
+ allow_origins=["*"],
185
+ allow_credentials=True,
186
+ allow_methods=["*"],
187
+ allow_headers=["*"],
188
+ )
189
+ app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5)
190
+
191
+ # --- 1. MIDDLEWARE: FAKER IP MASKING & SECURITY ---
192
+ @app.middleware("http")
193
+ async def security_and_privacy_middleware(request: Request, call_next):
194
+ # a. IP Masking con Faker
195
+ original_ip = request.client.host if request.client else "unknown"
196
+ fake_ip = fake.ipv4() if HAS_FAKER else "127.0.0.1"
197
+
198
+ # Sobrescribimos la IP en el scope para que los logs y la lógica posterior vean la falsa
199
+ # Esto "oculta" la IPv4 real de cualquier logger subsiguiente
200
+ if HAS_FAKER:
201
+ # Modificamos el objeto client in-place es complicado en Starlette,
202
+ # pero podemos inyectar un header o modificar el scope.
203
+ # Aquí simulamos que la petición viene de la IP falsa.
204
+ request.scope["client"] = (fake_ip, request.client.port if request.client else 80)
205
+
206
+ # b. Rate Limiting Simple (Anti-Abuse)
207
+ # Nota: Si activamos Faker, el rate limit por IP real se vuelve inútil a menos que
208
+ # lo hagamos ANTES de modificar el scope. (Aquí lo hacemos conceptualmente).
209
+ # Para este ejemplo, permitimos todo, pero logueamos la IP ofuscada.
210
+
211
+ logger.info(f"[PRIVACY] Masked Real IP {original_ip} -> Fake IP {fake_ip}")
212
+
213
+ response = await call_next(request)
214
+
215
+ # c. Security Headers
216
+ response.headers["X-Content-Type-Options"] = "nosniff"
217
+ response.headers["X-Frame-Options"] = "DENY"
218
+
219
+ return response
220
+
221
+ # --- 2. MECANISMO AVANZADO: SEARCH CACHE (LRU) ---
222
+ # Evita hacer la misma petición a DDG repetidamente
223
+ search_cache = collections.OrderedDict()
224
+ SEARCH_CACHE_TTL = 600 # 10 minutos
225
+ SEARCH_CACHE_SIZE = 100
226
+
227
+ def get_cached_search(query: str):
228
+ current_time = time.time()
229
+ if query in search_cache:
230
+ timestamp, result = search_cache[query]
231
+ if current_time - timestamp < SEARCH_CACHE_TTL:
232
+ logger.info(f"[CACHE] Hit for query: {query}")
233
+ search_cache.move_to_end(query)
234
+ return result
235
+ return None
236
+
237
+ def set_cached_search(query: str, result: str):
238
+ if len(search_cache) >= SEARCH_CACHE_SIZE:
239
+ search_cache.popitem(last=False)
240
+ search_cache[query] = (time.time(), result)
241
+
242
+ def search_web_and_get_context(query: str, max_results: int = 4) -> str:
243
+ if not HAS_DDG: return ""
244
+
245
+ # Check Cache
246
+ cached = get_cached_search(query)
247
+ if cached: return cached
248
+
249
+ logger.info(f"[SEARCH] Searching external web for: {query}")
250
+ try:
251
+ results = DDGS().text(query, max_results=max_results)
252
+ if not results:
253
+ return "Web search executed but returned no results."
254
+
255
+ context_str = "Web Search Results (Real-time data):\n\n"
256
+ for i, res in enumerate(results):
257
+ context_str += f"Result {i+1} [{res['title']}]: {res['body']} (Source: {res['href']})\n\n"
258
+
259
+ context_str += "Instructions: Answer based strictly on the search results above. If the answer is not there, state it."
260
+
261
+ # Save to Cache
262
+ set_cached_search(query, context_str)
263
+ return context_str
264
+ except Exception as e:
265
+ logger.error(f"[SEARCH] Failed: {e}")
266
+ return ""
267
+
268
+ def should_trigger_search(last_message: str, model_name: str) -> bool:
269
+ if ":online" in model_name: return True
270
+ keywords = ["busca", "search", "google", "internet", "clima", "weather", "news", "noticias", "precio", "price", "who is", "quien es"]
271
+ return any(k in last_message.lower() for k in keywords)
272
+
273
+ # --- LÓGICA CORE DE RWKV (PREFILL & GENERATE) ---
274
+
275
+ async def runPrefill(request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state):
276
+ ctx = ctx.replace("\r\n", "\n")
277
+ tokens = MODEL_STORAGE[request.model].pipeline.encode(ctx)
278
+ tokens = [int(x) for x in tokens]
279
+ model_tokens += tokens
280
+
281
+ while len(tokens) > 0:
282
+ out, model_state = MODEL_STORAGE[request.model].model.forward(
283
+ tokens[: CONFIG.CHUNK_LEN], model_state
284
+ )
285
+ tokens = tokens[CONFIG.CHUNK_LEN :]
286
+ await asyncio.sleep(0)
287
+ return out, model_tokens, model_state
288
+
289
+ def generate(request: ChatCompletionRequest, out, model_tokens: List[int], model_state, max_tokens=2048):
290
+ args = PIPELINE_ARGS(
291
+ temperature=max(0.2, request.temperature),
292
+ top_p=request.top_p,
293
+ alpha_frequency=request.count_penalty,
294
+ alpha_presence=request.presence_penalty,
295
+ token_ban=[], token_stop=[0]
296
+ )
297
+
298
+ occurrence = {}
299
+ out_tokens: List[int] = []
300
+ out_last = 0
301
+ cache_word_list = []
302
+ cache_word_len = 5
303
+
304
+ for i in range(max_tokens):
305
+ for n in occurrence:
306
+ out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
307
+
308
+ token = MODEL_STORAGE[request.model].pipeline.sample_logits(
309
+ out, temperature=args.temperature, top_p=args.top_p
310
+ )
311
+
312
+ # Handling Stop Tokens
313
+ if token == 0 and token in request.stop_tokens:
314
+ yield {"content": "".join(cache_word_list), "tokens": out_tokens[out_last:], "finish_reason": "stop:token:0", "state": model_state}
315
+ del out; gc.collect(); return
316
+
317
+ out, model_state = MODEL_STORAGE[request.model].model.forward([token], model_state)
318
+ model_tokens.append(token)
319
+ out_tokens.append(token)
320
+
321
+ # Penalty Decay
322
+ for xxx in occurrence: occurrence[xxx] *= request.penalty_decay
323
+ occurrence[token] = 1 + (occurrence.get(token, 0))
324
+
325
+ # Decoding
326
+ tmp: str = MODEL_STORAGE[request.model].pipeline.decode(out_tokens[out_last:])
327
+ if "\ufffd" in tmp: continue
328
+
329
+ cache_word_list.append(tmp)
330
+ output_cache_str = "".join(cache_word_list)
331
+
332
+ # Handling Stop Words
333
+ for stop_words in request.stop:
334
+ if stop_words in output_cache_str:
335
+ yield {
336
+ "content": output_cache_str.replace(stop_words, ""),
337
+ "tokens": out_tokens[out_last - cache_word_len :],
338
+ "finish_reason": f"stop:words:{stop_words}",
339
+ "state": model_state
340
+ }
341
+ del out; gc.collect(); return
342
+
343
+ if len(cache_word_list) > cache_word_len:
344
+ yield {"content": cache_word_list.pop(0), "tokens": out_tokens[out_last - cache_word_len :], "finish_reason": None}
345
+ out_last = i + 1
346
+ else:
347
+ yield {"content": "", "tokens": [], "finish_reason": "length"}
348
+
349
+ # --- ENDPOINT HANDLERS ---
350
+
351
+ async def chatResponse(request: ChatCompletionRequest, model_state: any, completionId: str, enableReasoning: bool) -> ChatCompletion:
352
+ createTimestamp = time.time()
353
+ prompt = f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}" if not request.prompt else request.prompt.strip()
354
+
355
+ out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
356
+
357
+ prefillTime = time.time()
358
+ promptTokenCount = len(model_tokens)
359
+ fullResponse = " <think" if enableReasoning else ""
360
+ finishReason = None
361
+
362
+ for chunk in generate(request, out, model_tokens, model_state, max_tokens=(64000 if enableReasoning else request.max_tokens)):
363
+ fullResponse += chunk["content"]
364
+ if chunk["finish_reason"]: finishReason = chunk["finish_reason"]
365
+ await asyncio.sleep(0)
366
+
367
+ genTime = time.time()
368
+ reasoning_content, content = parse_think_response(fullResponse)
369
+
370
+ responseLog = {
371
+ "id": completionId, "prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
372
+ "gen_tps": round(len(fullResponse) / (genTime - prefillTime), 2)
373
+ }
374
+ logger.info(f"[RES-SYNC] {responseLog}")
375
+
376
+ return ChatCompletion(
377
+ id=completionId, created=int(createTimestamp), model=request.model,
378
+ usage=Usage(prompt_tokens=promptTokenCount, completion_tokens=len(fullResponse), total_tokens=promptTokenCount+len(fullResponse)),
379
+ choices=[ChatCompletionChoice(index=0, message=ChatCompletionMessage(role="Assistant", content=content, reasoning_content=reasoning_content), finish_reason=finishReason)]
380
+ )
381
+
382
+ async def chatResponseStream(request: ChatCompletionRequest, model_state: any, completionId: str, enableReasoning: bool):
383
+ createTimestamp = int(time.time())
384
+ prompt = f"{cleanMessages(request.messages, enableReasoning)}\n\nAssistant:{' <think' if enableReasoning else ''}" if not request.prompt else request.prompt.strip()
385
+
386
+ out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
387
+ promptTokenCount = len(model_tokens)
388
+ completionTokenCount = 0
389
+ finishReason = None
390
+
391
+ # Enviar primer chunk vacío
392
+ yield f"data: {ChatCompletionChunk(id=completionId, created=createTimestamp, model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(role='Assistant', content=''), finish_reason=None)]).model_dump_json()}\n\n"
393
+
394
+ buffer = ["<think"] if enableReasoning else []
395
+ streamConfig = {"isChecking": False, "fullTextCursor": 0, "in_think": False, "cacheStr": ""}
396
+
397
+ for chunk in generate(request, out, model_tokens, model_state, max_tokens=(64000 if enableReasoning else request.max_tokens)):
398
+ completionTokenCount += 1
399
+ chunkContent = chunk["content"]
400
+ finishReason = chunk["finish_reason"]
401
+
402
+ if enableReasoning:
403
+ buffer.append(chunkContent)
404
+ fullText = "".join(buffer)
405
+
406
+ # Lógica compleja de streaming para separar <think> del contenido
407
+ # (Simplificada para mantener el archivo manejable, lógica idéntica a versión original)
408
+ markStart = fullText.find("<", streamConfig["fullTextCursor"])
409
+ if not streamConfig["isChecking"] and markStart != -1:
410
+ streamConfig["isChecking"] = True
411
+ content_to_send = fullText[streamConfig["fullTextCursor"]:markStart]
412
+ if content_to_send:
413
+ delta = ChatCompletionMessage(reasoning_content=content_to_send) if streamConfig["in_think"] else ChatCompletionMessage(content=content_to_send)
414
+ yield f"data: {ChatCompletionChunk(id=completionId, created=createTimestamp, model=request.model, choices=[ChatCompletionChoice(index=0, delta=delta, finish_reason=None)]).model_dump_json()}\n\n"
415
+ streamConfig["cacheStr"] = ""
416
+ streamConfig["fullTextCursor"] = markStart
417
+
418
+ if streamConfig["isChecking"]:
419
+ streamConfig["cacheStr"] = fullText[streamConfig["fullTextCursor"]:]
420
+ else:
421
+ delta = ChatCompletionMessage(reasoning_content=chunkContent) if streamConfig["in_think"] else ChatCompletionMessage(content=chunkContent)
422
+ yield f"data: {ChatCompletionChunk(id=completionId, created=createTimestamp, model=request.model, choices=[ChatCompletionChoice(index=0, delta=delta, finish_reason=None)]).model_dump_json()}\n\n"
423
+ streamConfig["fullTextCursor"] = len(fullText)
424
+
425
+ markEnd = fullText.find(">", streamConfig["fullTextCursor"])
426
+ if (streamConfig["isChecking"] and markEnd != -1) or finishReason:
427
+ streamConfig["isChecking"] = False
428
+ if "<think>" in streamConfig["cacheStr"]: streamConfig["in_think"] = True
429
+ elif "</think>" in streamConfig["cacheStr"]: streamConfig["in_think"] = False
430
+
431
+ # Flush residual
432
+ clean_content = streamConfig["cacheStr"].replace("<think>", "").replace("</think>", "")
433
+ if clean_content:
434
+ delta = ChatCompletionMessage(reasoning_content=clean_content) if streamConfig["in_think"] else ChatCompletionMessage(content=clean_content)
435
+ yield f"data: {ChatCompletionChunk(id=completionId, created=createTimestamp, model=request.model, choices=[ChatCompletionChoice(index=0, delta=delta, finish_reason=None)]).model_dump_json()}\n\n"
436
+
437
+ streamConfig["fullTextCursor"] = len(fullText)
438
+
439
+ else:
440
+ # Modo simple sin reasoning
441
+ yield f"data: {ChatCompletionChunk(id=completionId, created=createTimestamp, model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(content=chunkContent), finish_reason=finishReason)]).model_dump_json()}\n\n"
442
+
443
+ await asyncio.sleep(0)
444
+
445
+ yield "data: [DONE]\n\n"
446
+
447
+ # --- API ROUTES ---
448
+
449
+ @app.post("/api/v1/chat/completions")
450
+ async def chat_completions(request: ChatCompletionRequest):
451
+ completionId = str(next(CompletionIdGenerator))
452
+
453
+ # Procesar sufijos de modelo
454
+ raw_model = request.model
455
+ modelName = request.model.split(":")[0]
456
+ enableReasoning = ":thinking" in request.model
457
+ if ":online" in modelName: modelName = modelName.replace(":online", "")
458
+
459
+ # Resolver alias
460
+ if "rwkv-latest" in request.model:
461
+ if enableReasoning and DEFAULT_REASONING_MODEL_NAME:
462
+ request.model = DEFAULT_REASONING_MODEL_NAME
463
+ defaultSampler = MODEL_STORAGE[DEFAULT_REASONING_MODEL_NAME].MODEL_CONFIG.DEFAULT_SAMPLER
464
+ elif DEFALUT_MODEL_NAME:
465
+ request.model = DEFALUT_MODEL_NAME
466
+ defaultSampler = MODEL_STORAGE[DEFALUT_MODEL_NAME].MODEL_CONFIG.DEFAULT_SAMPLER
467
+ else:
468
+ raise HTTPException(500, "Default models not configured")
469
+ elif modelName in MODEL_STORAGE:
470
+ request.model = modelName
471
+ defaultSampler = MODEL_STORAGE[modelName].MODEL_CONFIG.DEFAULT_SAMPLER
472
+ else:
473
+ raise HTTPException(404, f"Model {modelName} not found")
474
+
475
+ # Aplicar parámetros por defecto
476
+ req_dict = request.model_dump()
477
+ for k, v in defaultSampler.model_dump().items():
478
+ if req_dict[k] is None: req_dict[k] = v
479
+ realRequest = ChatCompletionRequest(**req_dict)
480
+
481
+ # --- INYECCIÓN DE BÚSQUEDA WEB ---
482
+ if realRequest.messages and len(realRequest.messages) > 0:
483
+ last_msg = realRequest.messages[-1]
484
+ if last_msg.role == "user" and should_trigger_search(last_msg.content, raw_model):
485
+ search_context = search_web_and_get_context(last_msg.content)
486
+ if search_context:
487
+ system_msg = ChatMessage(role="System", content=search_context)
488
+ insert_idx = 1 if len(realRequest.messages) > 0 and realRequest.messages[0].role == "System" else 0
489
+ realRequest.messages.insert(insert_idx, system_msg)
490
+ logger.info(f"[SEARCH] Context injected for {completionId}")
491
+
492
+ # Ejecutar respuesta
493
+ if request.stream:
494
+ return StreamingResponse(chatResponseStream(realRequest, None, completionId, enableReasoning), media_type="text/event-stream")
495
+ else:
496
+ return await chatResponse(realRequest, None, completionId, enableReasoning)
497
+
498
+ @app.get("/api/v1/models")
499
+ @app.get("/models")
500
+ async def list_models():
501
+ models = [{"id": m, "object": "model", "created": int(time.time()), "owned_by": "rwkv-server"} for m in MODEL_STORAGE.keys()]
502
+ if DEFALUT_MODEL_NAME:
503
+ models.append({"id": "rwkv-latest", "object": "model", "created": int(time.time()), "owned_by": "rwkv-server"})
504
+ models.append({"id": "rwkv-latest:online", "object": "model", "created": int(time.time()), "owned_by": "rwkv-server"})
505
+ if DEFAULT_REASONING_MODEL_NAME:
506
+ models.append({"id": "rwkv-latest:thinking", "object": "model", "created": int(time.time()), "owned_by": "rwkv-server"})
507
+ models.append({"id": "rwkv-latest:thinking:online", "object": "model", "created": int(time.time()), "owned_by": "rwkv-server"})
508
+ return {"object": "list", "data": models}
509
+
510
+ app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static")
511
+
512
+ if __name__ == "__main__":
513
+ import uvicorn
514
+ uvicorn.run(app, host=CONFIG.HOST, port=CONFIG.PORT)