import os #os.system("pip install faker duckduckgo_search") import copy import types import gc import sys import re import time import collections import asyncio import random from typing import List, Optional, Union, Any, Dict # --- CONFIGURACIÓN DE ENTORNO --- if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio": from modelscope import patch_hub patch_hub() os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256" os.environ["RWKV_V7_ON"] = "1" os.environ["RWKV_JIT_ON"] = "1" # --- IMPORTS --- from config import CONFIG, ModelConfig from utils import ( cleanMessages, parse_think_response, remove_nested_think_tags_stack, format_bytes, log, ) from huggingface_hub import hf_hub_download from loguru import logger from snowflake import SnowflakeGenerator import numpy as np import torch import requests # Dependencias Opcionales try: from duckduckgo_search import DDGS HAS_DDG = True except ImportError: HAS_DDG = False try: from faker import Faker fake = Faker() HAS_FAKER = True except ImportError: HAS_FAKER = False from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.middleware.gzip import GZipMiddleware from pydantic import BaseModel, Field, model_validator # --- SETUP INICIAL --- CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595) if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available(): CONFIG.STRATEGY = "cpu fp16" if "cuda" in CONFIG.STRATEGY.lower(): from pynvml import * nvmlInit() gpu_h = nvmlDeviceGetHandleByIndex(0) torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True os.environ["RWKV_CUDA_ON"] = "1" if CONFIG.RWKV_CUDA_ON else "0" else: os.environ["RWKV_CUDA_ON"] = "0" from rwkv.model import RWKV from rwkv.utils import PIPELINE, PIPELINE_ARGS from api_types import ( ChatMessage, ChatCompletion, ChatCompletionChunk, Usage, ChatCompletionChoice, ChatCompletionMessage ) # --- ALMACENAMIENTO DE MODELOS --- class ModelStorage: MODEL_CONFIG: Optional[ModelConfig] = None model: Optional[RWKV] = None pipeline: Optional[PIPELINE] = None MODEL_STORAGE: Dict[str, ModelStorage] = {} DEFALUT_MODEL_NAME = None DEFAULT_REASONING_MODEL_NAME = None for model_config in CONFIG.MODELS: if model_config.MODEL_FILE_PATH is None: model_config.MODEL_FILE_PATH = hf_hub_download( repo_id=model_config.DOWNLOAD_MODEL_REPO_ID, filename=model_config.DOWNLOAD_MODEL_FILE_NAME, local_dir=model_config.DOWNLOAD_MODEL_DIR, ) if model_config.DEFAULT_CHAT: DEFALUT_MODEL_NAME = model_config.SERVICE_NAME if model_config.DEFAULT_REASONING: DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage() MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV( model=model_config.MODEL_FILE_PATH.replace(".pth", ""), strategy=CONFIG.STRATEGY, ) MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE( MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB ) if "cuda" in CONFIG.STRATEGY: torch.cuda.empty_cache() gc.collect() # --- CLASES DE DATOS --- class ChatCompletionRequest(BaseModel): model: str = Field(default="rwkv-latest") messages: Optional[List[ChatMessage]] = Field(default=None) prompt: Optional[str] = Field(default=None) max_tokens: Optional[int] = Field(default=None) temperature: Optional[float] = Field(default=None) top_p: Optional[float] = Field(default=None) presence_penalty: Optional[float] = Field(default=None) count_penalty: Optional[float] = Field(default=None) penalty_decay: Optional[float] = Field(default=None) stream: Optional[bool] = Field(default=False) stop: Optional[list[str]] = Field(["\n\n"]) stop_tokens: Optional[list[int]] = Field([0]) @model_validator(mode="before") @classmethod def validate_mutual_exclusivity(cls, data: Any) -> Any: if not isinstance(data, dict): return data if "messages" in data and "prompt" in data and data["messages"] and data["prompt"]: raise ValueError("messages and prompt cannot coexist.") return data # --- PROTOCOLO DE VERDAD Y FLUIDEZ --- class TruthAndFlowProtocol: """ Gestiona la coherencia factual y evita la repetición robótica. """ SYSTEM_INSTRUCTION = """ PROTOCOL: FACTUAL_AND_CONCISE 1. TRUTH: Say ONLY what is verified in the context or internal knowledge. 2. NO REPETITION: Do not repeat facts. Do not repeat sentence structures. 3. CONCISENESS: Get to the point directly. 4. LABELS: Use [VERIFICADO] for confirmed data, [INCIERTO] for contradictions. 5. NO FILLER: Avoid "As an AI", "I think", "Basically". """.strip() @staticmethod def optimize_params(request: ChatCompletionRequest): """ Calibración fina para evitar bucles sin perder la factualidad. """ # Temperatura baja (0.15) pero no cero. # Si es 0.0, entra en bucle seguro. 0.15 da el mínimo margen para variar palabras. request.temperature = 0.15 # Top P estricto (0.1) # Solo permite palabras lógicas. request.top_p = 0.1 # --- AQUÍ ESTÁ LA MAGIA ANTI-REPETICIÓN --- # Frequency Penalty (1.2): # Castigo ALTO si usas la MISMA palabra exacta muchas veces. # Evita: "y y y y" o "es es es". request.count_penalty = 1.2 # Presence Penalty (0.7): # Castigo MEDIO si repites el mismo concepto. # Evita decir lo mismo con otras palabras inmediatamente. request.presence_penalty = 0.7 # Penalty Decay (0.996): # "Perdona" el uso de palabras después de un rato. # Necesario para que pueda volver a usar "el", "de", "que" sin bloquearse. request.penalty_decay = 0.996 @staticmethod def search_verify(query: str) -> str: """Búsqueda y corroboración web.""" if not HAS_DDG: return "" try: # Búsqueda normal ddgs = DDGS() results = ddgs.text(query, max_results=3) # Búsqueda de fact-check si es necesario is_suspicious = any(w in query.lower() for w in ["verdad", "fake", "bulo", "cierto"]) if is_suspicious: check_res = ddgs.text(f"{query} fact check", max_results=2) if check_res: results.extend(check_res) if not results: return "" context = "VERIFIED CONTEXT (Use strict labels [VERIFICADO]/[INCIERTO]):\n" for r in results: context += f"- {r['body']} (Source: {r['title']})\n" return context except Exception: return "" # --- APP SETUP --- app = FastAPI(title="RWKV High-Fidelity Server") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5) @app.middleware("http") async def privacy_middleware(request: Request, call_next): if HAS_FAKER: request.scope["client"] = (fake.ipv4(), request.client.port if request.client else 80) return await call_next(request) # --- CACHÉ --- search_cache = collections.OrderedDict() def get_context(query: str) -> str: if query in search_cache: return search_cache[query] ctx = TruthAndFlowProtocol.search_verify(query) if len(search_cache) > 50: search_cache.popitem(last=False) search_cache[query] = ctx return ctx def needs_search(msg: str, model: str) -> bool: if ":online" in model: return True return any(k in msg.lower() for k in ["quien", "cuando", "donde", "precio", "es verdad", "dato"]) # --- CORE RWKV LOOP --- async def runPrefill(request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state): ctx = ctx.replace("\r\n", "\n") tokens = MODEL_STORAGE[request.model].pipeline.encode(ctx) model_tokens.extend([int(x) for x in tokens]) while len(tokens) > 0: out, model_state = MODEL_STORAGE[request.model].model.forward(tokens[: CONFIG.CHUNK_LEN], model_state) tokens = tokens[CONFIG.CHUNK_LEN :] await asyncio.sleep(0) return out, model_tokens, model_state def generate(request: ChatCompletionRequest, out, model_tokens: List[int], model_state, max_tokens=2048): # Asignación correcta de penalizaciones a PIPELINE_ARGS # Nota: alpha_frequency suele mapearse a count_penalty en la API de OpenAI args = PIPELINE_ARGS( temperature=request.temperature, top_p=request.top_p, alpha_frequency=request.count_penalty, # Penalización por repetición exacta alpha_presence=request.presence_penalty, # Penalización por presencia de concepto token_ban=[], token_stop=[0] ) occurrence = {} out_tokens = [] out_last = 0 cache_word_list = [] for i in range(max_tokens): # Aplicación manual de penalizaciones al vector de logits 'out' for n in occurrence: out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency token = MODEL_STORAGE[request.model].pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p) if token == 0: yield {"content": "".join(cache_word_list), "finish_reason": "stop", "state": model_state} del out; gc.collect(); return out, model_state = MODEL_STORAGE[request.model].model.forward([token], model_state) model_tokens.append(token) out_tokens.append(token) # Decay: La memoria de repetición se desvanece lentamente for xxx in occurrence: occurrence[xxx] *= request.penalty_decay occurrence[token] = 1 + (occurrence.get(token, 0)) tmp = MODEL_STORAGE[request.model].pipeline.decode(out_tokens[out_last:]) if "\ufffd" in tmp: continue cache_word_list.append(tmp) out_last = i + 1 if len(cache_word_list) > 1: yield {"content": cache_word_list.pop(0), "finish_reason": None} yield {"content": "".join(cache_word_list), "finish_reason": "length"} # --- HANDLER --- async def chatResponseStream(request: ChatCompletionRequest, model_state: any, completionId: str, enableReasoning: bool): clean_msg = cleanMessages(request.messages, enableReasoning) prompt = f"{clean_msg}\n\nAssistant:{'