Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,15 +7,66 @@ import numpy as np
|
|
| 7 |
import os
|
| 8 |
import sys # ์ค๋ฅ ์ ์๋น์ค ์ข
๋ฃ๋ฅผ ์ํด sys ๋ชจ๋ ์ํฌํธ
|
| 9 |
|
| 10 |
-
# transformers์ AutoTokenizer
|
| 11 |
-
from transformers import AutoTokenizer # BertModel
|
| 12 |
from torch.utils.data import Dataset, DataLoader
|
| 13 |
import logging # ๋ก๊น
๋ชจ๋ ์ํฌํธ ์ ์ง
|
| 14 |
from huggingface_hub import hf_hub_download # hf_hub_download ์ํฌํธ ์ ์ง
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
app = FastAPI()
|
| 20 |
device = torch.device("cpu") # Hugging Face Spaces์ ๋ฌด๋ฃ ํฐ์ด๋ ์ฃผ๋ก CPU๋ฅผ ์ฌ์ฉํฉ๋๋ค.
|
| 21 |
|
|
@@ -42,7 +93,6 @@ tokenizer = AutoTokenizer.from_pretrained('skt/kobert-base-v1')
|
|
| 42 |
print("ํ ํฌ๋์ด์ ๋ก๋ ์ฑ๊ณต.")
|
| 43 |
|
| 44 |
# โ
๋ชจ๋ธ ๋ก๋ (Hugging Face Hub์์ ๋ค์ด๋ก๋)
|
| 45 |
-
# textClassifierModel.pt ํ์ผ์ ์ด๋ฏธ ๊ฒฝ๋ํ๋ '์์ ํ ๋ชจ๋ธ ๊ฐ์ฒด'๋ผ๊ณ ๊ฐ์ ํ๊ณ ์ง์ ๋ก๋ํฉ๋๋ค.
|
| 46 |
try:
|
| 47 |
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" # ์ฌ์ฉ์๋์ ์ค์ Hugging Face ์ ์ฅ์ ID
|
| 48 |
HF_MODEL_FILENAME = "textClassifierModel.pt" # Hugging Face Hub์ ์
๋ก๋ํ ํ์ผ ์ด๋ฆ๊ณผ ์ผ์นํด์ผ ํฉ๋๋ค.
|
|
@@ -51,11 +101,35 @@ try:
|
|
| 51 |
print(f"๋ชจ๋ธ ํ์ผ์ด '{model_path}'์ ์ฑ๊ณต์ ์ผ๋ก ๋ค์ด๋ก๋๋์์ต๋๋ค.")
|
| 52 |
|
| 53 |
# --- ์์ ๋ ํต์ฌ ๋ถ๋ถ ---
|
| 54 |
-
#
|
| 55 |
-
#
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
# --- ์์ ๋ ํต์ฌ ๋ถ๋ถ ๋ ---
|
| 58 |
|
|
|
|
| 59 |
model.eval() # ์ถ๋ก ๋ชจ๋๋ก ์ค์
|
| 60 |
print("๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต.")
|
| 61 |
|
|
@@ -64,25 +138,6 @@ except Exception as e:
|
|
| 64 |
sys.exit(1) # ๋ชจ๋ธ ๋ก๋ ์คํจ ์ ์๋น์ค ์์ํ์ง ์์
|
| 65 |
|
| 66 |
|
| 67 |
-
# --- 2. BERTDataset ํด๋์ค ์ ์ (dataset.py์์ ์ฎ๊ฒจ์ด) ---
|
| 68 |
-
# ์ด ํด๋์ค๋ ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ธ ์
๋ ฅ ํ์์ผ๋ก ๋ณํํฉ๋๋ค.
|
| 69 |
-
class BERTDataset(Dataset):
|
| 70 |
-
def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
|
| 71 |
-
# nlp.data.BERTSentenceTransform์ ํ ํฌ๋์ด์ ํจ์๋ฅผ ๋ฐ์ต๋๋ค.
|
| 72 |
-
# AutoTokenizer์ tokenize ๋ฉ์๋๋ฅผ ์ง์ ์ ๋ฌํฉ๋๋ค.
|
| 73 |
-
transform = nlp.data.BERTSentenceTransform(
|
| 74 |
-
bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
|
| 75 |
-
)
|
| 76 |
-
self.sentences = [transform([i[sent_idx]]) for i in dataset]
|
| 77 |
-
self.labels = [np.int32(i[label_idx]) for i in dataset]
|
| 78 |
-
|
| 79 |
-
def __getitem__(self, i):
|
| 80 |
-
return (self.sentences[i] + (self.labels[i],))
|
| 81 |
-
|
| 82 |
-
def __len__(self):
|
| 83 |
-
return len(self.labels)
|
| 84 |
-
|
| 85 |
-
|
| 86 |
# โ
๋ฐ์ดํฐ์
์์ฑ์ ํ์ํ ํ๋ผ๋ฏธํฐ
|
| 87 |
max_len = 64
|
| 88 |
batch_size = 32
|
|
@@ -125,4 +180,3 @@ def root():
|
|
| 125 |
async def predict_route(item: InputText):
|
| 126 |
result = predict(item.text)
|
| 127 |
return {"text": item.text, "classification": result}
|
| 128 |
-
|
|
|
|
| 7 |
import os
|
| 8 |
import sys # ์ค๋ฅ ์ ์๋น์ค ์ข
๋ฃ๋ฅผ ์ํด sys ๋ชจ๋ ์ํฌํธ
|
| 9 |
|
| 10 |
+
# transformers์ AutoTokenizer ๋ฐ BertModel ์ํฌํธ
|
| 11 |
+
from transformers import AutoTokenizer, BertModel # BertModel ์ํฌํธ ์ถ๊ฐ
|
| 12 |
from torch.utils.data import Dataset, DataLoader
|
| 13 |
import logging # ๋ก๊น
๋ชจ๋ ์ํฌํธ ์ ์ง
|
| 14 |
from huggingface_hub import hf_hub_download # hf_hub_download ์ํฌํธ ์ ์ง
|
| 15 |
+
import collections # collections ๋ชจ๋ ์ํฌํธ ์ ์ง
|
| 16 |
+
|
| 17 |
+
# --- 1. BERTClassifier ๋ชจ๋ธ ํด๋์ค ์ ์ ---
|
| 18 |
+
# ์ด ํด๋์ค๋ ๋ชจ๋ธ์ ์ํคํ
์ฒ๋ฅผ ์ ์ํฉ๋๋ค.
|
| 19 |
+
class BERTClassifier(torch.nn.Module):
|
| 20 |
+
def __init__(self,
|
| 21 |
+
bert,
|
| 22 |
+
hidden_size = 768,
|
| 23 |
+
num_classes=5, # ๋ถ๋ฅํ ํด๋์ค ์ (category ๋์
๋๋ฆฌ ํฌ๊ธฐ์ ์ผ์น)
|
| 24 |
+
dr_rate=None,
|
| 25 |
+
params=None):
|
| 26 |
+
super(BERTClassifier, self).__init__()
|
| 27 |
+
self.bert = bert
|
| 28 |
+
self.dr_rate = dr_rate
|
| 29 |
+
|
| 30 |
+
self.classifier = torch.nn.Linear(hidden_size , num_classes)
|
| 31 |
+
if dr_rate:
|
| 32 |
+
self.dropout = torch.nn.Dropout(p=dr_rate)
|
| 33 |
+
|
| 34 |
+
def gen_attention_mask(self, token_ids, valid_length):
|
| 35 |
+
attention_mask = torch.zeros_like(token_ids)
|
| 36 |
+
for i, v in enumerate(valid_length):
|
| 37 |
+
attention_mask[i][:v] = 1
|
| 38 |
+
return attention_mask.float()
|
| 39 |
+
|
| 40 |
+
def forward(self, token_ids, valid_length, segment_ids):
|
| 41 |
+
attention_mask = self.gen_attention_mask(token_ids, valid_length)
|
| 42 |
+
|
| 43 |
+
_, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(), attention_mask=attention_mask.float().to(token_ids.device), return_dict=False)
|
| 44 |
+
|
| 45 |
+
if self.dr_rate:
|
| 46 |
+
out = self.dropout(pooler)
|
| 47 |
+
else:
|
| 48 |
+
out = pooler
|
| 49 |
+
return self.classifier(out)
|
| 50 |
+
|
| 51 |
+
# --- 2. BERTDataset ํด๋์ค ์ ์ ---
|
| 52 |
+
# ์ด ํด๋์ค๋ ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ธ ์
๋ ฅ ํ์์ผ๋ก ๋ณํํฉ๋๋ค.
|
| 53 |
+
class BERTDataset(Dataset):
|
| 54 |
+
def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
|
| 55 |
+
# nlp.data.BERTSentenceTransform์ ํ ํฌ๋์ด์ ํจ์๋ฅผ ๋ฐ์ต๋๋ค.
|
| 56 |
+
# AutoTokenizer์ tokenize ๋ฉ์๋๋ฅผ ์ง์ ์ ๋ฌํฉ๋๋ค.
|
| 57 |
+
transform = nlp.data.BERTSentenceTransform(
|
| 58 |
+
bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
|
| 59 |
+
)
|
| 60 |
+
self.sentences = [transform([i[sent_idx]]) for i in dataset]
|
| 61 |
+
self.labels = [np.int32(i[label_idx]) for i in dataset]
|
| 62 |
|
| 63 |
+
def __getitem__(self, i):
|
| 64 |
+
return (self.sentences[i] + (self.labels[i],))
|
| 65 |
+
|
| 66 |
+
def __len__(self):
|
| 67 |
+
return len(self.labels)
|
| 68 |
+
|
| 69 |
+
# --- 3. FastAPI ์ฑ ๋ฐ ์ ์ญ ๋ณ์ ์ค์ ---
|
| 70 |
app = FastAPI()
|
| 71 |
device = torch.device("cpu") # Hugging Face Spaces์ ๋ฌด๋ฃ ํฐ์ด๋ ์ฃผ๋ก CPU๋ฅผ ์ฌ์ฉํฉ๋๋ค.
|
| 72 |
|
|
|
|
| 93 |
print("ํ ํฌ๋์ด์ ๋ก๋ ์ฑ๊ณต.")
|
| 94 |
|
| 95 |
# โ
๋ชจ๋ธ ๋ก๋ (Hugging Face Hub์์ ๋ค์ด๋ก๋)
|
|
|
|
| 96 |
try:
|
| 97 |
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" # ์ฌ์ฉ์๋์ ์ค์ Hugging Face ์ ์ฅ์ ID
|
| 98 |
HF_MODEL_FILENAME = "textClassifierModel.pt" # Hugging Face Hub์ ์
๋ก๋ํ ํ์ผ ์ด๋ฆ๊ณผ ์ผ์นํด์ผ ํฉ๋๋ค.
|
|
|
|
| 101 |
print(f"๋ชจ๋ธ ํ์ผ์ด '{model_path}'์ ์ฑ๊ณต์ ์ผ๋ก ๋ค์ด๋ก๋๋์์ต๋๋ค.")
|
| 102 |
|
| 103 |
# --- ์์ ๋ ํต์ฌ ๋ถ๋ถ ---
|
| 104 |
+
# 1. BertModel.from_pretrained๋ฅผ ์ฌ์ฉํ์ฌ ๊ธฐ๋ณธ BERT ๋ชจ๋ธ์ ๋ก๋ํฉ๋๋ค.
|
| 105 |
+
# ์ด๋ ๊ฒ ํ๋ฉด ๋ชจ๋ธ์ ์ํคํ
์ฒ์ ์ฌ์ ํ์ต๋ ๊ฐ์ค์น๊ฐ ๋ก๋๋ฉ๋๋ค.
|
| 106 |
+
bert_base_model = BertModel.from_pretrained('skt/kobert-base-v1')
|
| 107 |
+
|
| 108 |
+
# 2. BERTClassifier ์ธ์คํด์ค๋ฅผ ์์ฑํฉ๋๋ค.
|
| 109 |
+
# ์ฌ๊ธฐ์ num_classes๋ category ๋์
๋๋ฆฌ์ ํฌ๊ธฐ์ ์ผ์นํด์ผ ํฉ๋๋ค.
|
| 110 |
+
model = BERTClassifier(
|
| 111 |
+
bert_base_model,
|
| 112 |
+
dr_rate=0.5, # ํ์ต ์ ์ฌ์ฉ๋ dr_rate ๊ฐ์ผ๋ก ๋ณ๊ฒฝํ์ธ์.
|
| 113 |
+
num_classes=len(category)
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# 3. ๋ค์ด๋ก๋๋ ํ์ผ์์ state_dict๋ฅผ ๋ก๋ํฉ๋๋ค.
|
| 117 |
+
# ์ด ํ์ผ์ ์ฌ์ฉ์๋์ ๊ฒฝ๋ํ๋ ๋ชจ๋ธ์ ๊ฐ์ค์น๋ง ํฌํจํ๊ณ ์์ต๋๋ค.
|
| 118 |
+
loaded_state_dict = torch.load(model_path, map_location=device)
|
| 119 |
+
|
| 120 |
+
# 4. ๋ก๋๋ state_dict์ ํค๋ฅผ ์กฐ์ ํ๊ณ ๋ชจ๋ธ์ ์ ์ฉํฉ๋๋ค.
|
| 121 |
+
# 'module.' ์ ๋์ฌ๊ฐ ๏ฟฝ๏ฟฝ๏ฟฝ์ด์๋ ๊ฒฝ์ฐ ์ ๊ฑฐํ๋ ๋ก์ง์ ํฌํจํฉ๋๋ค.
|
| 122 |
+
new_state_dict = collections.OrderedDict()
|
| 123 |
+
for k, v in loaded_state_dict.items():
|
| 124 |
+
name = k
|
| 125 |
+
if name.startswith('module.'):
|
| 126 |
+
name = name[7:]
|
| 127 |
+
new_state_dict[name] = v
|
| 128 |
+
|
| 129 |
+
model.load_state_dict(new_state_dict)
|
| 130 |
# --- ์์ ๋ ํต์ฌ ๋ถ๋ถ ๋ ---
|
| 131 |
|
| 132 |
+
model.to(device) # ๋ชจ๋ธ์ ๋๋ฐ์ด์ค๋ก ์ด๋
|
| 133 |
model.eval() # ์ถ๋ก ๋ชจ๋๋ก ์ค์
|
| 134 |
print("๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต.")
|
| 135 |
|
|
|
|
| 138 |
sys.exit(1) # ๋ชจ๋ธ ๋ก๋ ์คํจ ์ ์๋น์ค ์์ํ์ง ์์
|
| 139 |
|
| 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
# โ
๋ฐ์ดํฐ์
์์ฑ์ ํ์ํ ํ๋ผ๋ฏธํฐ
|
| 142 |
max_len = 64
|
| 143 |
batch_size = 32
|
|
|
|
| 180 |
async def predict_route(item: InputText):
|
| 181 |
result = predict(item.text)
|
| 182 |
return {"text": item.text, "classification": result}
|
|
|