Spaces:
Sleeping
Sleeping
File size: 2,801 Bytes
7f17fe7 9dd37b1 7f17fe7 e66afc2 7f17fe7 ec61894 7f17fe7 e66afc2 7f17fe7 ec61894 7f17fe7 e66afc2 7f17fe7 6ba018e 7f17fe7 ec61894 7f17fe7 e66afc2 7f17fe7 0914de7 7f17fe7 e66afc2 7f17fe7 ec61894 0914de7 7f17fe7 ec61894 7f17fe7 ec61894 7f17fe7 4607c9c e66afc2 7f17fe7 e66afc2 7f17fe7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from fastapi import FastAPI, Request
from transformers import BertModel, BertForSequenceClassification, AutoTokenizer
from huggingface_hub import hf_hub_download
import torch
import pickle
import os
import sys
import psutil
app = FastAPI()
device = torch.device("cpu")
# category.pkl ๋ก๋
try:
with open("category.pkl", "rb") as f:
category = pickle.load(f)
print("โ
category.pkl ๋ก๋ ์ฑ๊ณต.")
except FileNotFoundError:
print("โ Error: category.pkl ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค.")
sys.exit(1)
# ํ ํฌ๋์ด์ ๋ก๋
tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
print("โ
ํ ํฌ๋์ด์ ๋ก๋ ์ฑ๊ณต.")
class CustomClassifier(torch.nn.Module):
def __init__(self):
super().__init__()
# ์ ์ํ๋ ๊ตฌ์กฐ ๊ทธ๋๋ก ๋ณต์ํด์ผ ํจ
self.bert = BertModel.from_pretrained("skt/kobert-base-v1")
self.classifier = torch.nn.Linear(768, len(category))
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
outputs = self.bert(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
pooled_output = outputs[1] # CLS ํ ํฐ
return self.classifier(pooled_output)
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
HF_MODEL_FILENAME = "textClassifierModel.pt"
# ๋ฉ๋ชจ๋ฆฌ ์ธก์ ์
process = psutil.Process(os.getpid())
mem_before = process.memory_info().rss / (1024 * 1024)
print(f"๐ฆ ๋ชจ๋ธ ๋ค์ด๋ก๋ ์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋: {mem_before:.2f} MB")
# ๋ชจ๋ธ ๋ก๋
try:
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
print(f"โ
๋ชจ๋ธ ํ์ผ ๋ค์ด๋ก๋ ์ฑ๊ณต: {model_path}")
state_dict = torch.load(model_path, map_location=device)
model = BertForSequenceClassification.from_pretrained(
"skt/kobert-base-v1",
num_labels=len(category),
state_dict=state_dict,
)
model.to(device)
model.eval()
print("โ
๋ชจ๋ธ ๋ก๋ ๋ฐ ์ค๋น ์๋ฃ.")
except Exception as e:
print(f"โ Error: ๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}")
sys.exit(1)
# ์์ธก API
@app.post("/predict")
async def predict_api(request: Request):
data = await request.json()
text = data.get("text")
if not text:
return {"error": "No text provided", "classification": "null"}
encoded = tokenizer.encode_plus(
text, max_length=64, padding='max_length', truncation=True, return_tensors='pt'
)
with torch.no_grad():
outputs = model(**encoded)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
predicted = torch.argmax(probs, dim=1).item()
label = list(category.keys())[predicted]
return {"text": text, "classification": label}
|