File size: 14,257 Bytes
bbff189
ece9c25
bbff189
 
 
 
 
 
 
 
 
 
 
f8f0c2e
bbff189
 
 
 
 
 
 
 
1cde957
bbff189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cde957
bbff189
 
 
 
 
 
 
 
 
 
 
 
 
bf3068d
bbff189
 
 
 
 
 
f8f0c2e
bbff189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8f0c2e
bbff189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf3068d
 
bbff189
 
 
 
 
 
 
 
 
 
 
 
 
 
f8f0c2e
bbff189
bf3068d
bbff189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cde957
 
bf3068d
1cde957
bf3068d
1cde957
 
 
 
 
 
 
 
f8f0c2e
bf3068d
f8f0c2e
1cde957
f8f0c2e
1cde957
f8f0c2e
1cde957
 
 
 
 
 
f8f0c2e
bf3068d
1cde957
bf3068d
1cde957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf3068d
1cde957
 
 
 
 
 
 
 
 
 
 
 
bbff189
 
 
 
 
 
 
 
 
 
 
f8f0c2e
bbff189
bf3068d
f8f0c2e
bbff189
1cde957
bbff189
bf3068d
1cde957
bf3068d
1cde957
 
 
 
bbff189
1cde957
bf3068d
1cde957
bbff189
1cde957
bbff189
 
 
f8f0c2e
bbff189
bf3068d
bbff189
 
 
 
 
1cde957
 
bbff189
f8f0c2e
bbff189
1cde957
 
 
 
bbff189
1cde957
bbff189
bf3068d
bbff189
 
bf3068d
bbff189
1cde957
 
 
bbff189
bf3068d
 
 
 
 
bbff189
 
 
 
bf3068d
1cde957
bbff189
 
bf3068d
 
bbff189
 
 
bf3068d
1cde957
bf3068d
 
 
bbff189
1cde957
bf3068d
f8f0c2e
 
bbff189
 
 
bf3068d
bbff189
bf3068d
 
 
 
 
bbff189
bf3068d
bbff189
 
08e53db
bbff189
 
 
 
 
1cde957
bf3068d
bbff189
f8f0c2e
bf3068d
f8f0c2e
 
bf3068d
1cde957
f8f0c2e
 
 
bf3068d
 
 
 
 
1cde957
bf3068d
1cde957
 
f8f0c2e
 
1cde957
f8f0c2e
1cde957
 
 
 
 
 
 
 
 
 
 
 
bf3068d
1cde957
bbff189
 
08e53db
bbff189
f8f0c2e
bbff189
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import os
#os.system("pip install faker duckduckgo_search")
import copy
import types
import gc
import sys
import re
import time
import collections
import asyncio
import random
from typing import List, Optional, Union, Any, Dict

# --- CONFIGURACI脫N DE ENTORNO ---
if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio":
    from modelscope import patch_hub
    patch_hub()

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"
os.environ["RWKV_V7_ON"] = "1"
os.environ["RWKV_JIT_ON"] = "1"

# --- IMPORTS ---
from config import CONFIG, ModelConfig
from utils import (
    cleanMessages,
    parse_think_response,
    remove_nested_think_tags_stack,
    format_bytes,
    log,
)
from huggingface_hub import hf_hub_download
from loguru import logger
from snowflake import SnowflakeGenerator
import numpy as np
import torch
import requests

# Dependencias Opcionales
try:
    from duckduckgo_search import DDGS
    HAS_DDG = True
except ImportError:
    HAS_DDG = False

try:
    from faker import Faker
    fake = Faker()
    HAS_FAKER = True
except ImportError:
    HAS_FAKER = False

from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.gzip import GZipMiddleware
from pydantic import BaseModel, Field, model_validator

# --- SETUP INICIAL ---
CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)

if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available():
    CONFIG.STRATEGY = "cpu fp16"

if "cuda" in CONFIG.STRATEGY.lower():
    from pynvml import *
    nvmlInit()
    gpu_h = nvmlDeviceGetHandleByIndex(0)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cuda.matmul.allow_tf32 = True
    os.environ["RWKV_CUDA_ON"] = "1" if CONFIG.RWKV_CUDA_ON else "0"
