AARANHA commited on
Commit
8c89b4a
·
verified ·
1 Parent(s): cf88a80

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from pydantic import BaseModel
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ import uvicorn
7
+ import os
8
+
9
+ app = FastAPI(
10
+ title="GLM-4.6-FP8 API",
11
+ description="API REST funcional para GLM-4.6-FP8 com suporte a múltiplas linguagens",
12
+ version="1.0.0"
13
+ )
14
+
15
+ # Modelos cache
16
+ model = None
17
+ tokenizer = None
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ class ChatRequest(BaseModel):
21
+ message: str
22
+ max_tokens: int = 512
23
+ temperature: float = 0.7
24
+ top_p: float = 0.95
25
+
26
+ class ChatResponse(BaseModel):
27
+ response: str
28
+ model: str = "GLM-4.6-FP8"
29
+ device: str = device
30
+
31
+ @app.on_event("startup")
32
+ async def startup_event():
33
+ global model, tokenizer
34
+ try:
35
+ print("Carregando modelo GLM-4.6-FP8...")
36
+ tokenizer = AutoTokenizer.from_pretrained("zai-org/GLM-4.6-FP8")
37
+ model = AutoModelForCausalLM.from_pretrained(
38
+ "zai-org/GLM-4.6-FP8",
39
+ device_map="auto",
40
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
41
+ trust_remote_code=True
42
+ )
43
+ print("Modelo carregado com sucesso!")
44
+ except Exception as e:
45
+ print(f"Erro ao carregar modelo: {e}")
46
+ raise
47
+
48
+ @app.get("/")
49
+ async def root():
50
+ return {
51
+ "message": "GLM-4.6-FP8 API",
52
+ "version": "1.0.0",
53
+ "device": device,
54
+ "model_loaded": model is not None,
55
+ "endpoints": {
56
+ "chat": "/chat",
57
+ "generate": "/generate",
58
+ "health": "/health"
59
+ }
60
+ }
61
+
62
+ @app.get("/health")
63
+ async def health():
64
+ return {
65
+ "status": "ok",
66
+ "model_loaded": model is not None,
67
+ "device": device
68
+ }
69
+
70
+ @app.post("/chat", response_model=ChatResponse)
71
+ async def chat(request: ChatRequest):
72
+ global model, tokenizer
73
+
74
+ if model is None or tokenizer is None:
75
+ raise HTTPException(status_code=503, detail="Modelo não está carregado")
76
+
77
+ try:
78
+ # Tokenizar entrada
79
+ inputs = tokenizer(request.message, return_tensors="pt").to(device)
80
+
81
+ # Gerar resposta
82
+ with torch.no_grad():
83
+ outputs = model.generate(
84
+ **inputs,
85
+ max_new_tokens=request.max_tokens,
86
+ temperature=request.temperature,
87
+ top_p=request.top_p,
88
+ do_sample=True
89
+ )
90
+
91
+ # Decodificar resposta
92
+ response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
93
+
94
+ return ChatResponse(response=response_text)
95
+
96
+ except Exception as e:
97
+ raise HTTPException(status_code=500, detail=f"Erro na geração: {str(e)}")
98
+
99
+ @app.post("/generate", response_model=ChatResponse)
100
+ async def generate(request: ChatRequest):
101
+ """Alias para /chat com formato alternativo"""
102
+ return await chat(request)
103
+
104
+ if __name__ == "__main__":
105
+ uvicorn.run(app, host="0.0.0.0", port=7860)