File size: 2,993 Bytes
8c89b4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import uvicorn
import os

app = FastAPI(
    title="GLM-4.6-FP8 API",
    description="API REST funcional para GLM-4.6-FP8 com suporte a múltiplas linguagens",
    version="1.0.0"
)

# Modelos cache
model = None
tokenizer = None
device = "cuda" if torch.cuda.is_available() else "cpu"

class ChatRequest(BaseModel):
    message: str
    max_tokens: int = 512
    temperature: float = 0.7
    top_p: float = 0.95

class ChatResponse(BaseModel):
    response: str
    model: str = "GLM-4.6-FP8"
    device: str = device

@app.on_event("startup")
async def startup_event():
    global model, tokenizer
    try:
        print("Carregando modelo GLM-4.6-FP8...")
        tokenizer = AutoTokenizer.from_pretrained("zai-org/GLM-4.6-FP8")
        model = AutoModelForCausalLM.from_pretrained(
            "zai-org/GLM-4.6-FP8",
            device_map="auto",
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
            trust_remote_code=True
        )
        print("Modelo carregado com sucesso!")
    except Exception as e:
        print(f"Erro ao carregar modelo: {e}")
        raise

@app.get("/")
async def root():
    return {
        "message": "GLM-4.6-FP8 API",
        "version": "1.0.0",
        "device": device,
        "model_loaded": model is not None,
        "endpoints": {
            "chat": "/chat",
            "generate": "/generate",
            "health": "/health"
        }
    }

@app.get("/health")
async def health():
    return {
        "status": "ok",
        "model_loaded": model is not None,
        "device": device
    }

@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
    global model, tokenizer
    
    if model is None or tokenizer is None:
        raise HTTPException(status_code=503, detail="Modelo não está carregado")
    
    try:
        # Tokenizar entrada
        inputs = tokenizer(request.message, return_tensors="pt").to(device)
        
        # Gerar resposta
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=request.max_tokens,
                temperature=request.temperature,
                top_p=request.top_p,
                do_sample=True
            )
        
        # Decodificar resposta
        response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        return ChatResponse(response=response_text)
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Erro na geração: {str(e)}")

@app.post("/generate", response_model=ChatResponse)
async def generate(request: ChatRequest):
    """Alias para /chat com formato alternativo"""
    return await chat(request)

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)