Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,17 +5,22 @@ import pickle
|
|
| 5 |
import gluonnlp as nlp
|
| 6 |
import numpy as np
|
| 7 |
import os
|
| 8 |
-
import sys
|
|
|
|
|
|
|
| 9 |
|
| 10 |
# transformers์ AutoTokenizer ๋ฐ BertModel ์ํฌํธ
|
| 11 |
-
from transformers import AutoTokenizer, BertModel
|
| 12 |
from torch.utils.data import Dataset, DataLoader
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
# --- 1. BERTClassifier ๋ชจ๋ธ ํด๋์ค ์ ์ ---
|
| 18 |
-
# ์ด ํด๋์ค๋ ๋ชจ๋ธ์ ์ํคํ
์ฒ๋ฅผ ์ ์ํฉ๋๋ค.
|
| 19 |
class BERTClassifier(torch.nn.Module):
|
| 20 |
def __init__(self,
|
| 21 |
bert,
|
|
@@ -49,11 +54,8 @@ class BERTClassifier(torch.nn.Module):
|
|
| 49 |
return self.classifier(out)
|
| 50 |
|
| 51 |
# --- 2. BERTDataset ํด๋์ค ์ ์ ---
|
| 52 |
-
# ์ด ํด๋์ค๋ ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ธ ์
๋ ฅ ํ์์ผ๋ก ๋ณํํฉ๋๋ค.
|
| 53 |
class BERTDataset(Dataset):
|
| 54 |
def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
|
| 55 |
-
# nlp.data.BERTSentenceTransform์ ํ ํฌ๋์ด์ ํจ์๋ฅผ ๋ฐ์ต๋๋ค.
|
| 56 |
-
# AutoTokenizer์ tokenize ๋ฉ์๋๋ฅผ ์ง์ ์ ๋ฌํฉ๋๋ค.
|
| 57 |
transform = nlp.data.BERTSentenceTransform(
|
| 58 |
bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
|
| 59 |
)
|
|
@@ -70,55 +72,45 @@ class BERTDataset(Dataset):
|
|
| 70 |
app = FastAPI()
|
| 71 |
device = torch.device("cpu") # Hugging Face Spaces์ ๋ฌด๋ฃ ํฐ์ด๋ ์ฃผ๋ก CPU๋ฅผ ์ฌ์ฉํฉ๋๋ค.
|
| 72 |
|
| 73 |
-
# โ
category ๋ก๋
|
| 74 |
try:
|
| 75 |
with open("category.pkl", "rb") as f:
|
| 76 |
category = pickle.load(f)
|
| 77 |
-
|
| 78 |
except FileNotFoundError:
|
| 79 |
-
|
| 80 |
-
sys.exit(1)
|
| 81 |
|
| 82 |
-
# โ
vocab ๋ก๋
|
| 83 |
try:
|
| 84 |
with open("vocab.pkl", "rb") as f:
|
| 85 |
vocab = pickle.load(f)
|
| 86 |
-
|
| 87 |
except FileNotFoundError:
|
| 88 |
-
|
| 89 |
-
sys.exit(1)
|
| 90 |
|
| 91 |
-
# โ
ํ ํฌ๋์ด์ ๋ก๋
|
| 92 |
tokenizer = AutoTokenizer.from_pretrained('skt/kobert-base-v1')
|
| 93 |
-
|
| 94 |
|
| 95 |
# โ
๋ชจ๋ธ ๋ก๋ (Hugging Face Hub์์ ๋ค์ด๋ก๋)
|
| 96 |
try:
|
| 97 |
-
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
|
| 98 |
-
HF_MODEL_FILENAME = "textClassifierModel.pt"
|
| 99 |
|
| 100 |
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
|
| 101 |
-
|
| 102 |
|
| 103 |
-
# --- ์์ ๋ ํต์ฌ ๋ถ๋ถ ---
|
| 104 |
-
# 1. BertModel.from_pretrained๋ฅผ ์ฌ์ฉํ์ฌ ๊ธฐ๋ณธ BERT ๋ชจ๋ธ์ ๋ก๋ํฉ๋๋ค.
|
| 105 |
-
# ์ด๋ ๊ฒ ํ๋ฉด ๋ชจ๋ธ์ ์ํคํ
์ฒ์ ์ฌ์ ํ์ต๋ ๊ฐ์ค์น๊ฐ ๋ก๋๋ฉ๋๋ค.
|
| 106 |
bert_base_model = BertModel.from_pretrained('skt/kobert-base-v1')
|
| 107 |
-
|
| 108 |
-
# 2. BERTClassifier ์ธ์คํด์ค๋ฅผ ์์ฑํฉ๋๋ค.
|
| 109 |
-
# ์ฌ๊ธฐ์ num_classes๋ category ๋์
๋๋ฆฌ์ ํฌ๊ธฐ์ ์ผ์นํด์ผ ํฉ๋๋ค.
|
| 110 |
model = BERTClassifier(
|
| 111 |
bert_base_model,
|
| 112 |
dr_rate=0.5, # ํ์ต ์ ์ฌ์ฉ๋ dr_rate ๊ฐ์ผ๋ก ๋ณ๊ฒฝํ์ธ์.
|
| 113 |
num_classes=len(category)
|
| 114 |
)
|
| 115 |
|
| 116 |
-
# 3. ๋ค์ด๋ก๋๋ ํ์ผ์์ state_dict๋ฅผ ๋ก๋ํฉ๋๋ค.
|
| 117 |
-
# ์ด ํ์ผ์ ์ฌ์ฉ์๋์ ๊ฒฝ๋ํ๋ ๋ชจ๋ธ์ ๊ฐ์ค์น๋ง ํฌํจํ๊ณ ์์ต๋๋ค.
|
| 118 |
loaded_state_dict = torch.load(model_path, map_location=device)
|
| 119 |
|
| 120 |
-
# 4. ๋ก๋๋ state_dict์ ํค๋ฅผ ์กฐ์ ํ๊ณ ๋ชจ๋ธ์ ์ ์ฉํฉ๋๋ค.
|
| 121 |
-
# 'module.' ์ ๋์ฌ๊ฐ ๋ถ์ด์๋ ๊ฒฝ์ฐ ์ ๊ฑฐํ๋ ๋ก์ง์ ํฌํจํฉ๋๋ค.
|
| 122 |
new_state_dict = collections.OrderedDict()
|
| 123 |
for k, v in loaded_state_dict.items():
|
| 124 |
name = k
|
|
@@ -126,19 +118,14 @@ try:
|
|
| 126 |
name = name[7:]
|
| 127 |
new_state_dict[name] = v
|
| 128 |
|
| 129 |
-
# strict=False๋ฅผ ์ฌ์ฉํ์ฌ Missing key(s) ์ค๋ฅ๋ฅผ ๋ฐฉ์งํฉ๋๋ค.
|
| 130 |
-
# ์ด๋ new_state_dict์ ์๋ ํค๋ ๋ชจ๋ธ์์ ๊ธฐ์กด ๊ฐ(from_pretrained๋ก ๋ก๋๋)์ ์ ์งํ๊ณ ,
|
| 131 |
-
# ๋ชจ๋ธ์ ์๋ ํค๋ ๋ฌด์ํ๋๋ก ํฉ๋๋ค.
|
| 132 |
model.load_state_dict(new_state_dict, strict=False)
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
model.eval() # ์ถ๋ก ๋ชจ๋๋ก ์ค์
|
| 137 |
-
print("๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต.")
|
| 138 |
|
| 139 |
except Exception as e:
|
| 140 |
-
|
| 141 |
-
sys.exit(1)
|
| 142 |
|
| 143 |
|
| 144 |
# โ
๋ฐ์ดํฐ์
์์ฑ์ ํ์ํ ํ๋ผ๋ฏธํฐ
|
|
@@ -149,26 +136,32 @@ batch_size = 32
|
|
| 149 |
def predict(predict_sentence):
|
| 150 |
data = [predict_sentence, '0']
|
| 151 |
dataset_another = [data]
|
| 152 |
-
# num_workers๋ ๋ฐฐํฌ ํ๊ฒฝ์์ 0์ผ๋ก ์ค์ ๊ถ์ฅ
|
| 153 |
-
# tokenizer.tokenize๋ฅผ BERTDataset์ ์ ๋ฌํฉ๋๋ค.
|
| 154 |
another_test = BERTDataset(dataset_another, 0, 1, tokenizer.tokenize, vocab, max_len, True, False)
|
| 155 |
test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0)
|
| 156 |
|
| 157 |
-
model.eval()
|
| 158 |
|
| 159 |
-
with torch.no_grad():
|
| 160 |
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader):
|
| 161 |
token_ids = token_ids.long().to(device)
|
| 162 |
segment_ids = segment_ids.long().to(device)
|
| 163 |
|
| 164 |
out = model(token_ids, valid_length, segment_ids)
|
| 165 |
|
| 166 |
-
logits = out
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
predicted_category_index =
|
| 170 |
-
predicted_category_name = list(category.keys())[predicted_category_index]
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
return predicted_category_name
|
| 173 |
|
| 174 |
# โ
์๋ํฌ์ธํธ ์ ์
|
|
|
|
| 5 |
import gluonnlp as nlp
|
| 6 |
import numpy as np
|
| 7 |
import os
|
| 8 |
+
import sys
|
| 9 |
+
import collections
|
| 10 |
+
import logging # ๋ก๊น
๋ชจ๋ ์ํฌํธ
|
| 11 |
|
| 12 |
# transformers์ AutoTokenizer ๋ฐ BertModel ์ํฌํธ
|
| 13 |
+
from transformers import AutoTokenizer, BertModel
|
| 14 |
from torch.utils.data import Dataset, DataLoader
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
+
|
| 17 |
+
# --- ๋ก๊น
์ค์ ---
|
| 18 |
+
# INFO ๋ ๋ฒจ ์ด์์ ๋ก๊ทธ๋ฅผ ์ถ๋ ฅํ๋๋ก ์ค์ ํฉ๋๋ค.
|
| 19 |
+
# ์ค์ ๋ฐฐํฌ ํ๊ฒฝ์์๋ ๋ก๊ทธ ๋ ๋ฒจ์ WARNING์ด๋ ERROR๋ก ๋์ฌ ๋ถํ์ํ ๋ก๊ทธ๋ฅผ ์ค์ผ ์ ์์ต๋๋ค.
|
| 20 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
|
| 23 |
# --- 1. BERTClassifier ๋ชจ๋ธ ํด๋์ค ์ ์ ---
|
|
|
|
| 24 |
class BERTClassifier(torch.nn.Module):
|
| 25 |
def __init__(self,
|
| 26 |
bert,
|
|
|
|
| 54 |
return self.classifier(out)
|
| 55 |
|
| 56 |
# --- 2. BERTDataset ํด๋์ค ์ ์ ---
|
|
|
|
| 57 |
class BERTDataset(Dataset):
|
| 58 |
def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
|
|
|
|
|
|
|
| 59 |
transform = nlp.data.BERTSentenceTransform(
|
| 60 |
bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
|
| 61 |
)
|
|
|
|
| 72 |
app = FastAPI()
|
| 73 |
device = torch.device("cpu") # Hugging Face Spaces์ ๋ฌด๋ฃ ํฐ์ด๋ ์ฃผ๋ก CPU๋ฅผ ์ฌ์ฉํฉ๋๋ค.
|
| 74 |
|
| 75 |
+
# โ
category ๋ก๋
|
| 76 |
try:
|
| 77 |
with open("category.pkl", "rb") as f:
|
| 78 |
category = pickle.load(f)
|
| 79 |
+
logger.info("category.pkl ๋ก๋ ์ฑ๊ณต.")
|
| 80 |
except FileNotFoundError:
|
| 81 |
+
logger.error("Error: category.pkl ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. ํ๋ก์ ํธ ๋ฃจํธ์ ์๋์ง ํ์ธํ์ธ์.")
|
| 82 |
+
sys.exit(1)
|
| 83 |
|
| 84 |
+
# โ
vocab ๋ก๋
|
| 85 |
try:
|
| 86 |
with open("vocab.pkl", "rb") as f:
|
| 87 |
vocab = pickle.load(f)
|
| 88 |
+
logger.info("vocab.pkl ๋ก๋ ์ฑ๊ณต.")
|
| 89 |
except FileNotFoundError:
|
| 90 |
+
logger.error("Error: vocab.pkl ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. ํ๋ก์ ํธ ๋ฃจํธ์ ์๋์ง ํ์ธํ์ธ์.")
|
| 91 |
+
sys.exit(1)
|
| 92 |
|
| 93 |
+
# โ
ํ ํฌ๋์ด์ ๋ก๋
|
| 94 |
tokenizer = AutoTokenizer.from_pretrained('skt/kobert-base-v1')
|
| 95 |
+
logger.info("ํ ํฌ๋์ด์ ๋ก๋ ์ฑ๊ณต.")
|
| 96 |
|
| 97 |
# โ
๋ชจ๋ธ ๋ก๋ (Hugging Face Hub์์ ๋ค์ด๋ก๋)
|
| 98 |
try:
|
| 99 |
+
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
|
| 100 |
+
HF_MODEL_FILENAME = "textClassifierModel.pt"
|
| 101 |
|
| 102 |
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
|
| 103 |
+
logger.info(f"๋ชจ๋ธ ํ์ผ์ด '{model_path}'์ ์ฑ๊ณต์ ์ผ๋ก ๋ค์ด๋ก๋๋์์ต๋๋ค.")
|
| 104 |
|
|
|
|
|
|
|
|
|
|
| 105 |
bert_base_model = BertModel.from_pretrained('skt/kobert-base-v1')
|
|
|
|
|
|
|
|
|
|
| 106 |
model = BERTClassifier(
|
| 107 |
bert_base_model,
|
| 108 |
dr_rate=0.5, # ํ์ต ์ ์ฌ์ฉ๋ dr_rate ๊ฐ์ผ๋ก ๋ณ๊ฒฝํ์ธ์.
|
| 109 |
num_classes=len(category)
|
| 110 |
)
|
| 111 |
|
|
|
|
|
|
|
| 112 |
loaded_state_dict = torch.load(model_path, map_location=device)
|
| 113 |
|
|
|
|
|
|
|
| 114 |
new_state_dict = collections.OrderedDict()
|
| 115 |
for k, v in loaded_state_dict.items():
|
| 116 |
name = k
|
|
|
|
| 118 |
name = name[7:]
|
| 119 |
new_state_dict[name] = v
|
| 120 |
|
|
|
|
|
|
|
|
|
|
| 121 |
model.load_state_dict(new_state_dict, strict=False)
|
| 122 |
+
model.to(device)
|
| 123 |
+
model.eval()
|
| 124 |
+
logger.info("๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต.")
|
|
|
|
|
|
|
| 125 |
|
| 126 |
except Exception as e:
|
| 127 |
+
logger.error(f"Error: ๋ชจ๋ธ ๋ค์ด๋ก๋ ๋๋ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 128 |
+
sys.exit(1)
|
| 129 |
|
| 130 |
|
| 131 |
# โ
๋ฐ์ดํฐ์
์์ฑ์ ํ์ํ ํ๋ผ๋ฏธํฐ
|
|
|
|
| 136 |
def predict(predict_sentence):
|
| 137 |
data = [predict_sentence, '0']
|
| 138 |
dataset_another = [data]
|
|
|
|
|
|
|
| 139 |
another_test = BERTDataset(dataset_another, 0, 1, tokenizer.tokenize, vocab, max_len, True, False)
|
| 140 |
test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0)
|
| 141 |
|
| 142 |
+
model.eval()
|
| 143 |
|
| 144 |
+
with torch.no_grad():
|
| 145 |
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader):
|
| 146 |
token_ids = token_ids.long().to(device)
|
| 147 |
segment_ids = segment_ids.long().to(device)
|
| 148 |
|
| 149 |
out = model(token_ids, valid_length, segment_ids)
|
| 150 |
|
| 151 |
+
logits = out # ๋ชจ๋ธ์ ์ง์ ์ถ๋ ฅ์ ๋ก์ง์
๋๋ค.
|
| 152 |
+
probs = torch.nn.functional.softmax(logits, dim=1) # ํ๋ฅ ๊ณ์ฐ
|
| 153 |
+
|
| 154 |
+
predicted_category_index = torch.argmax(probs, dim=1).item() # ์์ธก ์ธ๋ฑ์ค
|
| 155 |
+
predicted_category_name = list(category.keys())[predicted_category_index] # ์์ธก ์นดํ
๊ณ ๋ฆฌ ์ด๋ฆ
|
| 156 |
|
| 157 |
+
# --- ์์ธก ์์ธ ๋ก๊น
---
|
| 158 |
+
logger.info(f"Input Text: '{predict_sentence}'")
|
| 159 |
+
logger.info(f"Raw Logits: {logits.tolist()}")
|
| 160 |
+
logger.info(f"Probabilities: {probs.tolist()}")
|
| 161 |
+
logger.info(f"Predicted Index: {predicted_category_index}")
|
| 162 |
+
logger.info(f"Predicted Label: '{predicted_category_name}'")
|
| 163 |
+
# --- ์์ธก ์์ธ ๋ก๊น
๋ ---
|
| 164 |
+
|
| 165 |
return predicted_category_name
|
| 166 |
|
| 167 |
# โ
์๋ํฌ์ธํธ ์ ์
|