hiddenFront's picture
Update app.py
b303ca3 verified
from fastapi import FastAPI, Request
from pydantic import BaseModel
import torch
import pickle
import gluonnlp as nlp
import numpy as np
import os
import sys
import collections
import logging # λ‘œκΉ… λͺ¨λ“ˆ μž„ν¬νŠΈ
from transformers import AutoTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader
from huggingface_hub import hf_hub_download
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class BERTClassifier(torch.nn.Module):
def __init__(self,
bert,
hidden_size = 768,
num_classes=5, # λΆ„λ₯˜ν•  클래슀 수 (category λ”•μ…”λ„ˆλ¦¬ 크기와 일치)
dr_rate=None,
params=None):
super(BERTClassifier, self).__init__()
self.bert = bert
self.dr_rate = dr_rate
self.classifier = torch.nn.Linear(hidden_size , num_classes)
if dr_rate:
self.dropout = torch.nn.Dropout(p=dr_rate)
def gen_attention_mask(self, token_ids, valid_length):
attention_mask = torch.zeros_like(token_ids)
for i, v in enumerate(valid_length):
attention_mask[i][:v] = 1
return attention_mask.float()
def forward(self, token_ids, valid_length, segment_ids):
attention_mask = self.gen_attention_mask(token_ids, valid_length)
_, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(), attention_mask=attention_mask.float().to(token_ids.device), return_dict=False)
if self.dr_rate:
out = self.dropout(pooler)
else:
out = pooler
return self.classifier(out)
# --- 2. BERTDataset 클래슀 μ •μ˜ ---
class BERTDataset(Dataset):
def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
transform = nlp.data.BERTSentenceTransform(
bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
)
self.sentences = [transform([i[sent_idx]]) for i in dataset]
self.labels = [np.int32(i[label_idx]) for i in dataset]
def __getitem__(self, i):
return (self.sentences[i] + (self.labels[i],))
def __len__(self):
return len(self.labels)
app = FastAPI()
device = torch.device("cpu") # Hugging Face Spaces의 무료 ν‹°μ–΄λŠ” 주둜 CPUλ₯Ό μ‚¬μš©ν•©λ‹ˆλ‹€.
# βœ… category λ‘œλ“œ
try:
with open("category.pkl", "rb") as f:
category = pickle.load(f)
logger.info("category.pkl λ‘œλ“œ 성곡.")
except FileNotFoundError:
logger.error("Error: category.pkl νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€. ν”„λ‘œμ νŠΈ λ£¨νŠΈμ— μžˆλŠ”μ§€ ν™•μΈν•˜μ„Έμš”.")
sys.exit(1)
# βœ… vocab λ‘œλ“œ
try:
with open("vocab.pkl", "rb") as f:
vocab = pickle.load(f)
logger.info("vocab.pkl λ‘œλ“œ 성곡.")
except FileNotFoundError:
logger.error("Error: vocab.pkl νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€. ν”„λ‘œμ νŠΈ λ£¨νŠΈμ— μžˆλŠ”μ§€ ν™•μΈν•˜μ„Έμš”.")
sys.exit(1)
# βœ… ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
tokenizer = AutoTokenizer.from_pretrained('skt/kobert-base-v1')
logger.info("ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ 성곡.")
# βœ… λͺ¨λΈ λ‘œλ“œ (Hugging Face Hubμ—μ„œ λ‹€μš΄λ‘œλ“œ)
try:
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
HF_MODEL_FILENAME = "textClassifierModel.pt"
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
logger.info(f"λͺ¨λΈ 파일이 '{model_path}'에 μ„±κ³΅μ μœΌλ‘œ λ‹€μš΄λ‘œλ“œλ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
bert_base_model = BertModel.from_pretrained('skt/kobert-base-v1')
model = BERTClassifier(
bert_base_model,
dr_rate=0.5, # ν•™μŠ΅ μ‹œ μ‚¬μš©λœ dr_rate κ°’μœΌλ‘œ λ³€κ²½ν•˜μ„Έμš”.
num_classes=len(category)
)
loaded_state_dict = torch.load(model_path, map_location=device)
new_state_dict = collections.OrderedDict()
for k, v in loaded_state_dict.items():
name = k
if name.startswith('module.'):
name = name[7:]
new_state_dict[name] = v
model.load_state_dict(new_state_dict, strict=False)
model.to(device)
model.eval()
logger.info("λͺ¨λΈ λ‘œλ“œ 성곡.")
except Exception as e:
logger.error(f"Error: λͺ¨λΈ λ‹€μš΄λ‘œλ“œ λ˜λŠ” λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
sys.exit(1)
# βœ… 데이터셋 생성에 ν•„μš”ν•œ νŒŒλΌλ―Έν„°
max_len = 64
batch_size = 32
# βœ… 예츑 ν•¨μˆ˜
def predict(predict_sentence):
data = [predict_sentence, '0']
dataset_another = [data]
another_test = BERTDataset(dataset_another, 0, 1, tokenizer.tokenize, vocab, max_len, True, False)
test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0)
model.eval()
with torch.no_grad():
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader):
token_ids = token_ids.long().to(device)
segment_ids = segment_ids.long().to(device)
out = model(token_ids, valid_length, segment_ids)
logits = out # λͺ¨λΈμ˜ 직접 좜λ ₯은 λ‘œμ§“μž…λ‹ˆλ‹€.
probs = torch.nn.functional.softmax(logits, dim=1) # ν™•λ₯  계산
predicted_category_index = torch.argmax(probs, dim=1).item() # 예츑 인덱슀
predicted_category_name = list(category.keys())[predicted_category_index] # 예츑 μΉ΄ν…Œκ³ λ¦¬ 이름
# --- 예츑 상세 λ‘œκΉ… ---
logger.info(f"Input Text: '{predict_sentence}'")
logger.info(f"Raw Logits: {logits.tolist()}")
logger.info(f"Probabilities: {probs.tolist()}")
logger.info(f"Predicted Index: {predicted_category_index}")
logger.info(f"Predicted Label: '{predicted_category_name}'")
# --- 예츑 상세 λ‘œκΉ… 끝 ---
return predicted_category_name
# βœ… μ—”λ“œν¬μΈνŠΈ μ •μ˜
class InputText(BaseModel):
text: str
@app.get("/")
def root():
return {"message": "Text Classification API (KoBERT)"}
@app.post("/predict")
async def predict_route(item: InputText):
result = predict(item.text)
return {"text": item.text, "classification": result}