Spaces:
Sleeping
Sleeping
File size: 5,369 Bytes
7f17fe7 95b43d8 7f17fe7 95b43d8 e66afc2 1efa28d 3cc319e 1efa28d 3cc319e 8153817 1efa28d 8153817 1efa28d 7f17fe7 1efa28d 7f17fe7 8153817 7f17fe7 8153817 7f17fe7 3cc319e 8153817 6ba018e 1efa28d 8153817 1efa28d 8153817 1efa28d 8153817 6ba018e 1efa28d 95b43d8 7f17fe7 95b43d8 8153817 1efa28d 8153817 7f17fe7 8153817 95b43d8 8153817 7f17fe7 95b43d8 4607c9c 689eabe 95b43d8 689eabe 7f17fe7 95b43d8 8153817 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from fastapi import FastAPI, Request
from pydantic import BaseModel
import torch
import pickle
import gluonnlp as nlp
import numpy as np
import os
import sys # ์ค๋ฅ ์ ์๋น์ค ์ข
๋ฃ๋ฅผ ์ํด sys ๋ชจ๋ ์ํฌํธ
# transformers์ AutoTokenizer๋ง ์ฌ์ฉํฉ๋๋ค.
from transformers import AutoTokenizer # BertModel, BertForSequenceClassification ๋ฑ์ ์ด์ ์ง์ ํ์ ์์ต๋๋ค.
from torch.utils.data import Dataset, DataLoader
import logging # ๋ก๊น
๋ชจ๋ ์ํฌํธ ์ ์ง
from huggingface_hub import hf_hub_download # hf_hub_download ์ํฌํธ ์ ์ง
# collections ๋ชจ๋์ ๋ ์ด์ ํ์ ์์ ์ ์์ง๋ง, ํน์ ๋ชฐ๋ผ ์ ์งํฉ๋๋ค.
import collections
# --- 1. FastAPI ์ฑ ๋ฐ ์ ์ญ ๋ณ์ ์ค์ ---
app = FastAPI()
device = torch.device("cpu") # Hugging Face Spaces์ ๋ฌด๋ฃ ํฐ์ด๋ ์ฃผ๋ก CPU๋ฅผ ์ฌ์ฉํฉ๋๋ค.
# โ
category ๋ก๋ (GitHub ์ ์ฅ์ ๋ฃจํธ์ ์์ด์ผ ํจ)
try:
with open("category.pkl", "rb") as f:
category = pickle.load(f)
print("category.pkl ๋ก๋ ์ฑ๊ณต.")
except FileNotFoundError:
print("Error: category.pkl ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. ํ๋ก์ ํธ ๋ฃจํธ์ ์๋์ง ํ์ธํ์ธ์.")
sys.exit(1) # ํ์ผ ์์ผ๋ฉด ์๋น์ค ์์ํ์ง ์์
# โ
vocab ๋ก๋ (GitHub ์ ์ฅ์ ๋ฃจํธ์ ์์ด์ผ ํจ)
try:
with open("vocab.pkl", "rb") as f:
vocab = pickle.load(f)
print("vocab.pkl ๋ก๋ ์ฑ๊ณต.")
except FileNotFoundError:
print("Error: vocab.pkl ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. ํ๋ก์ ํธ ๋ฃจํธ์ ์๋์ง ํ์ธํ์ธ์.")
sys.exit(1) # ํ์ผ ์์ผ๋ฉด ์๋น์ค ์์ํ์ง ์์
# โ
ํ ํฌ๋์ด์ ๋ก๋ (transformers.AutoTokenizer ์ฌ์ฉ)
tokenizer = AutoTokenizer.from_pretrained('skt/kobert-base-v1')
print("ํ ํฌ๋์ด์ ๋ก๋ ์ฑ๊ณต.")
# โ
๋ชจ๋ธ ๋ก๋ (Hugging Face Hub์์ ๋ค์ด๋ก๋)
# textClassifierModel.pt ํ์ผ์ ์ด๋ฏธ ๊ฒฝ๋ํ๋ '์์ ํ ๋ชจ๋ธ ๊ฐ์ฒด'๋ผ๊ณ ๊ฐ์ ํ๊ณ ์ง์ ๋ก๋ํฉ๋๋ค.
try:
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" # ์ฌ์ฉ์๋์ ์ค์ Hugging Face ์ ์ฅ์ ID
HF_MODEL_FILENAME = "textClassifierModel.pt" # Hugging Face Hub์ ์
๋ก๋ํ ํ์ผ ์ด๋ฆ๊ณผ ์ผ์นํด์ผ ํฉ๋๋ค.
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
print(f"๋ชจ๋ธ ํ์ผ์ด '{model_path}'์ ์ฑ๊ณต์ ์ผ๋ก ๋ค์ด๋ก๋๋์์ต๋๋ค.")
# --- ์์ ๋ ํต์ฌ ๋ถ๋ถ ---
# ๊ฒฝ๋ํ๋ ๋ชจ๋ธ ๊ฐ์ฒด๋ฅผ ์ง์ ๋ก๋ํฉ๋๋ค.
# ์ด ํ์ผ์ ์ด๋ฏธ PyTorch ๋ชจ๋ธ ๊ฐ์ฒด(์์ํ๋ ๋ชจ๋ธ ํฌํจ)์ด๋ฏ๋ก ๋ฐ๋ก ๋ก๋ํ์ฌ ์ฌ์ฉํฉ๋๋ค.
model = torch.load(model_path, map_location=device)
# --- ์์ ๋ ํต์ฌ ๋ถ๋ถ ๋ ---
model.eval() # ์ถ๋ก ๋ชจ๋๋ก ์ค์
print("๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต.")
except Exception as e:
print(f"Error: ๋ชจ๋ธ ๋ค์ด๋ก๋ ๋๋ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}")
sys.exit(1) # ๋ชจ๋ธ ๋ก๋ ์คํจ ์ ์๋น์ค ์์ํ์ง ์์
# --- 2. BERTDataset ํด๋์ค ์ ์ (dataset.py์์ ์ฎ๊ฒจ์ด) ---
# ์ด ํด๋์ค๋ ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ธ ์
๋ ฅ ํ์์ผ๋ก ๋ณํํฉ๋๋ค.
class BERTDataset(Dataset):
def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
# nlp.data.BERTSentenceTransform์ ํ ํฌ๋์ด์ ํจ์๋ฅผ ๋ฐ์ต๋๋ค.
# AutoTokenizer์ tokenize ๋ฉ์๋๋ฅผ ์ง์ ์ ๋ฌํฉ๋๋ค.
transform = nlp.data.BERTSentenceTransform(
bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
)
self.sentences = [transform([i[sent_idx]]) for i in dataset]
self.labels = [np.int32(i[label_idx]) for i in dataset]
def __getitem__(self, i):
return (self.sentences[i] + (self.labels[i],))
def __len__(self):
return len(self.labels)
# โ
๋ฐ์ดํฐ์
์์ฑ์ ํ์ํ ํ๋ผ๋ฏธํฐ
max_len = 64
batch_size = 32
# โ
์์ธก ํจ์
def predict(predict_sentence):
data = [predict_sentence, '0']
dataset_another = [data]
# num_workers๋ ๋ฐฐํฌ ํ๊ฒฝ์์ 0์ผ๋ก ์ค์ ๊ถ์ฅ
# tokenizer.tokenize๋ฅผ BERTDataset์ ์ ๋ฌํฉ๋๋ค.
another_test = BERTDataset(dataset_another, 0, 1, tokenizer.tokenize, vocab, max_len, True, False)
test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0)
model.eval() # ์์ธก ์ ๋ชจ๋ธ์ ํ๊ฐ ๋ชจ๋๋ก ์ค์
with torch.no_grad(): # ๊ทธ๋ผ๋์ธํธ ๊ณ์ฐ ๋นํ์ฑํ
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader):
token_ids = token_ids.long().to(device)
segment_ids = segment_ids.long().to(device)
out = model(token_ids, valid_length, segment_ids)
logits = out
logits = logits.detach().cpu().numpy()
predicted_category_index = np.argmax(logits)
predicted_category_name = list(category.keys())[predicted_category_index]
return predicted_category_name
# โ
์๋ํฌ์ธํธ ์ ์
class InputText(BaseModel):
text: str
@app.get("/")
def root():
return {"message": "Text Classification API (KoBERT)"}
@app.post("/predict")
async def predict_route(item: InputText):
result = predict(item.text)
return {"text": item.text, "classification": result}
|