Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,13 +5,16 @@ import pickle
|
|
| 5 |
import gluonnlp as nlp
|
| 6 |
import numpy as np
|
| 7 |
import os
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
| 11 |
import logging # ๋ก๊น
๋ชจ๋ ์ํฌํธ ์ ์ง
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# --- 1. BERTClassifier ๋ชจ๋ธ ํด๋์ค ์ ์ (model.py์์ ์ฎ๊ฒจ์ด) ---
|
| 14 |
-
# ์ด ํด๋์ค๋ ๋ชจ๋ธ์ ์ํคํ
์ฒ๋ฅผ ์ ์ํฉ๋๋ค.
|
| 15 |
class BERTClassifier(torch.nn.Module):
|
| 16 |
def __init__(self,
|
| 17 |
bert,
|
|
@@ -36,9 +39,6 @@ class BERTClassifier(torch.nn.Module):
|
|
| 36 |
def forward(self, token_ids, valid_length, segment_ids):
|
| 37 |
attention_mask = self.gen_attention_mask(token_ids, valid_length)
|
| 38 |
|
| 39 |
-
# BertModel์ ์ถ๋ ฅ ๊ตฌ์กฐ์ ๋ฐ๋ผ ์์
|
| 40 |
-
# Hugging Face Transformers์ BertModel์ (last_hidden_state, pooler_output, ...) ๋ฐํ
|
| 41 |
-
# pooler_output (CLS ํ ํฐ์ ์ต์ข
์๋ ์ํ๋ฅผ ํต๊ณผํ ๊ฒฐ๊ณผ) ์ฌ์ฉ
|
| 42 |
_, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(), attention_mask=attention_mask.float().to(token_ids.device), return_dict=False)
|
| 43 |
|
| 44 |
if self.dr_rate:
|
|
@@ -48,9 +48,10 @@ class BERTClassifier(torch.nn.Module):
|
|
| 48 |
return self.classifier(out)
|
| 49 |
|
| 50 |
# --- 2. BERTDataset ํด๋์ค ์ ์ (dataset.py์์ ์ฎ๊ฒจ์ด) ---
|
| 51 |
-
# ์ด ํด๋์ค๋ ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ธ ์
๋ ฅ ํ์์ผ๋ก ๋ณํํฉ๋๋ค.
|
| 52 |
class BERTDataset(Dataset):
|
| 53 |
def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
|
|
|
|
|
|
|
| 54 |
transform = nlp.data.BERTSentenceTransform(
|
| 55 |
bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
|
| 56 |
)
|
|
@@ -85,38 +86,30 @@ except FileNotFoundError:
|
|
| 85 |
print("Error: vocab.pkl ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. ํ๋ก์ ํธ ๋ฃจํธ์ ์๋์ง ํ์ธํ์ธ์.")
|
| 86 |
sys.exit(1) # ํ์ผ ์์ผ๋ฉด ์๋น์ค ์์ํ์ง ์์
|
| 87 |
|
| 88 |
-
# โ
ํ ํฌ๋์ด์ ๋ก๋ (
|
| 89 |
-
#
|
| 90 |
-
|
|
|
|
| 91 |
print("ํ ํฌ๋์ด์ ๋ก๋ ์ฑ๊ณต.")
|
| 92 |
|
| 93 |
# โ
๋ชจ๋ธ ๋ก๋
|
| 94 |
-
# ๋ชจ๋ธ ์ํคํ
์ฒ๋ฅผ ์ ์ํ๊ณ , ์ ์ฅ๋ state_dict๋ฅผ ๋ก๋ํฉ๋๋ค.
|
| 95 |
# num_classes๋ category ๋์
๋๋ฆฌ์ ํฌ๊ธฐ์ ์ผ์นํด์ผ ํฉ๋๋ค.
|
|
|
|
| 96 |
model = BERTClassifier(
|
| 97 |
-
|
| 98 |
dr_rate=0.5, # ํ์ต ์ ์ฌ์ฉ๋ dr_rate ๊ฐ์ผ๋ก ๋ณ๊ฒฝํ์ธ์.
|
| 99 |
num_classes=len(category)
|
| 100 |
)
|
| 101 |
|
| 102 |
# textClassifierModel.pt ํ์ผ ๋ก๋
|
| 103 |
-
# ์ด ํ์ผ์ GitHub ์ ์ฅ์์ ์์ด์ผ ํ๋ฉฐ, Dockerfile์์ Hugging Face Hub์์ ๋ค์ด๋ก๋ํ๋๋ก ์ค์ ๋์ด ์์ต๋๋ค.
|
| 104 |
try:
|
| 105 |
-
# Dockerfile์์ ๋ชจ๋ธ์ ๋ค์ด๋ก๋ํ๋๋ก ์ค์ ํ์ผ๋ฏ๋ก, ์ฌ๊ธฐ์๋ ๋ก์ปฌ ๊ฒฝ๋ก๋ฅผ ์ฌ์ฉํฉ๋๋ค.
|
| 106 |
-
# ๋ง์ฝ Dockerfile์์ hf_hub_download๋ฅผ ์ฌ์ฉํ์ง ์๋๋ค๋ฉด, ์ฌ๊ธฐ์ hf_hub_download๋ฅผ ์ถ๊ฐํด์ผ ํฉ๋๋ค.
|
| 107 |
-
# ํ์ฌ Dockerfile์ git+https://github.com/SKTBrain/KOBERT#egg=kobert_tokenizer ๋ก๋๋ง ํฌํจํ๊ณ ,
|
| 108 |
-
# ๋ชจ๋ธ ํ์ผ ๋ค์ด๋ก๋๋ ํฌํจํ์ง ์์ต๋๋ค.
|
| 109 |
-
# ๋ฐ๋ผ์, ๋ชจ๋ธ ํ์ผ์ Hugging Face Hub์์ ๋ค์ด๋ก๋ํ๋ ๋ก์ง์ ๋ค์ ์ถ๊ฐํด์ผ ํฉ๋๋ค.
|
| 110 |
-
from huggingface_hub import hf_hub_download
|
| 111 |
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" # ์ฌ์ฉ์๋์ ์ค์ Hugging Face ์ ์ฅ์ ID
|
| 112 |
HF_MODEL_FILENAME = "textClassifierModel.pt"
|
| 113 |
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
|
| 114 |
print(f"๋ชจ๋ธ ํ์ผ์ด '{model_path}'์ ์ฑ๊ณต์ ์ผ๋ก ๋ค์ด๋ก๋๋์์ต๋๋ค.")
|
| 115 |
|
| 116 |
-
# ๋ชจ๋ธ์ state_dict๋ฅผ ๋ก๋ํฉ๋๋ค.
|
| 117 |
loaded_state_dict = torch.load(model_path, map_location=device)
|
| 118 |
|
| 119 |
-
# state_dict ํค ์กฐ์ (ํ์ํ ๊ฒฝ์ฐ)
|
| 120 |
new_state_dict = collections.OrderedDict()
|
| 121 |
for k, v in loaded_state_dict.items():
|
| 122 |
name = k
|
|
@@ -143,7 +136,7 @@ def predict(predict_sentence):
|
|
| 143 |
data = [predict_sentence, '0']
|
| 144 |
dataset_another = [data]
|
| 145 |
# num_workers๋ ๋ฐฐํฌ ํ๊ฒฝ์์ 0์ผ๋ก ์ค์ ๊ถ์ฅ
|
| 146 |
-
another_test = BERTDataset(dataset_another, 0, 1, tokenizer, vocab, max_len, True, False)
|
| 147 |
test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0)
|
| 148 |
|
| 149 |
model.eval() # ์์ธก ์ ๋ชจ๋ธ์ ํ๊ฐ ๋ชจ๋๋ก ์ค์
|
|
|
|
| 5 |
import gluonnlp as nlp
|
| 6 |
import numpy as np
|
| 7 |
import os
|
| 8 |
+
import sys # sys ๋ชจ๋ ์ํฌํธ ์ถ๊ฐ (NameError ํด๊ฒฐ)
|
| 9 |
+
|
| 10 |
+
# KoBERTTokenizer ๋์ transformers.AutoTokenizer ์ฌ์ฉ
|
| 11 |
+
from transformers import BertModel, AutoTokenizer # AutoTokenizer ์ํฌํธ ์ ์ง
|
| 12 |
+
from torch.utils.data import Dataset, DataLoader
|
| 13 |
import logging # ๋ก๊น
๋ชจ๋ ์ํฌํธ ์ ์ง
|
| 14 |
+
from huggingface_hub import hf_hub_download # hf_hub_download ์ํฌํธ ์ถ๊ฐ
|
| 15 |
+
import collections # collections ๋ชจ๋ ์ํฌํธ ์ ์ง
|
| 16 |
|
| 17 |
# --- 1. BERTClassifier ๋ชจ๋ธ ํด๋์ค ์ ์ (model.py์์ ์ฎ๊ฒจ์ด) ---
|
|
|
|
| 18 |
class BERTClassifier(torch.nn.Module):
|
| 19 |
def __init__(self,
|
| 20 |
bert,
|
|
|
|
| 39 |
def forward(self, token_ids, valid_length, segment_ids):
|
| 40 |
attention_mask = self.gen_attention_mask(token_ids, valid_length)
|
| 41 |
|
|
|
|
|
|
|
|
|
|
| 42 |
_, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(), attention_mask=attention_mask.float().to(token_ids.device), return_dict=False)
|
| 43 |
|
| 44 |
if self.dr_rate:
|
|
|
|
| 48 |
return self.classifier(out)
|
| 49 |
|
| 50 |
# --- 2. BERTDataset ํด๋์ค ์ ์ (dataset.py์์ ์ฎ๊ฒจ์ด) ---
|
|
|
|
| 51 |
class BERTDataset(Dataset):
|
| 52 |
def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
|
| 53 |
+
# nlp.data.BERTSentenceTransform์ ํ ํฌ๋์ด์ ํจ์๋ฅผ ๋ฐ์ต๋๋ค.
|
| 54 |
+
# AutoTokenizer์ tokenize ๋ฉ์๋๋ฅผ ์ง์ ์ ๋ฌํฉ๋๋ค.
|
| 55 |
transform = nlp.data.BERTSentenceTransform(
|
| 56 |
bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
|
| 57 |
)
|
|
|
|
| 86 |
print("Error: vocab.pkl ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. ํ๋ก์ ํธ ๋ฃจํธ์ ์๋์ง ํ์ธํ์ธ์.")
|
| 87 |
sys.exit(1) # ํ์ผ ์์ผ๋ฉด ์๋น์ค ์์ํ์ง ์์
|
| 88 |
|
| 89 |
+
# โ
ํ ํฌ๋์ด์ ๋ก๋ (transformers.AutoTokenizer ์ฌ์ฉ)
|
| 90 |
+
# KoBERTTokenizer ๋์ AutoTokenizer๋ฅผ ์ฌ์ฉํ์ฌ KoBERT ๋ชจ๋ธ์ ํ ํฌ๋์ด์ ๋ฅผ ๋ก๋ํฉ๋๋ค.
|
| 91 |
+
# ์ด๋ ๊ฒ ํ๋ฉด XLNetTokenizer ๊ฒฝ๊ณ ๋ฐ kobert_tokenizer ์ค์น ๋ฌธ์ ๋ฅผ ํผํ ์ ์์ต๋๋ค.
|
| 92 |
+
tokenizer = AutoTokenizer.from_pretrained('skt/kobert-base-v1')
|
| 93 |
print("ํ ํฌ๋์ด์ ๋ก๋ ์ฑ๊ณต.")
|
| 94 |
|
| 95 |
# โ
๋ชจ๋ธ ๋ก๋
|
|
|
|
| 96 |
# num_classes๋ category ๋์
๋๋ฆฌ์ ํฌ๊ธฐ์ ์ผ์นํด์ผ ํฉ๋๋ค.
|
| 97 |
+
bertmodel = BertModel.from_pretrained('skt/kobert-base-v1')
|
| 98 |
model = BERTClassifier(
|
| 99 |
+
bertmodel,
|
| 100 |
dr_rate=0.5, # ํ์ต ์ ์ฌ์ฉ๋ dr_rate ๊ฐ์ผ๋ก ๋ณ๊ฒฝํ์ธ์.
|
| 101 |
num_classes=len(category)
|
| 102 |
)
|
| 103 |
|
| 104 |
# textClassifierModel.pt ํ์ผ ๋ก๋
|
|
|
|
| 105 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" # ์ฌ์ฉ์๋์ ์ค์ Hugging Face ์ ์ฅ์ ID
|
| 107 |
HF_MODEL_FILENAME = "textClassifierModel.pt"
|
| 108 |
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
|
| 109 |
print(f"๋ชจ๋ธ ํ์ผ์ด '{model_path}'์ ์ฑ๊ณต์ ์ผ๋ก ๋ค์ด๋ก๋๋์์ต๋๋ค.")
|
| 110 |
|
|
|
|
| 111 |
loaded_state_dict = torch.load(model_path, map_location=device)
|
| 112 |
|
|
|
|
| 113 |
new_state_dict = collections.OrderedDict()
|
| 114 |
for k, v in loaded_state_dict.items():
|
| 115 |
name = k
|
|
|
|
| 136 |
data = [predict_sentence, '0']
|
| 137 |
dataset_another = [data]
|
| 138 |
# num_workers๋ ๋ฐฐํฌ ํ๊ฒฝ์์ 0์ผ๋ก ์ค์ ๊ถ์ฅ
|
| 139 |
+
another_test = BERTDataset(dataset_another, 0, 1, tokenizer, vocab, max_len, True, False) # tokenizer ๊ฐ์ฒด ์ง์ ์ ๋ฌ
|
| 140 |
test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0)
|
| 141 |
|
| 142 |
model.eval() # ์์ธก ์ ๋ชจ๋ธ์ ํ๊ฐ ๋ชจ๋๋ก ์ค์
|