Spaces:
Sleeping
Sleeping
File size: 3,100 Bytes
7f17fe7 9dd37b1 7f17fe7 e66afc2 7f17fe7 ec61894 7f17fe7 e66afc2 7f17fe7 ec61894 7f17fe7 e66afc2 7f17fe7 6ba018e ec61894 7f17fe7 ec61894 7f17fe7 e66afc2 7f17fe7 ec61894 7f17fe7 e66afc2 7f17fe7 e66afc2 7f17fe7 ec61894 7f17fe7 e66afc2 ec61894 7f17fe7 ec61894 7f17fe7 e66afc2 7f17fe7 e66afc2 7f17fe7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
from fastapi import FastAPI, Request
from transformers import BertModel, BertForSequenceClassification, AutoTokenizer
from huggingface_hub import hf_hub_download
import torch
import pickle
import os
import sys
import psutil
app = FastAPI()
device = torch.device("cpu")
# category.pkl ๋ก๋
try:
with open("category.pkl", "rb") as f:
category = pickle.load(f)
print("โ
category.pkl ๋ก๋ ์ฑ๊ณต.")
except FileNotFoundError:
print("โ Error: category.pkl ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค.")
sys.exit(1)
# ํ ํฌ๋์ด์ ๋ก๋
tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
print("โ
ํ ํฌ๋์ด์ ๋ก๋ ์ฑ๊ณต.")
class CustomClassifier(torch.nn.Module):
def __init__(self):
super().__init__()
# ์ ์ํ๋ ๊ตฌ์กฐ ๊ทธ๋๋ก ๋ณต์ํด์ผ ํจ
self.bert = BertModel.from_pretrained("skt/kobert-base-v1")
self.classifier = torch.nn.Linear(768, len(category))
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
outputs = self.bert(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
pooled_output = outputs[1] # CLS ํ ํฐ
return self.classifier(pooled_output)
model = CustomClassifier()
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
HF_MODEL_FILENAME = "textClassifierModel.pt"
# ๋ฉ๋ชจ๋ฆฌ ์ธก์ ์
process = psutil.Process(os.getpid())
mem_before = process.memory_info().rss / (1024 * 1024)
print(f"๐ฆ ๋ชจ๋ธ ๋ค์ด๋ก๋ ์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋: {mem_before:.2f} MB")
# ๋ชจ๋ธ ๊ฐ์ค์น ๋ค์ด๋ก๋
try:
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
print(f"โ
๋ชจ๋ธ ํ์ผ ๋ค์ด๋ก๋ ์ฑ๊ณต: {model_path}")
mem_after_dl = process.memory_info().rss / (1024 * 1024)
print(f"๐ฆ ๋ชจ๋ธ ๋ค์ด๋ก๋ ํ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋: {mem_after_dl:.2f} MB")
# state_dict ๋ก๋
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()
mem_after_load = process.memory_info().rss / (1024 * 1024)
print(f"๐ฆ ๋ชจ๋ธ ๋ก๋ ํ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋: {mem_after_load:.2f} MB")
print("โ
๋ชจ๋ธ ๋ก๋ ๋ฐ ์ค๋น ์๋ฃ.")
except Exception as e:
print(f"โ Error: ๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}")
sys.exit(1)
# ์์ธก API
@app.post("/predict")
async def predict_api(request: Request):
data = await request.json()
text = data.get("text")
if not text:
return {"error": "No text provided", "classification": "null"}
encoded = tokenizer.encode_plus(
text, max_length=64, padding='max_length', truncation=True, return_tensors='pt'
)
with torch.no_grad():
outputs = model(**encoded)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
predicted = torch.argmax(probs, dim=1).item()
label = list(category.keys())[predicted]
return {"text": text, "classification": label}
|