Bbbv / app.py.bak
Ksjsjjdj's picture
Rename app.py to app.py.bak
9556615 verified
import os
#os.system("pip install faker duckduckgo_search")
import copy
import types
import gc
import sys
import re
import time
import collections
import asyncio
import random
from typing import List, Optional, Union, Any, Dict
# --- CONFIGURACI脫N DE ENTORNO ---
if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio":
from modelscope import patch_hub
patch_hub()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"
os.environ["RWKV_V7_ON"] = "1"
os.environ["RWKV_JIT_ON"] = "1"
# --- IMPORTS ---
from config import CONFIG, ModelConfig
from utils import (
cleanMessages,
parse_think_response,
remove_nested_think_tags_stack,
format_bytes,
log,
)
from huggingface_hub import hf_hub_download
from loguru import logger
from snowflake import SnowflakeGenerator
import numpy as np
import torch
import requests
# Dependencias Opcionales
try:
from duckduckgo_search import DDGS
HAS_DDG = True
except ImportError:
HAS_DDG = False
try:
from faker import Faker
fake = Faker()
HAS_FAKER = True
except ImportError:
HAS_FAKER = False
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.gzip import GZipMiddleware
from pydantic import BaseModel, Field, model_validator
# --- SETUP INICIAL ---
CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available():
CONFIG.STRATEGY = "cpu fp16"
if "cuda" in CONFIG.STRATEGY.lower():
from pynvml import *
nvmlInit()
gpu_h = nvmlDeviceGetHandleByIndex(0)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
os.environ["RWKV_CUDA_ON"] = "1" if CONFIG.RWKV_CUDA_ON else "0"
else:
os.environ["RWKV_CUDA_ON"] = "0"
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
from api_types import (
ChatMessage, ChatCompletion, ChatCompletionChunk, Usage,
ChatCompletionChoice, ChatCompletionMessage
)
# --- ALMACENAMIENTO DE MODELOS ---
class ModelStorage:
MODEL_CONFIG: Optional[ModelConfig] = None
model: Optional[RWKV] = None
pipeline: Optional[PIPELINE] = None
MODEL_STORAGE: Dict[str, ModelStorage] = {}
DEFALUT_MODEL_NAME = None
DEFAULT_REASONING_MODEL_NAME = None
for model_config in CONFIG.MODELS:
if model_config.MODEL_FILE_PATH is None:
model_config.MODEL_FILE_PATH = hf_hub_download(
repo_id=model_config.DOWNLOAD_MODEL_REPO_ID,
filename=model_config.DOWNLOAD_MODEL_FILE_NAME,
local_dir=model_config.DOWNLOAD_MODEL_DIR,
)
if model_config.DEFAULT_CHAT: DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
if model_config.DEFAULT_REASONING: DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME
MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV(
model=model_config.MODEL_FILE_PATH.replace(".pth", ""),
strategy=CONFIG.STRATEGY,
)
MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE(
MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB
)
if "cuda" in CONFIG.STRATEGY:
torch.cuda.empty_cache()
gc.collect()
# --- CLASES DE DATOS ---
class ChatCompletionRequest(BaseModel):
model: str = Field(default="rwkv-latest")
messages: Optional[List[ChatMessage]] = Field(default=None)
prompt: Optional[str] = Field(default=None)
max_tokens: Optional[int] = Field(default=None)
temperature: Optional[float] = Field(default=None)
top_p: Optional[float] = Field(default=None)
presence_penalty: Optional[float] = Field(default=None)
count_penalty: Optional[float] = Field(default=None)
penalty_decay: Optional[float] = Field(default=None)
stream: Optional[bool] = Field(default=False)
stop: Optional[list[str]] = Field(["\n\n"])
stop_tokens: Optional[list[int]] = Field([0])
@model_validator(mode="before")
@classmethod
def validate_mutual_exclusivity(cls, data: Any) -> Any:
if not isinstance(data, dict): return data
if "messages" in data and "prompt" in data and data["messages"] and data["prompt"]:
raise ValueError("messages and prompt cannot coexist.")
return data
# --- PROTOCOLO DE VERDAD Y FLUIDEZ ---
class TruthAndFlowProtocol:
"""
Gestiona la coherencia factual y evita la repetici贸n rob贸tica.
"""
SYSTEM_INSTRUCTION = """
PROTOCOL: FACTUAL_AND_CONCISE
1. TRUTH: Say ONLY what is verified in the context or internal knowledge.
2. NO REPETITION: Do not repeat facts. Do not repeat sentence structures.
3. CONCISENESS: Get to the point directly.
4. LABELS: Use [VERIFICADO] for confirmed data, [INCIERTO] for contradictions.
5. NO FILLER: Avoid "As an AI", "I think", "Basically".
""".strip()
@staticmethod
def optimize_params(request: ChatCompletionRequest):
"""
Calibraci贸n fina para evitar bucles sin perder la factualidad.
"""
# Temperatura baja (0.15) pero no cero.
# Si es 0.0, entra en bucle seguro. 0.15 da el m铆nimo margen para variar palabras.
request.temperature = 0.15
# Top P estricto (0.1)
# Solo permite palabras l贸gicas.
request.top_p = 0.1
# --- AQU脥 EST脕 LA MAGIA ANTI-REPETICI脫N ---
# Frequency Penalty (1.2):
# Castigo ALTO si usas la MISMA palabra exacta muchas veces.
# Evita: "y y y y" o "es es es".
request.count_penalty = 1.2
# Presence Penalty (0.7):
# Castigo MEDIO si repites el mismo concepto.
# Evita decir lo mismo con otras palabras inmediatamente.
request.presence_penalty = 0.7
# Penalty Decay (0.996):
# "Perdona" el uso de palabras despu茅s de un rato.
# Necesario para que pueda volver a usar "el", "de", "que" sin bloquearse.
request.penalty_decay = 0.996
@staticmethod
def search_verify(query: str) -> str:
"""B煤squeda y corroboraci贸n web."""
if not HAS_DDG: return ""
try:
# B煤squeda normal
ddgs = DDGS()
results = ddgs.text(query, max_results=3)
# B煤squeda de fact-check si es necesario
is_suspicious = any(w in query.lower() for w in ["verdad", "fake", "bulo", "cierto"])
if is_suspicious:
check_res = ddgs.text(f"{query} fact check", max_results=2)
if check_res: results.extend(check_res)
if not results: return ""
context = "VERIFIED CONTEXT (Use strict labels [VERIFICADO]/[INCIERTO]):\n"
for r in results:
context += f"- {r['body']} (Source: {r['title']})\n"
return context
except Exception:
return ""
# --- APP SETUP ---
app = FastAPI(title="RWKV High-Fidelity Server")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5)
@app.middleware("http")
async def privacy_middleware(request: Request, call_next):
if HAS_FAKER:
request.scope["client"] = (fake.ipv4(), request.client.port if request.client else 80)
return await call_next(request)
# --- CACH脡 ---
search_cache = collections.OrderedDict()
def get_context(query: str) -> str:
if query in search_cache: return search_cache[query]
ctx = TruthAndFlowProtocol.search_verify(query)
if len(search_cache) > 50: search_cache.popitem(last=False)
search_cache[query] = ctx
return ctx
def needs_search(msg: str, model: str) -> bool:
if ":online" in model: return True
return any(k in msg.lower() for k in ["quien", "cuando", "donde", "precio", "es verdad", "dato"])
# --- CORE RWKV LOOP ---
async def runPrefill(request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state):
ctx = ctx.replace("\r\n", "\n")
tokens = MODEL_STORAGE[request.model].pipeline.encode(ctx)
model_tokens.extend([int(x) for x in tokens])
while len(tokens) > 0:
out, model_state = MODEL_STORAGE[request.model].model.forward(tokens[: CONFIG.CHUNK_LEN], model_state)
tokens = tokens[CONFIG.CHUNK_LEN :]
await asyncio.sleep(0)
return out, model_tokens, model_state
def generate(request: ChatCompletionRequest, out, model_tokens: List[int], model_state, max_tokens=2048):
# Asignaci贸n correcta de penalizaciones a PIPELINE_ARGS
# Nota: alpha_frequency suele mapearse a count_penalty en la API de OpenAI
args = PIPELINE_ARGS(
temperature=request.temperature,
top_p=request.top_p,
alpha_frequency=request.count_penalty, # Penalizacin por repeticin exacta
alpha_presence=request.presence_penalty, # Penalizacin por presencia de concepto
token_ban=[],
token_stop=[0]
)
occurrence = {}
out_tokens = []
out_last = 0
cache_word_list = []
for i in range(max_tokens):
# Aplicaci贸n manual de penalizaciones al vector de logits 'out'
for n in occurrence:
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
token = MODEL_STORAGE[request.model].pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
if token == 0:
yield {"content": "".join(cache_word_list), "finish_reason": "stop", "state": model_state}
del out; gc.collect(); return
out, model_state = MODEL_STORAGE[request.model].model.forward([token], model_state)
model_tokens.append(token)
out_tokens.append(token)
# Decay: La memoria de repetici贸n se desvanece lentamente
for xxx in occurrence: occurrence[xxx] *= request.penalty_decay
occurrence[token] = 1 + (occurrence.get(token, 0))
tmp = MODEL_STORAGE[request.model].pipeline.decode(out_tokens[out_last:])
if "\ufffd" in tmp: continue
cache_word_list.append(tmp)
out_last = i + 1
if len(cache_word_list) > 1:
yield {"content": cache_word_list.pop(0), "finish_reason": None}
yield {"content": "".join(cache_word_list), "finish_reason": "length"}
# --- HANDLER ---
async def chatResponseStream(request: ChatCompletionRequest, model_state: any, completionId: str, enableReasoning: bool):
clean_msg = cleanMessages(request.messages, enableReasoning)
prompt = f"{clean_msg}\n\nAssistant:{' <think' if enableReasoning else ''}"
out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(role='Assistant', content=''), finish_reason=None)]).model_dump_json()}\n\n"
for chunk in generate(request, out, model_tokens, model_state, max_tokens=request.max_tokens or 4096):
content = chunk["content"]
if content:
yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(content=content), finish_reason=None)]).model_dump_json()}\n\n"
if chunk.get("finish_reason"): break
await asyncio.sleep(0)
yield "data: [DONE]\n\n"
@app.post("/v1/chat/completions")
@app.post("/api/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
completionId = str(next(CompletionIdGenerator))
raw_model = request.model
model_key = request.model.split(":")[0].replace(":online", "")
is_reasoning = ":thinking" in request.model
target_model = model_key
if "rwkv-latest" in model_key:
if is_reasoning and DEFAULT_REASONING_MODEL_NAME: target_model = DEFAULT_REASONING_MODEL_NAME
elif DEFALUT_MODEL_NAME: target_model = DEFALUT_MODEL_NAME
if target_model not in MODEL_STORAGE: raise HTTPException(404, "Model not found")
request.model = target_model
default_sampler = MODEL_STORAGE[target_model].MODEL_CONFIG.DEFAULT_SAMPLER
req_data = request.model_dump()
for k, v in default_sampler.model_dump().items():
if req_data.get(k) is None: req_data[k] = v
realRequest = ChatCompletionRequest(**req_data)
# --- L脫GICA DE OPTIMIZACI脫N ---
# 1. System Prompt Anti-Repetici贸n
sys_msg = ChatMessage(role="System", content=TruthAndFlowProtocol.SYSTEM_INSTRUCTION)
if realRequest.messages:
if realRequest.messages[0].role == "System":
realRequest.messages[0].content = f"{TruthAndFlowProtocol.SYSTEM_INSTRUCTION}\n\n{realRequest.messages[0].content}"
else:
realRequest.messages.insert(0, sys_msg)
# 2. Inyecci贸n de Contexto (si aplica)
last_msg = realRequest.messages[-1]
if last_msg.role == "user" and needs_search(last_msg.content, raw_model):
ctx = get_context(last_msg.content)
if ctx: realRequest.messages.insert(-1, ChatMessage(role="System", content=ctx))
# 3. Ajuste Fino de Par谩metros (El ncleo anti-repeticin)
TruthAndFlowProtocol.optimize_params(realRequest)
logger.info(f"[REQ] {completionId} | Params: T={realRequest.temperature} Freq={realRequest.count_penalty} Pres={realRequest.presence_penalty}")
return StreamingResponse(chatResponseStream(realRequest, None, completionId, is_reasoning), media_type="text/event-stream")
@app.get("/api/v1/models")
@app.get("/v1/models")
async def list_models():
return {"object": "list", "data": [{"id": "rwkv-latest", "object": "model"}]}
app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=CONFIG.HOST, port=CONFIG.PORT)