else:
    os.environ["RWKV_CUDA_ON"] = "0"

from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
from api_types import (
    ChatMessage, ChatCompletion, ChatCompletionChunk, Usage,
    ChatCompletionChoice, ChatCompletionMessage
)

# --- ALMACENAMIENTO DE MODELOS ---
class ModelStorage:
    MODEL_CONFIG: Optional[ModelConfig] = None
    model: Optional[RWKV] = None
    pipeline: Optional[PIPELINE] = None

MODEL_STORAGE: Dict[str, ModelStorage] = {}
DEFALUT_MODEL_NAME = None
DEFAULT_REASONING_MODEL_NAME = None

for model_config in CONFIG.MODELS:
    if model_config.MODEL_FILE_PATH is None:
        model_config.MODEL_FILE_PATH = hf_hub_download(
            repo_id=model_config.DOWNLOAD_MODEL_REPO_ID,
            filename=model_config.DOWNLOAD_MODEL_FILE_NAME,
            local_dir=model_config.DOWNLOAD_MODEL_DIR,
        )
    if model_config.DEFAULT_CHAT: DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
    if model_config.DEFAULT_REASONING: DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME

    MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
    MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
    MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV(
        model=model_config.MODEL_FILE_PATH.replace(".pth", ""),
        strategy=CONFIG.STRATEGY,
    )
    MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE(
        MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB
    )
    if "cuda" in CONFIG.STRATEGY:
        torch.cuda.empty_cache()
        gc.collect()

# --- CLASES DE DATOS ---
class ChatCompletionRequest(BaseModel):
    model: str = Field(default="rwkv-latest")
    messages: Optional[List[ChatMessage]] = Field(default=None)
    prompt: Optional[str] = Field(default=None)
    max_tokens: Optional[int] = Field(default=None)
    temperature: Optional[float] = Field(default=None)
    top_p: Optional[float] = Field(default=None)
    presence_penalty: Optional[float] = Field(default=None)
    count_penalty: Optional[float] = Field(default=None)
    penalty_decay: Optional[float] = Field(default=None)
    stream: Optional[bool] = Field(default=False)
    stop: Optional[list[str]] = Field(["\n\n"])
    stop_tokens: Optional[list[int]] = Field([0])

    @model_validator(mode="before")
    @classmethod
    def validate_mutual_exclusivity(cls, data: Any) -> Any:
        if not isinstance(data, dict): return data
        if "messages" in data and "prompt" in data and data["messages"] and data["prompt"]:
            raise ValueError("messages and prompt cannot coexist.")
        return data

