hiddenFront's picture
Update app.py
1efa28d verified
raw
history blame
5.37 kB
from fastapi import FastAPI, Request
from pydantic import BaseModel
import torch
import pickle
import gluonnlp as nlp
import numpy as np
import os
import sys # ์˜ค๋ฅ˜ ์‹œ ์„œ๋น„์Šค ์ข…๋ฃŒ๋ฅผ ์œ„ํ•ด sys ๋ชจ๋“ˆ ์ž„ํฌํŠธ
# transformers์˜ AutoTokenizer๋งŒ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
from transformers import AutoTokenizer # BertModel, BertForSequenceClassification ๋“ฑ์€ ์ด์ œ ์ง์ ‘ ํ•„์š” ์—†์Šต๋‹ˆ๋‹ค.
from torch.utils.data import Dataset, DataLoader
import logging # ๋กœ๊น… ๋ชจ๋“ˆ ์ž„ํฌํŠธ ์œ ์ง€
from huggingface_hub import hf_hub_download # hf_hub_download ์ž„ํฌํŠธ ์œ ์ง€
# collections ๋ชจ๋“ˆ์€ ๋” ์ด์ƒ ํ•„์š” ์—†์„ ์ˆ˜ ์žˆ์ง€๋งŒ, ํ˜น์‹œ ๋ชฐ๋ผ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.
import collections
# --- 1. FastAPI ์•ฑ ๋ฐ ์ „์—ญ ๋ณ€์ˆ˜ ์„ค์ • ---
app = FastAPI()
device = torch.device("cpu") # Hugging Face Spaces์˜ ๋ฌด๋ฃŒ ํ‹ฐ์–ด๋Š” ์ฃผ๋กœ CPU๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
# โœ… category ๋กœ๋“œ (GitHub ์ €์žฅ์†Œ ๋ฃจํŠธ์— ์žˆ์–ด์•ผ ํ•จ)
try:
with open("category.pkl", "rb") as f:
category = pickle.load(f)
print("category.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
except FileNotFoundError:
print("Error: category.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
sys.exit(1) # ํŒŒ์ผ ์—†์œผ๋ฉด ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜์ง€ ์•Š์Œ
# โœ… vocab ๋กœ๋“œ (GitHub ์ €์žฅ์†Œ ๋ฃจํŠธ์— ์žˆ์–ด์•ผ ํ•จ)
try:
with open("vocab.pkl", "rb") as f:
vocab = pickle.load(f)
print("vocab.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
except FileNotFoundError:
print("Error: vocab.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
sys.exit(1) # ํŒŒ์ผ ์—†์œผ๋ฉด ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜์ง€ ์•Š์Œ
# โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ (transformers.AutoTokenizer ์‚ฌ์šฉ)
tokenizer = AutoTokenizer.from_pretrained('skt/kobert-base-v1')
print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
# โœ… ๋ชจ๋ธ ๋กœ๋“œ (Hugging Face Hub์—์„œ ๋‹ค์šด๋กœ๋“œ)
# textClassifierModel.pt ํŒŒ์ผ์€ ์ด๋ฏธ ๊ฒฝ๋Ÿ‰ํ™”๋œ '์™„์ „ํ•œ ๋ชจ๋ธ ๊ฐ์ฒด'๋ผ๊ณ  ๊ฐ€์ •ํ•˜๊ณ  ์ง์ ‘ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
try:
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" # ์‚ฌ์šฉ์ž๋‹˜์˜ ์‹ค์ œ Hugging Face ์ €์žฅ์†Œ ID
HF_MODEL_FILENAME = "textClassifierModel.pt" # Hugging Face Hub์— ์—…๋กœ๋“œํ•œ ํŒŒ์ผ ์ด๋ฆ„๊ณผ ์ผ์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
print(f"๋ชจ๋ธ ํŒŒ์ผ์ด '{model_path}'์— ์„ฑ๊ณต์ ์œผ๋กœ ๋‹ค์šด๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
# --- ์ˆ˜์ •๋œ ํ•ต์‹ฌ ๋ถ€๋ถ„ ---
# ๊ฒฝ๋Ÿ‰ํ™”๋œ ๋ชจ๋ธ ๊ฐ์ฒด๋ฅผ ์ง์ ‘ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
# ์ด ํŒŒ์ผ์€ ์ด๋ฏธ PyTorch ๋ชจ๋ธ ๊ฐ์ฒด(์–‘์žํ™”๋œ ๋ชจ๋ธ ํฌํ•จ)์ด๋ฏ€๋กœ ๋ฐ”๋กœ ๋กœ๋“œํ•˜์—ฌ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
model = torch.load(model_path, map_location=device)
# --- ์ˆ˜์ •๋œ ํ•ต์‹ฌ ๋ถ€๋ถ„ ๋ ---
model.eval() # ์ถ”๋ก  ๋ชจ๋“œ๋กœ ์„ค์ •
print("๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต.")
except Exception as e:
print(f"Error: ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ๋˜๋Š” ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
sys.exit(1) # ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ ์‹œ ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜์ง€ ์•Š์Œ
# --- 2. BERTDataset ํด๋ž˜์Šค ์ •์˜ (dataset.py์—์„œ ์˜ฎ๊ฒจ์˜ด) ---
# ์ด ํด๋ž˜์Šค๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ๋ชจ๋ธ ์ž…๋ ฅ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
class BERTDataset(Dataset):
def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
# nlp.data.BERTSentenceTransform์€ ํ† ํฌ๋‚˜์ด์ € ํ•จ์ˆ˜๋ฅผ ๋ฐ›์Šต๋‹ˆ๋‹ค.
# AutoTokenizer์˜ tokenize ๋ฉ”์„œ๋“œ๋ฅผ ์ง์ ‘ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
transform = nlp.data.BERTSentenceTransform(
bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
)
self.sentences = [transform([i[sent_idx]]) for i in dataset]
self.labels = [np.int32(i[label_idx]) for i in dataset]
def __getitem__(self, i):
return (self.sentences[i] + (self.labels[i],))
def __len__(self):
return len(self.labels)
# โœ… ๋ฐ์ดํ„ฐ์…‹ ์ƒ์„ฑ์— ํ•„์š”ํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ
max_len = 64
batch_size = 32
# โœ… ์˜ˆ์ธก ํ•จ์ˆ˜
def predict(predict_sentence):
data = [predict_sentence, '0']
dataset_another = [data]
# num_workers๋Š” ๋ฐฐํฌ ํ™˜๊ฒฝ์—์„œ 0์œผ๋กœ ์„ค์ • ๊ถŒ์žฅ
# tokenizer.tokenize๋ฅผ BERTDataset์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
another_test = BERTDataset(dataset_another, 0, 1, tokenizer.tokenize, vocab, max_len, True, False)
test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0)
model.eval() # ์˜ˆ์ธก ์‹œ ๋ชจ๋ธ์„ ํ‰๊ฐ€ ๋ชจ๋“œ๋กœ ์„ค์ •
with torch.no_grad(): # ๊ทธ๋ผ๋””์–ธํŠธ ๊ณ„์‚ฐ ๋น„ํ™œ์„ฑํ™”
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader):
token_ids = token_ids.long().to(device)
segment_ids = segment_ids.long().to(device)
out = model(token_ids, valid_length, segment_ids)
logits = out
logits = logits.detach().cpu().numpy()
predicted_category_index = np.argmax(logits)
predicted_category_name = list(category.keys())[predicted_category_index]
return predicted_category_name
# โœ… ์—”๋“œํฌ์ธํŠธ ์ •์˜
class InputText(BaseModel):
text: str
@app.get("/")
def root():
return {"message": "Text Classification API (KoBERT)"}
@app.post("/predict")
async def predict_route(item: InputText):
result = predict(item.text)
return {"text": item.text, "classification": result}