Spaces:
Sleeping
Sleeping
File size: 2,568 Bytes
7f17fe7 ec61894 7f17fe7 e66afc2 7f17fe7 ec61894 7f17fe7 e66afc2 7f17fe7 ec61894 7f17fe7 e66afc2 7f17fe7 ec61894 7f17fe7 ec61894 7f17fe7 e66afc2 7f17fe7 ec61894 7f17fe7 e66afc2 7f17fe7 e66afc2 7f17fe7 ec61894 7f17fe7 e66afc2 ec61894 7f17fe7 ec61894 7f17fe7 e66afc2 7f17fe7 e66afc2 7f17fe7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from fastapi import FastAPI, Request
from transformers import BertForSequenceClassification, AutoTokenizer
from huggingface_hub import hf_hub_download
import torch
import pickle
import os
import sys
import psutil
app = FastAPI()
device = torch.device("cpu")
# category.pkl ๋ก๋
try:
with open("category.pkl", "rb") as f:
category = pickle.load(f)
print("โ
category.pkl ๋ก๋ ์ฑ๊ณต.")
except FileNotFoundError:
print("โ Error: category.pkl ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค.")
sys.exit(1)
# ํ ํฌ๋์ด์ ๋ก๋
tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
print("โ
ํ ํฌ๋์ด์ ๋ก๋ ์ฑ๊ณต.")
# ๋ชจ๋ธ ๊ตฌ์กฐ ์ฌ์ ์
num_labels = len(category) # ๋ถ๋ฅํ ํด๋์ค ์์ ๋ฐ๋ผ
model = BertForSequenceClassification.from_pretrained("skt/kobert-base-v1", num_labels=num_labels)
model.to(device)
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
HF_MODEL_FILENAME = "textClassifierModel.pt"
# ๋ฉ๋ชจ๋ฆฌ ์ธก์ ์
process = psutil.Process(os.getpid())
mem_before = process.memory_info().rss / (1024 * 1024)
print(f"๐ฆ ๋ชจ๋ธ ๋ค์ด๋ก๋ ์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋: {mem_before:.2f} MB")
# ๋ชจ๋ธ ๊ฐ์ค์น ๋ค์ด๋ก๋
try:
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
print(f"โ
๋ชจ๋ธ ํ์ผ ๋ค์ด๋ก๋ ์ฑ๊ณต: {model_path}")
mem_after_dl = process.memory_info().rss / (1024 * 1024)
print(f"๐ฆ ๋ชจ๋ธ ๋ค์ด๋ก๋ ํ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋: {mem_after_dl:.2f} MB")
# state_dict ๋ก๋
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()
mem_after_load = process.memory_info().rss / (1024 * 1024)
print(f"๐ฆ ๋ชจ๋ธ ๋ก๋ ํ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋: {mem_after_load:.2f} MB")
print("โ
๋ชจ๋ธ ๋ก๋ ๋ฐ ์ค๋น ์๋ฃ.")
except Exception as e:
print(f"โ Error: ๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}")
sys.exit(1)
# ์์ธก API
@app.post("/predict")
async def predict_api(request: Request):
data = await request.json()
text = data.get("text")
if not text:
return {"error": "No text provided", "classification": "null"}
encoded = tokenizer.encode_plus(
text, max_length=64, padding='max_length', truncation=True, return_tensors='pt'
)
with torch.no_grad():
outputs = model(**encoded)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
predicted = torch.argmax(probs, dim=1).item()
label = list(category.keys())[predicted]
return {"text": text, "classification": label}
|