Spaces:
Sleeping
Sleeping
File size: 6,779 Bytes
7f17fe7 95b43d8 7f17fe7 95b43d8 e66afc2 3cc319e 8153817 3cc319e 7f17fe7 8153817 3cc319e 8153817 7f17fe7 8153817 7f17fe7 8153817 7f17fe7 8153817 7f17fe7 3cc319e 8153817 6ba018e 95b43d8 8153817 3cc319e 95b43d8 3cc319e 8153817 95b43d8 8153817 6ba018e 95b43d8 7f17fe7 95b43d8 8153817 3cc319e 8153817 7f17fe7 8153817 95b43d8 8153817 7f17fe7 95b43d8 4607c9c 689eabe 95b43d8 689eabe 7f17fe7 95b43d8 8153817 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
from fastapi import FastAPI, Request
from pydantic import BaseModel
import torch
import pickle
import gluonnlp as nlp
import numpy as np
import os
import sys # sys ๋ชจ๋ ์ํฌํธ ์ถ๊ฐ (NameError ํด๊ฒฐ)
# KoBERTTokenizer ๋์ transformers.AutoTokenizer ์ฌ์ฉ
from transformers import BertModel, AutoTokenizer # AutoTokenizer ์ํฌํธ ์ ์ง
from torch.utils.data import Dataset, DataLoader
import logging # ๋ก๊น
๋ชจ๋ ์ํฌํธ ์ ์ง
from huggingface_hub import hf_hub_download # hf_hub_download ์ํฌํธ ์ถ๊ฐ
import collections # collections ๋ชจ๋ ์ํฌํธ ์ ์ง
# --- 1. BERTClassifier ๋ชจ๋ธ ํด๋์ค ์ ์ (model.py์์ ์ฎ๊ฒจ์ด) ---
class BERTClassifier(torch.nn.Module):
def __init__(self,
bert,
hidden_size = 768,
num_classes=5, # ๋ถ๋ฅํ ํด๋์ค ์ (category ๋์
๋๋ฆฌ ํฌ๊ธฐ์ ์ผ์น)
dr_rate=None,
params=None):
super(BERTClassifier, self).__init__()
self.bert = bert
self.dr_rate = dr_rate
self.classifier = torch.nn.Linear(hidden_size , num_classes)
if dr_rate:
self.dropout = torch.nn.Dropout(p=dr_rate)
def gen_attention_mask(self, token_ids, valid_length):
attention_mask = torch.zeros_like(token_ids)
for i, v in enumerate(valid_length):
attention_mask[i][:v] = 1
return attention_mask.float()
def forward(self, token_ids, valid_length, segment_ids):
attention_mask = self.gen_attention_mask(token_ids, valid_length)
_, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(), attention_mask=attention_mask.float().to(token_ids.device), return_dict=False)
if self.dr_rate:
out = self.dropout(pooler)
else:
out = pooler
return self.classifier(out)
# --- 2. BERTDataset ํด๋์ค ์ ์ (dataset.py์์ ์ฎ๊ฒจ์ด) ---
class BERTDataset(Dataset):
def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
# nlp.data.BERTSentenceTransform์ ํ ํฌ๋์ด์ ํจ์๋ฅผ ๋ฐ์ต๋๋ค.
# AutoTokenizer์ tokenize ๋ฉ์๋๋ฅผ ์ง์ ์ ๋ฌํฉ๋๋ค.
transform = nlp.data.BERTSentenceTransform(
bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
)
self.sentences = [transform([i[sent_idx]]) for i in dataset]
self.labels = [np.int32(i[label_idx]) for i in dataset]
def __getitem__(self, i):
return (self.sentences[i] + (self.labels[i],))
def __len__(self):
return len(self.labels)
# --- 3. FastAPI ์ฑ ๋ฐ ์ ์ญ ๋ณ์ ์ค์ ---
app = FastAPI()
device = torch.device("cpu") # Render์ ๋ฌด๋ฃ ํฐ์ด๋ ์ฃผ๋ก CPU๋ฅผ ์ฌ์ฉํฉ๋๋ค.
# โ
category ๋ก๋ (GitHub ์ ์ฅ์ ๋ฃจํธ์ ์์ด์ผ ํจ)
try:
with open("category.pkl", "rb") as f:
category = pickle.load(f)
print("category.pkl ๋ก๋ ์ฑ๊ณต.")
except FileNotFoundError:
print("Error: category.pkl ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. ํ๋ก์ ํธ ๋ฃจํธ์ ์๋์ง ํ์ธํ์ธ์.")
sys.exit(1) # ํ์ผ ์์ผ๋ฉด ์๋น์ค ์์ํ์ง ์์
# โ
vocab ๋ก๋ (GitHub ์ ์ฅ์ ๋ฃจํธ์ ์์ด์ผ ํจ)
try:
with open("vocab.pkl", "rb") as f:
vocab = pickle.load(f)
print("vocab.pkl ๋ก๋ ์ฑ๊ณต.")
except FileNotFoundError:
print("Error: vocab.pkl ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. ํ๋ก์ ํธ ๋ฃจํธ์ ์๋์ง ํ์ธํ์ธ์.")
sys.exit(1) # ํ์ผ ์์ผ๋ฉด ์๋น์ค ์์ํ์ง ์์
# โ
ํ ํฌ๋์ด์ ๋ก๋ (transformers.AutoTokenizer ์ฌ์ฉ)
# KoBERTTokenizer ๋์ AutoTokenizer๋ฅผ ์ฌ์ฉํ์ฌ KoBERT ๋ชจ๋ธ์ ํ ํฌ๋์ด์ ๋ฅผ ๋ก๋ํฉ๋๋ค.
# ์ด๋ ๊ฒ ํ๋ฉด XLNetTokenizer ๊ฒฝ๊ณ ๋ฐ kobert_tokenizer ์ค์น ๋ฌธ์ ๋ฅผ ํผํ ์ ์์ต๋๋ค.
tokenizer = AutoTokenizer.from_pretrained('skt/kobert-base-v1')
print("ํ ํฌ๋์ด์ ๋ก๋ ์ฑ๊ณต.")
# โ
๋ชจ๋ธ ๋ก๋
# num_classes๋ category ๋์
๋๋ฆฌ์ ํฌ๊ธฐ์ ์ผ์นํด์ผ ํฉ๋๋ค.
bertmodel = BertModel.from_pretrained('skt/kobert-base-v1')
model = BERTClassifier(
bertmodel,
dr_rate=0.5, # ํ์ต ์ ์ฌ์ฉ๋ dr_rate ๊ฐ์ผ๋ก ๋ณ๊ฒฝํ์ธ์.
num_classes=len(category)
)
# textClassifierModel.pt ํ์ผ ๋ก๋
try:
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" # ์ฌ์ฉ์๋์ ์ค์ Hugging Face ์ ์ฅ์ ID
HF_MODEL_FILENAME = "textClassifierModel.pt"
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
print(f"๋ชจ๋ธ ํ์ผ์ด '{model_path}'์ ์ฑ๊ณต์ ์ผ๋ก ๋ค์ด๋ก๋๋์์ต๋๋ค.")
loaded_state_dict = torch.load(model_path, map_location=device)
new_state_dict = collections.OrderedDict()
for k, v in loaded_state_dict.items():
name = k
if name.startswith('module.'):
name = name[7:]
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.to(device) # ๋ชจ๋ธ์ ๋๋ฐ์ด์ค๋ก ์ด๋
model.eval() # ์ถ๋ก ๋ชจ๋๋ก ์ค์
print("๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต.")
except Exception as e:
print(f"Error: ๋ชจ๋ธ ๋ค์ด๋ก๋ ๋๋ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}")
sys.exit(1) # ๋ชจ๋ธ ๋ก๋ ์คํจ ์ ์๋น์ค ์์ํ์ง ์์
# โ
๋ฐ์ดํฐ์
์์ฑ์ ํ์ํ ํ๋ผ๋ฏธํฐ
max_len = 64
batch_size = 32
# โ
์์ธก ํจ์
def predict(predict_sentence):
data = [predict_sentence, '0']
dataset_another = [data]
# num_workers๋ ๋ฐฐํฌ ํ๊ฒฝ์์ 0์ผ๋ก ์ค์ ๊ถ์ฅ
another_test = BERTDataset(dataset_another, 0, 1, tokenizer, vocab, max_len, True, False) # tokenizer ๊ฐ์ฒด ์ง์ ์ ๋ฌ
test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0)
model.eval() # ์์ธก ์ ๋ชจ๋ธ์ ํ๊ฐ ๋ชจ๋๋ก ์ค์
with torch.no_grad(): # ๊ทธ๋ผ๋์ธํธ ๊ณ์ฐ ๋นํ์ฑํ
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader):
token_ids = token_ids.long().to(device)
segment_ids = segment_ids.long().to(device)
out = model(token_ids, valid_length, segment_ids)
logits = out
logits = logits.detach().cpu().numpy()
predicted_category_index = np.argmax(logits)
predicted_category_name = list(category.keys())[predicted_category_index]
return predicted_category_name
# โ
์๋ํฌ์ธํธ ์ ์
class InputText(BaseModel):
text: str
@app.get("/")
def root():
return {"message": "Text Classification API (KoBERT)"}
@app.post("/predict")
async def predict_route(item: InputText):
result = predict(item.text)
return {"text": item.text, "classification": result}
|