hiddenFront's picture
Update app.py
44d2bcd verified
raw
history blame
6.56 kB
from fastapi import FastAPI, Request
from pydantic import BaseModel
import torch
import pickle
import gluonnlp as nlp
import numpy as np
import os
import sys
import collections
import logging # ๋กœ๊น… ๋ชจ๋“ˆ ์ž„ํฌํŠธ
# transformers์˜ AutoTokenizer ๋ฐ BertModel ์ž„ํฌํŠธ
from transformers import AutoTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader
from huggingface_hub import hf_hub_download
# --- ๋กœ๊น… ์„ค์ • ---
# INFO ๋ ˆ๋ฒจ ์ด์ƒ์˜ ๋กœ๊ทธ๋ฅผ ์ถœ๋ ฅํ•˜๋„๋ก ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
# ์‹ค์ œ ๋ฐฐํฌ ํ™˜๊ฒฝ์—์„œ๋Š” ๋กœ๊ทธ ๋ ˆ๋ฒจ์„ WARNING์ด๋‚˜ ERROR๋กœ ๋†’์—ฌ ๋ถˆํ•„์š”ํ•œ ๋กœ๊ทธ๋ฅผ ์ค„์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- 1. BERTClassifier ๋ชจ๋ธ ํด๋ž˜์Šค ์ •์˜ ---
class BERTClassifier(torch.nn.Module):
def __init__(self,
bert,
hidden_size = 768,
num_classes=5, # ๋ถ„๋ฅ˜ํ•  ํด๋ž˜์Šค ์ˆ˜ (category ๋”•์…”๋„ˆ๋ฆฌ ํฌ๊ธฐ์™€ ์ผ์น˜)
dr_rate=None,
params=None):
super(BERTClassifier, self).__init__()
self.bert = bert
self.dr_rate = dr_rate
self.classifier = torch.nn.Linear(hidden_size , num_classes)
if dr_rate:
self.dropout = torch.nn.Dropout(p=dr_rate)
def gen_attention_mask(self, token_ids, valid_length):
attention_mask = torch.zeros_like(token_ids)
for i, v in enumerate(valid_length):
attention_mask[i][:v] = 1
return attention_mask.float()
def forward(self, token_ids, valid_length, segment_ids):
attention_mask = self.gen_attention_mask(token_ids, valid_length)
_, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(), attention_mask=attention_mask.float().to(token_ids.device), return_dict=False)
if self.dr_rate:
out = self.dropout(pooler)
else:
out = pooler
return self.classifier(out)
# --- 2. BERTDataset ํด๋ž˜์Šค ์ •์˜ ---
class BERTDataset(Dataset):
def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
transform = nlp.data.BERTSentenceTransform(
bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
)
self.sentences = [transform([i[sent_idx]]) for i in dataset]
self.labels = [np.int32(i[label_idx]) for i in dataset]
def __getitem__(self, i):
return (self.sentences[i] + (self.labels[i],))
def __len__(self):
return len(self.labels)
# --- 3. FastAPI ์•ฑ ๋ฐ ์ „์—ญ ๋ณ€์ˆ˜ ์„ค์ • ---
app = FastAPI()
device = torch.device("cpu") # Hugging Face Spaces์˜ ๋ฌด๋ฃŒ ํ‹ฐ์–ด๋Š” ์ฃผ๋กœ CPU๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
# โœ… category ๋กœ๋“œ
try:
with open("category.pkl", "rb") as f:
category = pickle.load(f)
logger.info("category.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
except FileNotFoundError:
logger.error("Error: category.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
sys.exit(1)
# โœ… vocab ๋กœ๋“œ
try:
with open("vocab.pkl", "rb") as f:
vocab = pickle.load(f)
logger.info("vocab.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
except FileNotFoundError:
logger.error("Error: vocab.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
sys.exit(1)
# โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
tokenizer = AutoTokenizer.from_pretrained('skt/kobert-base-v1')
logger.info("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
# โœ… ๋ชจ๋ธ ๋กœ๋“œ (Hugging Face Hub์—์„œ ๋‹ค์šด๋กœ๋“œ)
try:
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
HF_MODEL_FILENAME = "textClassifierModel.pt"
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
logger.info(f"๋ชจ๋ธ ํŒŒ์ผ์ด '{model_path}'์— ์„ฑ๊ณต์ ์œผ๋กœ ๋‹ค์šด๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
bert_base_model = BertModel.from_pretrained('skt/kobert-base-v1')
model = BERTClassifier(
bert_base_model,
dr_rate=0.5, # ํ•™์Šต ์‹œ ์‚ฌ์šฉ๋œ dr_rate ๊ฐ’์œผ๋กœ ๋ณ€๊ฒฝํ•˜์„ธ์š”.
num_classes=len(category)
)
loaded_state_dict = torch.load(model_path, map_location=device)
new_state_dict = collections.OrderedDict()
for k, v in loaded_state_dict.items():
name = k
if name.startswith('module.'):
name = name[7:]
new_state_dict[name] = v
model.load_state_dict(new_state_dict, strict=False)
model.to(device)
model.eval()
logger.info("๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต.")
except Exception as e:
logger.error(f"Error: ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ๋˜๋Š” ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
sys.exit(1)
# โœ… ๋ฐ์ดํ„ฐ์…‹ ์ƒ์„ฑ์— ํ•„์š”ํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ
max_len = 64
batch_size = 32
# โœ… ์˜ˆ์ธก ํ•จ์ˆ˜
def predict(predict_sentence):
data = [predict_sentence, '0']
dataset_another = [data]
another_test = BERTDataset(dataset_another, 0, 1, tokenizer.tokenize, vocab, max_len, True, False)
test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0)
model.eval()
with torch.no_grad():
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader):
token_ids = token_ids.long().to(device)
segment_ids = segment_ids.long().to(device)
out = model(token_ids, valid_length, segment_ids)
logits = out # ๋ชจ๋ธ์˜ ์ง์ ‘ ์ถœ๋ ฅ์€ ๋กœ์ง“์ž…๋‹ˆ๋‹ค.
probs = torch.nn.functional.softmax(logits, dim=1) # ํ™•๋ฅ  ๊ณ„์‚ฐ
predicted_category_index = torch.argmax(probs, dim=1).item() # ์˜ˆ์ธก ์ธ๋ฑ์Šค
predicted_category_name = list(category.keys())[predicted_category_index] # ์˜ˆ์ธก ์นดํ…Œ๊ณ ๋ฆฌ ์ด๋ฆ„
# --- ์˜ˆ์ธก ์ƒ์„ธ ๋กœ๊น… ---
logger.info(f"Input Text: '{predict_sentence}'")
logger.info(f"Raw Logits: {logits.tolist()}")
logger.info(f"Probabilities: {probs.tolist()}")
logger.info(f"Predicted Index: {predicted_category_index}")
logger.info(f"Predicted Label: '{predicted_category_name}'")
# --- ์˜ˆ์ธก ์ƒ์„ธ ๋กœ๊น… ๋ ---
return predicted_category_name
# โœ… ์—”๋“œํฌ์ธํŠธ ์ •์˜
class InputText(BaseModel):
text: str
@app.get("/")
def root():
return {"message": "Text Classification API (KoBERT)"}
@app.post("/predict")
async def predict_route(item: InputText):
result = predict(item.text)
return {"text": item.text, "classification": result}