hiddenFront's picture
Update app.py
ec61894 verified
raw
history blame
2.57 kB
from fastapi import FastAPI, Request
from transformers import BertForSequenceClassification, AutoTokenizer
from huggingface_hub import hf_hub_download
import torch
import pickle
import os
import sys
import psutil
app = FastAPI()
device = torch.device("cpu")
# category.pkl ๋กœ๋“œ
try:
with open("category.pkl", "rb") as f:
category = pickle.load(f)
print("โœ… category.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
except FileNotFoundError:
print("โŒ Error: category.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
sys.exit(1)
# ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
print("โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
# ๋ชจ๋ธ ๊ตฌ์กฐ ์žฌ์ •์˜
num_labels = len(category) # ๋ถ„๋ฅ˜ํ•  ํด๋ž˜์Šค ์ˆ˜์— ๋”ฐ๋ผ
model = BertForSequenceClassification.from_pretrained("skt/kobert-base-v1", num_labels=num_labels)
model.to(device)
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
HF_MODEL_FILENAME = "textClassifierModel.pt"
# ๋ฉ”๋ชจ๋ฆฌ ์ธก์ • ์ „
process = psutil.Process(os.getpid())
mem_before = process.memory_info().rss / (1024 * 1024)
print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์ „ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_before:.2f} MB")
# ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋‹ค์šด๋กœ๋“œ
try:
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
print(f"โœ… ๋ชจ๋ธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ์„ฑ๊ณต: {model_path}")
mem_after_dl = process.memory_info().rss / (1024 * 1024)
print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_after_dl:.2f} MB")
# state_dict ๋กœ๋“œ
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()
mem_after_load = process.memory_info().rss / (1024 * 1024)
print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋กœ๋“œ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_after_load:.2f} MB")
print("โœ… ๋ชจ๋ธ ๋กœ๋“œ ๋ฐ ์ค€๋น„ ์™„๋ฃŒ.")
except Exception as e:
print(f"โŒ Error: ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
sys.exit(1)
# ์˜ˆ์ธก API
@app.post("/predict")
async def predict_api(request: Request):
data = await request.json()
text = data.get("text")
if not text:
return {"error": "No text provided", "classification": "null"}
encoded = tokenizer.encode_plus(
text, max_length=64, padding='max_length', truncation=True, return_tensors='pt'
)
with torch.no_grad():
outputs = model(**encoded)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
predicted = torch.argmax(probs, dim=1).item()
label = list(category.keys())[predicted]
return {"text": text, "classification": label}