hiddenFront's picture
Update app.py
8153817 verified
raw
history blame
7.53 kB
from fastapi import FastAPI, Request
from pydantic import BaseModel
import torch
import pickle
import gluonnlp as nlp
import numpy as np
import os
from kobert_tokenizer import KoBERTTokenizer # kobert_tokenizer ์ž„ํฌํŠธ ์œ ์ง€
from transformers import BertModel # BertModel ์ž„ํฌํŠธ ์œ ์ง€
from torch.utils.data import Dataset, DataLoader # DataLoader ์ž„ํฌํŠธ ์ถ”๊ฐ€
import logging # ๋กœ๊น… ๋ชจ๋“ˆ ์ž„ํฌํŠธ ์œ ์ง€
# --- 1. BERTClassifier ๋ชจ๋ธ ํด๋ž˜์Šค ์ •์˜ (model.py์—์„œ ์˜ฎ๊ฒจ์˜ด) ---
# ์ด ํด๋ž˜์Šค๋Š” ๋ชจ๋ธ์˜ ์•„ํ‚คํ…์ฒ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
class BERTClassifier(torch.nn.Module):
def __init__(self,
bert,
hidden_size = 768,
num_classes=5, # ๋ถ„๋ฅ˜ํ•  ํด๋ž˜์Šค ์ˆ˜ (category ๋”•์…”๋„ˆ๋ฆฌ ํฌ๊ธฐ์™€ ์ผ์น˜)
dr_rate=None,
params=None):
super(BERTClassifier, self).__init__()
self.bert = bert
self.dr_rate = dr_rate
self.classifier = torch.nn.Linear(hidden_size , num_classes)
if dr_rate:
self.dropout = torch.nn.Dropout(p=dr_rate)
def gen_attention_mask(self, token_ids, valid_length):
attention_mask = torch.zeros_like(token_ids)
for i, v in enumerate(valid_length):
attention_mask[i][:v] = 1
return attention_mask.float()
def forward(self, token_ids, valid_length, segment_ids):
attention_mask = self.gen_attention_mask(token_ids, valid_length)
# BertModel์˜ ์ถœ๋ ฅ ๊ตฌ์กฐ์— ๋”ฐ๋ผ ์ˆ˜์ •
# Hugging Face Transformers์˜ BertModel์€ (last_hidden_state, pooler_output, ...) ๋ฐ˜ํ™˜
# pooler_output (CLS ํ† ํฐ์˜ ์ตœ์ข… ์€๋‹‰ ์ƒํƒœ๋ฅผ ํ†ต๊ณผํ•œ ๊ฒฐ๊ณผ) ์‚ฌ์šฉ
_, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(), attention_mask=attention_mask.float().to(token_ids.device), return_dict=False)
if self.dr_rate:
out = self.dropout(pooler)
else:
out = pooler
return self.classifier(out)
# --- 2. BERTDataset ํด๋ž˜์Šค ์ •์˜ (dataset.py์—์„œ ์˜ฎ๊ฒจ์˜ด) ---
# ์ด ํด๋ž˜์Šค๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ๋ชจ๋ธ ์ž…๋ ฅ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
class BERTDataset(Dataset):
def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
transform = nlp.data.BERTSentenceTransform(
bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
)
self.sentences = [transform([i[sent_idx]]) for i in dataset]
self.labels = [np.int32(i[label_idx]) for i in dataset]
def __getitem__(self, i):
return (self.sentences[i] + (self.labels[i],))
def __len__(self):
return len(self.labels)
# --- 3. FastAPI ์•ฑ ๋ฐ ์ „์—ญ ๋ณ€์ˆ˜ ์„ค์ • ---
app = FastAPI()
device = torch.device("cpu") # Render์˜ ๋ฌด๋ฃŒ ํ‹ฐ์–ด๋Š” ์ฃผ๋กœ CPU๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
# โœ… category ๋กœ๋“œ (GitHub ์ €์žฅ์†Œ ๋ฃจํŠธ์— ์žˆ์–ด์•ผ ํ•จ)
try:
with open("category.pkl", "rb") as f:
category = pickle.load(f)
print("category.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
except FileNotFoundError:
print("Error: category.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
sys.exit(1) # ํŒŒ์ผ ์—†์œผ๋ฉด ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜์ง€ ์•Š์Œ
# โœ… vocab ๋กœ๋“œ (GitHub ์ €์žฅ์†Œ ๋ฃจํŠธ์— ์žˆ์–ด์•ผ ํ•จ)
try:
with open("vocab.pkl", "rb") as f:
vocab = pickle.load(f)
print("vocab.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
except FileNotFoundError:
print("Error: vocab.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
sys.exit(1) # ํŒŒ์ผ ์—†์œผ๋ฉด ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜์ง€ ์•Š์Œ
# โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ (kobert_tokenizer ์‚ฌ์šฉ)
# Colab ์ฝ”๋“œ์—์„œ ์‚ฌ์šฉ๋œ ๋ฐฉ์‹์ด๋ฏ€๋กœ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.
tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1')
print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
# โœ… ๋ชจ๋ธ ๋กœ๋“œ
# ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜๋ฅผ ์ •์˜ํ•˜๊ณ , ์ €์žฅ๋œ state_dict๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
# num_classes๋Š” category ๋”•์…”๋„ˆ๋ฆฌ์˜ ํฌ๊ธฐ์™€ ์ผ์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
model = BERTClassifier(
BertModel.from_pretrained('skt/kobert-base-v1'),
dr_rate=0.5, # ํ•™์Šต ์‹œ ์‚ฌ์šฉ๋œ dr_rate ๊ฐ’์œผ๋กœ ๋ณ€๊ฒฝํ•˜์„ธ์š”.
num_classes=len(category)
)
# textClassifierModel.pt ํŒŒ์ผ ๋กœ๋“œ
# ์ด ํŒŒ์ผ์€ GitHub ์ €์žฅ์†Œ์— ์—†์–ด์•ผ ํ•˜๋ฉฐ, Dockerfile์—์„œ Hugging Face Hub์—์„œ ๋‹ค์šด๋กœ๋“œํ•˜๋„๋ก ์„ค์ •๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
try:
# Dockerfile์—์„œ ๋ชจ๋ธ์„ ๋‹ค์šด๋กœ๋“œํ•˜๋„๋ก ์„ค์ •ํ–ˆ์œผ๋ฏ€๋กœ, ์—ฌ๊ธฐ์„œ๋Š” ๋กœ์ปฌ ๊ฒฝ๋กœ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
# ๋งŒ์•ฝ Dockerfile์—์„œ hf_hub_download๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š”๋‹ค๋ฉด, ์—ฌ๊ธฐ์— hf_hub_download๋ฅผ ์ถ”๊ฐ€ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
# ํ˜„์žฌ Dockerfile์€ git+https://github.com/SKTBrain/KOBERT#egg=kobert_tokenizer ๋กœ๋“œ๋งŒ ํฌํ•จํ•˜๊ณ ,
# ๋ชจ๋ธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ๋Š” ํฌํ•จํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
# ๋”ฐ๋ผ์„œ, ๋ชจ๋ธ ํŒŒ์ผ์„ Hugging Face Hub์—์„œ ๋‹ค์šด๋กœ๋“œํ•˜๋Š” ๋กœ์ง์„ ๋‹ค์‹œ ์ถ”๊ฐ€ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
from huggingface_hub import hf_hub_download
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" # ์‚ฌ์šฉ์ž๋‹˜์˜ ์‹ค์ œ Hugging Face ์ €์žฅ์†Œ ID
HF_MODEL_FILENAME = "textClassifierModel.pt"
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
print(f"๋ชจ๋ธ ํŒŒ์ผ์ด '{model_path}'์— ์„ฑ๊ณต์ ์œผ๋กœ ๋‹ค์šด๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
# ๋ชจ๋ธ์˜ state_dict๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
loaded_state_dict = torch.load(model_path, map_location=device)
# state_dict ํ‚ค ์กฐ์ • (ํ•„์š”ํ•œ ๊ฒฝ์šฐ)
new_state_dict = collections.OrderedDict()
for k, v in loaded_state_dict.items():
name = k
if name.startswith('module.'):
name = name[7:]
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.to(device) # ๋ชจ๋ธ์„ ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™
model.eval() # ์ถ”๋ก  ๋ชจ๋“œ๋กœ ์„ค์ •
print("๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต.")
except Exception as e:
print(f"Error: ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ๋˜๋Š” ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
sys.exit(1) # ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ ์‹œ ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜์ง€ ์•Š์Œ
# โœ… ๋ฐ์ดํ„ฐ์…‹ ์ƒ์„ฑ์— ํ•„์š”ํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ
max_len = 64
batch_size = 32
# โœ… ์˜ˆ์ธก ํ•จ์ˆ˜
def predict(predict_sentence):
data = [predict_sentence, '0']
dataset_another = [data]
# num_workers๋Š” ๋ฐฐํฌ ํ™˜๊ฒฝ์—์„œ 0์œผ๋กœ ์„ค์ • ๊ถŒ์žฅ
another_test = BERTDataset(dataset_another, 0, 1, tokenizer, vocab, max_len, True, False)
test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0)
model.eval() # ์˜ˆ์ธก ์‹œ ๋ชจ๋ธ์„ ํ‰๊ฐ€ ๋ชจ๋“œ๋กœ ์„ค์ •
with torch.no_grad(): # ๊ทธ๋ผ๋””์–ธํŠธ ๊ณ„์‚ฐ ๋น„ํ™œ์„ฑํ™”
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader):
token_ids = token_ids.long().to(device)
segment_ids = segment_ids.long().to(device)
out = model(token_ids, valid_length, segment_ids)
logits = out
logits = logits.detach().cpu().numpy()
predicted_category_index = np.argmax(logits)
predicted_category_name = list(category.keys())[predicted_category_index]
return predicted_category_name
# โœ… ์—”๋“œํฌ์ธํŠธ ์ •์˜
class InputText(BaseModel):
text: str
@app.get("/")
def root():
return {"message": "Text Classification API (KoBERT)"}
@app.post("/predict")
async def predict_route(item: InputText):
result = predict(item.text)
return {"text": item.text, "classification": result}