|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel, Field |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
import time |
|
|
|
|
|
|
|
|
MODELS = { |
|
|
"yuuki-best": "OpceanAI/Yuuki-best", |
|
|
"yuuki-3.7": "OpceanAI/Yuuki-3.7", |
|
|
"yuuki-v0.1": "OpceanAI/Yuuki-v0.1" |
|
|
} |
|
|
|
|
|
app = FastAPI( |
|
|
title="Yuuki API", |
|
|
description="Local inference API for Yuuki models", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
loaded_models = {} |
|
|
loaded_tokenizers = {} |
|
|
|
|
|
|
|
|
def load_model(model_key: str): |
|
|
"""Lazy load: solo carga el modelo cuando se necesita""" |
|
|
if model_key not in loaded_models: |
|
|
print(f"Loading {model_key}...") |
|
|
model_id = MODELS[model_key] |
|
|
|
|
|
loaded_tokenizers[model_key] = AutoTokenizer.from_pretrained(model_id) |
|
|
loaded_models[model_key] = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float32 |
|
|
).to("cpu") |
|
|
loaded_models[model_key].eval() |
|
|
print(f"{model_key} ready!") |
|
|
|
|
|
return loaded_models[model_key], loaded_tokenizers[model_key] |
|
|
|
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
|
prompt: str = Field(..., min_length=1, max_length=4000) |
|
|
model: str = Field(default="yuuki-best", description="Model to use") |
|
|
max_new_tokens: int = Field(default=120, ge=1, le=512) |
|
|
temperature: float = Field(default=0.7, ge=0.1, le=2.0) |
|
|
top_p: float = Field(default=0.95, ge=0.0, le=1.0) |
|
|
|
|
|
|
|
|
class GenerateResponse(BaseModel): |
|
|
response: str |
|
|
model: str |
|
|
tokens_generated: int |
|
|
time_ms: int |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def root(): |
|
|
return { |
|
|
"message": "Yuuki Local Inference API", |
|
|
"models": list(MODELS.keys()), |
|
|
"endpoints": { |
|
|
"health": "GET /health", |
|
|
"models": "GET /models", |
|
|
"generate": "POST /generate", |
|
|
"docs": "GET /docs" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
def health(): |
|
|
return { |
|
|
"status": "ok", |
|
|
"available_models": list(MODELS.keys()), |
|
|
"loaded_models": list(loaded_models.keys()) |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/models") |
|
|
def list_models(): |
|
|
return { |
|
|
"models": [ |
|
|
{"id": key, "name": value} |
|
|
for key, value in MODELS.items() |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/generate", response_model=GenerateResponse) |
|
|
def generate(req: GenerateRequest): |
|
|
|
|
|
if req.model not in MODELS: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Invalid model. Available: {list(MODELS.keys())}" |
|
|
) |
|
|
|
|
|
try: |
|
|
start = time.time() |
|
|
|
|
|
|
|
|
model, tokenizer = load_model(req.model) |
|
|
|
|
|
inputs = tokenizer( |
|
|
req.prompt, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=1024 |
|
|
) |
|
|
|
|
|
input_length = inputs["input_ids"].shape[1] |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=req.max_new_tokens, |
|
|
temperature=req.temperature, |
|
|
top_p=req.top_p, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
repetition_penalty=1.1, |
|
|
) |
|
|
|
|
|
new_tokens = output[0][input_length:] |
|
|
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True) |
|
|
|
|
|
elapsed_ms = int((time.time() - start) * 1000) |
|
|
|
|
|
return GenerateResponse( |
|
|
response=response_text.strip(), |
|
|
model=req.model, |
|
|
tokens_generated=len(new_tokens), |
|
|
time_ms=elapsed_ms |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|