# --- PROTOCOLO DE VERDAD Y FLUIDEZ ---
class TruthAndFlowProtocol:
    """
    Gestiona la coherencia factual y evita la repetici贸n rob贸tica.
    """

    SYSTEM_INSTRUCTION = """
PROTOCOL: FACTUAL_AND_CONCISE
1. TRUTH: Say ONLY what is verified in the context or internal knowledge.
2. NO REPETITION: Do not repeat facts. Do not repeat sentence structures.
3. CONCISENESS: Get to the point directly.
4. LABELS: Use [VERIFICADO] for confirmed data, [INCIERTO] for contradictions.
5. NO FILLER: Avoid "As an AI", "I think", "Basically".
""".strip()

    @staticmethod
    def optimize_params(request: ChatCompletionRequest):
        """
        Calibraci贸n fina para evitar bucles sin perder la factualidad.
        """
        # Temperatura baja (0.15) pero no cero.
        # Si es 0.0, entra en bucle seguro. 0.15 da el m铆nimo margen para variar palabras.
        request.temperature = 0.15
        
        # Top P estricto (0.1)
        # Solo permite palabras l贸gicas.
        request.top_p = 0.1

        # --- AQU脥 EST脕 LA MAGIA ANTI-REPETICI脫N ---
        
        # Frequency Penalty (1.2):
        # Castigo ALTO si usas la MISMA palabra exacta muchas veces.
        # Evita: "y y y y" o "es es es".
        request.count_penalty = 1.2
        
        # Presence Penalty (0.7):
        # Castigo MEDIO si repites el mismo concepto.
        # Evita decir lo mismo con otras palabras inmediatamente.
        request.presence_penalty = 0.7
        
        # Penalty Decay (0.996):
        # "Perdona" el uso de palabras despu茅s de un rato.
        # Necesario para que pueda volver a usar "el", "de", "que" sin bloquearse.
        request.penalty_decay = 0.996

    @staticmethod
    def search_verify(query: str) -> str:
        """B煤squeda y corroboraci贸n web."""
        if not HAS_DDG: return ""
        try:
            # B煤squeda normal
            ddgs = DDGS()
            results = ddgs.text(query, max_results=3)
            
            # B煤squeda de fact-check si es necesario
            is_suspicious = any(w in query.lower() for w in ["verdad", "fake", "bulo", "cierto"])
            if is_suspicious:
                 check_res = ddgs.text(f"{query} fact check", max_results=2)
                 if check_res: results.extend(check_res)

            if not results: return ""

            context = "VERIFIED CONTEXT (Use strict labels [VERIFICADO]/[INCIERTO]):\n"
            for r in results:
                context += f"- {r['body']} (Source: {r['title']})\n"
            
            return context
        except Exception:
            return ""

# --- APP SETUP ---
app = FastAPI(title="RWKV High-Fidelity Server")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5)

@app.middleware("http")
async def privacy_middleware(request: Request, call_next):
    if HAS_FAKER:
        request.scope["client"] = (fake.ipv4(), request.client.port if request.client else 80)
    return await call_next(request)

# --- CACH脡 ---
search_cache = collections.OrderedDict()

def get_context(query: str) -> str:
    if query in search_cache: return search_cache[query]
    ctx = TruthAndFlowProtocol.search_verify(query)
    if len(search_cache) > 50: search_cache.popitem(last=False)
    search_cache[query] = ctx
    return ctx

def needs_search(msg: str, model: str) -> bool:
    if ":online" in model: return True
    return any(k in msg.lower() for k in ["quien", "cuando", "donde", "precio", "es verdad", "dato"])

# --- CORE RWKV LOOP ---
async def runPrefill(request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state):
    ctx = ctx.replace("\r\n", "\n")
    tokens = MODEL_STORAGE[request.model].pipeline.encode(ctx)
    model_tokens.extend([int(x) for x in tokens])
    while len(tokens) > 0:
        out, model_state = MODEL_STORAGE[request.model].model.forward(tokens[: CONFIG.CHUNK_LEN], model_state)
        tokens = tokens[CONFIG.CHUNK_LEN :]
        await asyncio.sleep(0)
    return out, model_tokens, model_state

def generate(request: ChatCompletionRequest, out, model_tokens: List[int], model_state, max_tokens=2048):
    # Asignaci贸n correcta de penalizaciones a PIPELINE_ARGS
    # Nota: alpha_frequency suele mapearse a count_penalty en la API de OpenAI
    args = PIPELINE_ARGS(
        temperature=request.temperature,
        top_p=request.top_p,
        alpha_frequency=request.count_penalty, # Penalizacin por repeticin exacta
        alpha_presence=request.presence_penalty, # Penalizacin por presencia de concepto
        token_ban=[], 
        token_stop=[0]
    )
    
    occurrence = {}
    out_tokens = []
    out_last = 0
    cache_word_list = []
    
    for i in range(max_tokens):
        # Aplicaci贸n manual de penalizaciones al vector de logits 'out'
        for n in occurrence: 
            out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
        
        token = MODEL_STORAGE[request.model].pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
        
        if token == 0:
             yield {"content": "".join(cache_word_list), "finish_reason": "stop", "state": model_state}
             del out; gc.collect(); return

        out, model_state = MODEL_STORAGE[request.model].model.forward([token], model_state)
        model_tokens.append(token)
        out_tokens.append(token)
        
        # Decay: La memoria de repetici贸n se desvanece lentamente
        for xxx in occurrence: occurrence[xxx] *= request.penalty_decay
        occurrence[token] = 1 + (occurrence.get(token, 0))
        
        tmp = MODEL_STORAGE[request.model].pipeline.decode(out_tokens[out_last:])
        if "\ufffd" in tmp: continue
        cache_word_list.append(tmp)
        out_last = i + 1
        
        if len(cache_word_list) > 1:
            yield {"content": cache_word_list.pop(0), "finish_reason": None}
            
    yield {"content": "".join(cache_word_list), "finish_reason": "length"}

