Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,35 +5,134 @@ import pickle
|
|
| 5 |
import gluonnlp as nlp
|
| 6 |
import numpy as np
|
| 7 |
import os
|
| 8 |
-
from kobert_tokenizer import KoBERTTokenizer
|
| 9 |
-
from
|
| 10 |
-
from
|
| 11 |
-
|
| 12 |
-
import logging
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
app = FastAPI()
|
| 15 |
-
device = torch.device("cpu")
|
| 16 |
|
| 17 |
-
# โ
category ๋ก๋
|
| 18 |
-
|
| 19 |
-
category
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
# โ
vocab ๋ก๋
|
| 22 |
-
|
| 23 |
-
vocab
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
# โ
ํ ํฌ๋์ด์
|
|
|
|
| 26 |
tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1')
|
|
|
|
| 27 |
|
| 28 |
# โ
๋ชจ๋ธ ๋ก๋
|
|
|
|
|
|
|
| 29 |
model = BERTClassifier(
|
| 30 |
BertModel.from_pretrained('skt/kobert-base-v1'),
|
| 31 |
-
dr_rate=0.5,
|
| 32 |
num_classes=len(category)
|
| 33 |
)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
# โ
๋ฐ์ดํฐ์
์์ฑ์ ํ์ํ ํ๋ผ๋ฏธํฐ
|
| 39 |
max_len = 64
|
|
@@ -43,20 +142,26 @@ batch_size = 32
|
|
| 43 |
def predict(predict_sentence):
|
| 44 |
data = [predict_sentence, '0']
|
| 45 |
dataset_another = [data]
|
|
|
|
| 46 |
another_test = BERTDataset(dataset_another, 0, 1, tokenizer, vocab, max_len, True, False)
|
| 47 |
-
test_dataLoader =
|
| 48 |
|
| 49 |
-
model.eval()
|
| 50 |
-
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader):
|
| 51 |
-
token_ids = token_ids.long().to(device)
|
| 52 |
-
segment_ids = segment_ids.long().to(device)
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
# โ
์๋ํฌ์ธํธ ์ ์
|
| 62 |
class InputText(BaseModel):
|
|
@@ -70,3 +175,4 @@ def root():
|
|
| 70 |
async def predict_route(item: InputText):
|
| 71 |
result = predict(item.text)
|
| 72 |
return {"text": item.text, "classification": result}
|
|
|
|
|
|
| 5 |
import gluonnlp as nlp
|
| 6 |
import numpy as np
|
| 7 |
import os
|
| 8 |
+
from kobert_tokenizer import KoBERTTokenizer # kobert_tokenizer ์ํฌํธ ์ ์ง
|
| 9 |
+
from transformers import BertModel # BertModel ์ํฌํธ ์ ์ง
|
| 10 |
+
from torch.utils.data import Dataset, DataLoader # DataLoader ์ํฌํธ ์ถ๊ฐ
|
| 11 |
+
import logging # ๋ก๊น
๋ชจ๋ ์ํฌํธ ์ ์ง
|
|
|
|
| 12 |
|
| 13 |
+
# --- 1. BERTClassifier ๋ชจ๋ธ ํด๋์ค ์ ์ (model.py์์ ์ฎ๊ฒจ์ด) ---
|
| 14 |
+
# ์ด ํด๋์ค๋ ๋ชจ๋ธ์ ์ํคํ
์ฒ๋ฅผ ์ ์ํฉ๋๋ค.
|
| 15 |
+
class BERTClassifier(torch.nn.Module):
|
| 16 |
+
def __init__(self,
|
| 17 |
+
bert,
|
| 18 |
+
hidden_size = 768,
|
| 19 |
+
num_classes=5, # ๋ถ๋ฅํ ํด๋์ค ์ (category ๋์
๋๋ฆฌ ํฌ๊ธฐ์ ์ผ์น)
|
| 20 |
+
dr_rate=None,
|
| 21 |
+
params=None):
|
| 22 |
+
super(BERTClassifier, self).__init__()
|
| 23 |
+
self.bert = bert
|
| 24 |
+
self.dr_rate = dr_rate
|
| 25 |
+
|
| 26 |
+
self.classifier = torch.nn.Linear(hidden_size , num_classes)
|
| 27 |
+
if dr_rate:
|
| 28 |
+
self.dropout = torch.nn.Dropout(p=dr_rate)
|
| 29 |
+
|
| 30 |
+
def gen_attention_mask(self, token_ids, valid_length):
|
| 31 |
+
attention_mask = torch.zeros_like(token_ids)
|
| 32 |
+
for i, v in enumerate(valid_length):
|
| 33 |
+
attention_mask[i][:v] = 1
|
| 34 |
+
return attention_mask.float()
|
| 35 |
+
|
| 36 |
+
def forward(self, token_ids, valid_length, segment_ids):
|
| 37 |
+
attention_mask = self.gen_attention_mask(token_ids, valid_length)
|
| 38 |
+
|
| 39 |
+
# BertModel์ ์ถ๋ ฅ ๊ตฌ์กฐ์ ๋ฐ๋ผ ์์
|
| 40 |
+
# Hugging Face Transformers์ BertModel์ (last_hidden_state, pooler_output, ...) ๋ฐํ
|
| 41 |
+
# pooler_output (CLS ํ ํฐ์ ์ต์ข
์๋ ์ํ๋ฅผ ํต๊ณผํ ๊ฒฐ๊ณผ) ์ฌ์ฉ
|
| 42 |
+
_, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(), attention_mask=attention_mask.float().to(token_ids.device), return_dict=False)
|
| 43 |
+
|
| 44 |
+
if self.dr_rate:
|
| 45 |
+
out = self.dropout(pooler)
|
| 46 |
+
else:
|
| 47 |
+
out = pooler
|
| 48 |
+
return self.classifier(out)
|
| 49 |
+
|
| 50 |
+
# --- 2. BERTDataset ํด๋์ค ์ ์ (dataset.py์์ ์ฎ๊ฒจ์ด) ---
|
| 51 |
+
# ์ด ํด๋์ค๋ ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ธ ์
๋ ฅ ํ์์ผ๋ก ๋ณํํฉ๋๋ค.
|
| 52 |
+
class BERTDataset(Dataset):
|
| 53 |
+
def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
|
| 54 |
+
transform = nlp.data.BERTSentenceTransform(
|
| 55 |
+
bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
|
| 56 |
+
)
|
| 57 |
+
self.sentences = [transform([i[sent_idx]]) for i in dataset]
|
| 58 |
+
self.labels = [np.int32(i[label_idx]) for i in dataset]
|
| 59 |
+
|
| 60 |
+
def __getitem__(self, i):
|
| 61 |
+
return (self.sentences[i] + (self.labels[i],))
|
| 62 |
+
|
| 63 |
+
def __len__(self):
|
| 64 |
+
return len(self.labels)
|
| 65 |
+
|
| 66 |
+
# --- 3. FastAPI ์ฑ ๋ฐ ์ ์ญ ๋ณ์ ์ค์ ---
|
| 67 |
app = FastAPI()
|
| 68 |
+
device = torch.device("cpu") # Render์ ๋ฌด๋ฃ ํฐ์ด๋ ์ฃผ๋ก CPU๋ฅผ ์ฌ์ฉํฉ๋๋ค.
|
| 69 |
|
| 70 |
+
# โ
category ๋ก๋ (GitHub ์ ์ฅ์ ๋ฃจํธ์ ์์ด์ผ ํจ)
|
| 71 |
+
try:
|
| 72 |
+
with open("category.pkl", "rb") as f:
|
| 73 |
+
category = pickle.load(f)
|
| 74 |
+
print("category.pkl ๋ก๋ ์ฑ๊ณต.")
|
| 75 |
+
except FileNotFoundError:
|
| 76 |
+
print("Error: category.pkl ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. ํ๋ก์ ํธ ๋ฃจํธ์ ์๋์ง ํ์ธํ์ธ์.")
|
| 77 |
+
sys.exit(1) # ํ์ผ ์์ผ๋ฉด ์๋น์ค ์์ํ์ง ์์
|
| 78 |
|
| 79 |
+
# โ
vocab ๋ก๋ (GitHub ์ ์ฅ์ ๋ฃจํธ์ ์์ด์ผ ํจ)
|
| 80 |
+
try:
|
| 81 |
+
with open("vocab.pkl", "rb") as f:
|
| 82 |
+
vocab = pickle.load(f)
|
| 83 |
+
print("vocab.pkl ๋ก๋ ์ฑ๊ณต.")
|
| 84 |
+
except FileNotFoundError:
|
| 85 |
+
print("Error: vocab.pkl ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. ํ๋ก์ ํธ ๋ฃจํธ์ ์๋์ง ํ์ธํ์ธ์.")
|
| 86 |
+
sys.exit(1) # ํ์ผ ์์ผ๋ฉด ์๋น์ค ์์ํ์ง ์์
|
| 87 |
|
| 88 |
+
# โ
ํ ํฌ๋์ด์ ๋ก๋ (kobert_tokenizer ์ฌ์ฉ)
|
| 89 |
+
# Colab ์ฝ๋์์ ์ฌ์ฉ๋ ๋ฐฉ์์ด๋ฏ๋ก ์ ์งํฉ๋๋ค.
|
| 90 |
tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1')
|
| 91 |
+
print("ํ ํฌ๋์ด์ ๋ก๋ ์ฑ๊ณต.")
|
| 92 |
|
| 93 |
# โ
๋ชจ๋ธ ๋ก๋
|
| 94 |
+
# ๋ชจ๋ธ ์ํคํ
์ฒ๋ฅผ ์ ์ํ๊ณ , ์ ์ฅ๋ state_dict๋ฅผ ๋ก๋ํฉ๋๋ค.
|
| 95 |
+
# num_classes๋ category ๋์
๋๋ฆฌ์ ํฌ๊ธฐ์ ์ผ์นํด์ผ ํฉ๋๋ค.
|
| 96 |
model = BERTClassifier(
|
| 97 |
BertModel.from_pretrained('skt/kobert-base-v1'),
|
| 98 |
+
dr_rate=0.5, # ํ์ต ์ ์ฌ์ฉ๋ dr_rate ๊ฐ์ผ๋ก ๋ณ๊ฒฝํ์ธ์.
|
| 99 |
num_classes=len(category)
|
| 100 |
)
|
| 101 |
+
|
| 102 |
+
# textClassifierModel.pt ํ์ผ ๋ก๋
|
| 103 |
+
# ์ด ํ์ผ์ GitHub ์ ์ฅ์์ ์์ด์ผ ํ๋ฉฐ, Dockerfile์์ Hugging Face Hub์์ ๋ค์ด๋ก๋ํ๋๋ก ์ค์ ๋์ด ์์ต๋๋ค.
|
| 104 |
+
try:
|
| 105 |
+
# Dockerfile์์ ๋ชจ๋ธ์ ๋ค์ด๋ก๋ํ๋๋ก ์ค์ ํ์ผ๋ฏ๋ก, ์ฌ๊ธฐ์๋ ๋ก์ปฌ ๊ฒฝ๋ก๋ฅผ ์ฌ์ฉํฉ๋๋ค.
|
| 106 |
+
# ๋ง์ฝ Dockerfile์์ hf_hub_download๋ฅผ ์ฌ์ฉํ์ง ์๋๋ค๋ฉด, ์ฌ๊ธฐ์ hf_hub_download๋ฅผ ์ถ๊ฐํด์ผ ํฉ๋๋ค.
|
| 107 |
+
# ํ์ฌ Dockerfile์ git+https://github.com/SKTBrain/KOBERT#egg=kobert_tokenizer ๋ก๋๋ง ํฌํจํ๊ณ ,
|
| 108 |
+
# ๋ชจ๋ธ ํ์ผ ๋ค์ด๋ก๋๋ ํฌํจํ์ง ์์ต๋๋ค.
|
| 109 |
+
# ๋ฐ๋ผ์, ๋ชจ๋ธ ํ์ผ์ Hugging Face Hub์์ ๋ค์ด๋ก๋ํ๋ ๋ก์ง์ ๋ค์ ์ถ๊ฐํด์ผ ํฉ๋๋ค.
|
| 110 |
+
from huggingface_hub import hf_hub_download
|
| 111 |
+
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" # ์ฌ์ฉ์๋์ ์ค์ Hugging Face ์ ์ฅ์ ID
|
| 112 |
+
HF_MODEL_FILENAME = "textClassifierModel.pt"
|
| 113 |
+
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
|
| 114 |
+
print(f"๋ชจ๋ธ ํ์ผ์ด '{model_path}'์ ์ฑ๊ณต์ ์ผ๋ก ๋ค์ด๋ก๋๋์์ต๋๋ค.")
|
| 115 |
+
|
| 116 |
+
# ๋ชจ๋ธ์ state_dict๋ฅผ ๋ก๋ํฉ๋๋ค.
|
| 117 |
+
loaded_state_dict = torch.load(model_path, map_location=device)
|
| 118 |
+
|
| 119 |
+
# state_dict ํค ์กฐ์ (ํ์ํ ๊ฒฝ์ฐ)
|
| 120 |
+
new_state_dict = collections.OrderedDict()
|
| 121 |
+
for k, v in loaded_state_dict.items():
|
| 122 |
+
name = k
|
| 123 |
+
if name.startswith('module.'):
|
| 124 |
+
name = name[7:]
|
| 125 |
+
new_state_dict[name] = v
|
| 126 |
+
|
| 127 |
+
model.load_state_dict(new_state_dict)
|
| 128 |
+
model.to(device) # ๋ชจ๋ธ์ ๋๋ฐ์ด์ค๋ก ์ด๋
|
| 129 |
+
model.eval() # ์ถ๋ก ๋ชจ๋๋ก ์ค์
|
| 130 |
+
print("๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต.")
|
| 131 |
+
|
| 132 |
+
except Exception as e:
|
| 133 |
+
print(f"Error: ๋ชจ๋ธ ๋ค์ด๋ก๋ ๋๋ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 134 |
+
sys.exit(1) # ๋ชจ๋ธ ๋ก๋ ์คํจ ์ ์๋น์ค ์์ํ์ง ์์
|
| 135 |
+
|
| 136 |
|
| 137 |
# โ
๋ฐ์ดํฐ์
์์ฑ์ ํ์ํ ํ๋ผ๋ฏธํฐ
|
| 138 |
max_len = 64
|
|
|
|
| 142 |
def predict(predict_sentence):
|
| 143 |
data = [predict_sentence, '0']
|
| 144 |
dataset_another = [data]
|
| 145 |
+
# num_workers๋ ๋ฐฐํฌ ํ๊ฒฝ์์ 0์ผ๋ก ์ค์ ๊ถ์ฅ
|
| 146 |
another_test = BERTDataset(dataset_another, 0, 1, tokenizer, vocab, max_len, True, False)
|
| 147 |
+
test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0)
|
| 148 |
|
| 149 |
+
model.eval() # ์์ธก ์ ๋ชจ๋ธ์ ํ๊ฐ ๋ชจ๋๋ก ์ค์
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
+
with torch.no_grad(): # ๊ทธ๋ผ๋์ธํธ ๊ณ์ฐ ๋นํ์ฑํ
|
| 152 |
+
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader):
|
| 153 |
+
token_ids = token_ids.long().to(device)
|
| 154 |
+
segment_ids = segment_ids.long().to(device)
|
| 155 |
+
|
| 156 |
+
out = model(token_ids, valid_length, segment_ids)
|
| 157 |
+
|
| 158 |
+
logits = out
|
| 159 |
+
logits = logits.detach().cpu().numpy()
|
| 160 |
+
|
| 161 |
+
predicted_category_index = np.argmax(logits)
|
| 162 |
+
predicted_category_name = list(category.keys())[predicted_category_index]
|
| 163 |
+
|
| 164 |
+
return predicted_category_name
|
| 165 |
|
| 166 |
# โ
์๋ํฌ์ธํธ ์ ์
|
| 167 |
class InputText(BaseModel):
|
|
|
|
| 175 |
async def predict_route(item: InputText):
|
| 176 |
result = predict(item.text)
|
| 177 |
return {"text": item.text, "classification": result}
|
| 178 |
+
|