# --- HANDLER ---
async def chatResponseStream(request: ChatCompletionRequest, model_state: any, completionId: str, enableReasoning: bool):
    clean_msg = cleanMessages(request.messages, enableReasoning)
    prompt = f"{clean_msg}\n\nAssistant:{' <think' if enableReasoning else ''}"
    
    out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
    
    yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(role='Assistant', content=''), finish_reason=None)]).model_dump_json()}\n\n"
    
    for chunk in generate(request, out, model_tokens, model_state, max_tokens=request.max_tokens or 4096):
        content = chunk["content"]
        if content:
             yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(content=content), finish_reason=None)]).model_dump_json()}\n\n"
        if chunk.get("finish_reason"): break
        await asyncio.sleep(0)
    
    yield "data: [DONE]\n\n"

@app.post("/v1/chat/completions")
@app.post("/api/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
    completionId = str(next(CompletionIdGenerator))
    
    raw_model = request.model
    model_key = request.model.split(":")[0].replace(":online", "")
    is_reasoning = ":thinking" in request.model
    
    target_model = model_key
    if "rwkv-latest" in model_key:
        if is_reasoning and DEFAULT_REASONING_MODEL_NAME: target_model = DEFAULT_REASONING_MODEL_NAME
        elif DEFALUT_MODEL_NAME: target_model = DEFALUT_MODEL_NAME
    
    if target_model not in MODEL_STORAGE: raise HTTPException(404, "Model not found")
    request.model = target_model

    default_sampler = MODEL_STORAGE[target_model].MODEL_CONFIG.DEFAULT_SAMPLER
    req_data = request.model_dump()
    for k, v in default_sampler.model_dump().items():
        if req_data.get(k) is None: req_data[k] = v
    realRequest = ChatCompletionRequest(**req_data)

    # --- L脫GICA DE OPTIMIZACI脫N ---
    
    # 1. System Prompt Anti-Repetici贸n
    sys_msg = ChatMessage(role="System", content=TruthAndFlowProtocol.SYSTEM_INSTRUCTION)
    if realRequest.messages:
        if realRequest.messages[0].role == "System":
             realRequest.messages[0].content = f"{TruthAndFlowProtocol.SYSTEM_INSTRUCTION}\n\n{realRequest.messages[0].content}"
        else:
            realRequest.messages.insert(0, sys_msg)

    # 2. Inyecci贸n de Contexto (si aplica)
    last_msg = realRequest.messages[-1]
    if last_msg.role == "user" and needs_search(last_msg.content, raw_model):
        ctx = get_context(last_msg.content)
        if ctx: realRequest.messages.insert(-1, ChatMessage(role="System", content=ctx))

    # 3. Ajuste Fino de Par谩metros (El ncleo anti-repeticin)
    TruthAndFlowProtocol.optimize_params(realRequest)

    logger.info(f"[REQ] {completionId} | Params: T={realRequest.temperature} Freq={realRequest.count_penalty} Pres={realRequest.presence_penalty}")

    return StreamingResponse(chatResponseStream(realRequest, None, completionId, is_reasoning), media_type="text/event-stream")

@app.get("/api/v1/models")
@app.get("/v1/models")
async def list_models():
    return {"object": "list", "data": [{"id": "rwkv-latest", "object": "model"}]}

app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host=CONFIG.HOST, port=CONFIG.PORT)