hiddenFront commited on
Commit
44d2bcd
ยท
verified ยท
1 Parent(s): 58648d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -51
app.py CHANGED
@@ -5,17 +5,22 @@ import pickle
5
  import gluonnlp as nlp
6
  import numpy as np
7
  import os
8
- import sys # ์˜ค๋ฅ˜ ์‹œ ์„œ๋น„์Šค ์ข…๋ฃŒ๋ฅผ ์œ„ํ•ด sys ๋ชจ๋“ˆ ์ž„ํฌํŠธ
 
 
9
 
10
  # transformers์˜ AutoTokenizer ๋ฐ BertModel ์ž„ํฌํŠธ
11
- from transformers import AutoTokenizer, BertModel # BertModel ์ž„ํฌํŠธ ์ถ”๊ฐ€
12
  from torch.utils.data import Dataset, DataLoader
13
- import logging # ๋กœ๊น… ๋ชจ๋“ˆ ์ž„ํฌํŠธ ์œ ์ง€
14
- from huggingface_hub import hf_hub_download # hf_hub_download ์ž„ํฌํŠธ ์œ ์ง€
15
- import collections # collections ๋ชจ๋“ˆ ์ž„ํฌํŠธ ์œ ์ง€
 
 
 
 
16
 
17
  # --- 1. BERTClassifier ๋ชจ๋ธ ํด๋ž˜์Šค ์ •์˜ ---
18
- # ์ด ํด๋ž˜์Šค๋Š” ๋ชจ๋ธ์˜ ์•„ํ‚คํ…์ฒ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
19
  class BERTClassifier(torch.nn.Module):
20
  def __init__(self,
21
  bert,
@@ -49,11 +54,8 @@ class BERTClassifier(torch.nn.Module):
49
  return self.classifier(out)
50
 
51
  # --- 2. BERTDataset ํด๋ž˜์Šค ์ •์˜ ---
52
- # ์ด ํด๋ž˜์Šค๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ๋ชจ๋ธ ์ž…๋ ฅ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
53
  class BERTDataset(Dataset):
54
  def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
55
- # nlp.data.BERTSentenceTransform์€ ํ† ํฌ๋‚˜์ด์ € ํ•จ์ˆ˜๋ฅผ ๋ฐ›์Šต๋‹ˆ๋‹ค.
56
- # AutoTokenizer์˜ tokenize ๋ฉ”์„œ๋“œ๋ฅผ ์ง์ ‘ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
57
  transform = nlp.data.BERTSentenceTransform(
58
  bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
59
  )
@@ -70,55 +72,45 @@ class BERTDataset(Dataset):
70
  app = FastAPI()
71
  device = torch.device("cpu") # Hugging Face Spaces์˜ ๋ฌด๋ฃŒ ํ‹ฐ์–ด๋Š” ์ฃผ๋กœ CPU๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
72
 
73
- # โœ… category ๋กœ๋“œ (GitHub ์ €์žฅ์†Œ ๋ฃจํŠธ์— ์žˆ์–ด์•ผ ํ•จ)
74
  try:
75
  with open("category.pkl", "rb") as f:
76
  category = pickle.load(f)
77
- print("category.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
78
  except FileNotFoundError:
79
- print("Error: category.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
80
- sys.exit(1) # ํŒŒ์ผ ์—†์œผ๋ฉด ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜์ง€ ์•Š์Œ
81
 
82
- # โœ… vocab ๋กœ๋“œ (GitHub ์ €์žฅ์†Œ ๋ฃจํŠธ์— ์žˆ์–ด์•ผ ํ•จ)
83
  try:
84
  with open("vocab.pkl", "rb") as f:
85
  vocab = pickle.load(f)
86
- print("vocab.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
87
  except FileNotFoundError:
88
- print("Error: vocab.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
89
- sys.exit(1) # ํŒŒ์ผ ์—†์œผ๋ฉด ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜์ง€ ์•Š์Œ
90
 
91
- # โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ (transformers.AutoTokenizer ์‚ฌ์šฉ)
92
  tokenizer = AutoTokenizer.from_pretrained('skt/kobert-base-v1')
93
- print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
94
 
95
  # โœ… ๋ชจ๋ธ ๋กœ๋“œ (Hugging Face Hub์—์„œ ๋‹ค์šด๋กœ๋“œ)
96
  try:
97
- HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" # ์‚ฌ์šฉ์ž๋‹˜์˜ ์‹ค์ œ Hugging Face ์ €์žฅ์†Œ ID
98
- HF_MODEL_FILENAME = "textClassifierModel.pt" # Hugging Face Hub์— ์—…๋กœ๋“œํ•œ ํŒŒ์ผ ์ด๋ฆ„๊ณผ ์ผ์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
99
 
100
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
101
- print(f"๋ชจ๋ธ ํŒŒ์ผ์ด '{model_path}'์— ์„ฑ๊ณต์ ์œผ๋กœ ๋‹ค์šด๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
102
 
103
- # --- ์ˆ˜์ •๋œ ํ•ต์‹ฌ ๋ถ€๋ถ„ ---
104
- # 1. BertModel.from_pretrained๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ธฐ๋ณธ BERT ๋ชจ๋ธ์„ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
105
- # ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ๋ชจ๋ธ์˜ ์•„ํ‚คํ…์ฒ˜์™€ ์‚ฌ์ „ ํ•™์Šต๋œ ๊ฐ€์ค‘์น˜๊ฐ€ ๋กœ๋“œ๋ฉ๋‹ˆ๋‹ค.
106
  bert_base_model = BertModel.from_pretrained('skt/kobert-base-v1')
107
-
108
- # 2. BERTClassifier ์ธ์Šคํ„ด์Šค๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
109
- # ์—ฌ๊ธฐ์„œ num_classes๋Š” category ๋”•์…”๋„ˆ๋ฆฌ์˜ ํฌ๊ธฐ์™€ ์ผ์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
110
  model = BERTClassifier(
111
  bert_base_model,
112
  dr_rate=0.5, # ํ•™์Šต ์‹œ ์‚ฌ์šฉ๋œ dr_rate ๊ฐ’์œผ๋กœ ๋ณ€๊ฒฝํ•˜์„ธ์š”.
113
  num_classes=len(category)
114
  )
115
 
116
- # 3. ๋‹ค์šด๋กœ๋“œ๋œ ํŒŒ์ผ์—์„œ state_dict๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
117
- # ์ด ํŒŒ์ผ์€ ์‚ฌ์šฉ์ž๋‹˜์˜ ๊ฒฝ๋Ÿ‰ํ™”๋œ ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋งŒ ํฌํ•จํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
118
  loaded_state_dict = torch.load(model_path, map_location=device)
119
 
120
- # 4. ๋กœ๋“œ๋œ state_dict์˜ ํ‚ค๋ฅผ ์กฐ์ •ํ•˜๊ณ  ๋ชจ๋ธ์— ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.
121
- # 'module.' ์ ‘๋‘์‚ฌ๊ฐ€ ๋ถ™์–ด์žˆ๋Š” ๊ฒฝ์šฐ ์ œ๊ฑฐํ•˜๋Š” ๋กœ์ง์„ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.
122
  new_state_dict = collections.OrderedDict()
123
  for k, v in loaded_state_dict.items():
124
  name = k
@@ -126,19 +118,14 @@ try:
126
  name = name[7:]
127
  new_state_dict[name] = v
128
 
129
- # strict=False๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ Missing key(s) ์˜ค๋ฅ˜๋ฅผ ๋ฐฉ์ง€ํ•ฉ๋‹ˆ๋‹ค.
130
- # ์ด๋Š” new_state_dict์— ์—†๋Š” ํ‚ค๋Š” ๋ชจ๋ธ์—์„œ ๊ธฐ์กด ๊ฐ’(from_pretrained๋กœ ๋กœ๋“œ๋œ)์„ ์œ ์ง€ํ•˜๊ณ ,
131
- # ๋ชจ๋ธ์— ์—†๋Š” ํ‚ค๋Š” ๋ฌด์‹œํ•˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.
132
  model.load_state_dict(new_state_dict, strict=False)
133
- # --- ์ˆ˜์ •๋œ ํ•ต์‹ฌ ๋ถ€๋ถ„ ๋ ---
134
-
135
- model.to(device) # ๋ชจ๋ธ์„ ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™
136
- model.eval() # ์ถ”๋ก  ๋ชจ๋“œ๋กœ ์„ค์ •
137
- print("๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต.")
138
 
139
  except Exception as e:
140
- print(f"Error: ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ๋˜๋Š” ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
141
- sys.exit(1) # ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ ์‹œ ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜์ง€ ์•Š์Œ
142
 
143
 
144
  # โœ… ๋ฐ์ดํ„ฐ์…‹ ์ƒ์„ฑ์— ํ•„์š”ํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ
@@ -149,26 +136,32 @@ batch_size = 32
149
  def predict(predict_sentence):
150
  data = [predict_sentence, '0']
151
  dataset_another = [data]
152
- # num_workers๋Š” ๋ฐฐํฌ ํ™˜๊ฒฝ์—์„œ 0์œผ๋กœ ์„ค์ • ๊ถŒ์žฅ
153
- # tokenizer.tokenize๋ฅผ BERTDataset์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
154
  another_test = BERTDataset(dataset_another, 0, 1, tokenizer.tokenize, vocab, max_len, True, False)
155
  test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0)
156
 
157
- model.eval() # ์˜ˆ์ธก ์‹œ ๋ชจ๋ธ์„ ํ‰๊ฐ€ ๋ชจ๋“œ๋กœ ์„ค์ •
158
 
159
- with torch.no_grad(): # ๊ทธ๋ผ๋””์–ธํŠธ ๊ณ„์‚ฐ ๋น„ํ™œ์„ฑํ™”
160
  for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader):
161
  token_ids = token_ids.long().to(device)
162
  segment_ids = segment_ids.long().to(device)
163
 
164
  out = model(token_ids, valid_length, segment_ids)
165
 
166
- logits = out
167
- logits = logits.detach().cpu().numpy()
168
-
169
- predicted_category_index = np.argmax(logits)
170
- predicted_category_name = list(category.keys())[predicted_category_index]
171
 
 
 
 
 
 
 
 
 
172
  return predicted_category_name
173
 
174
  # โœ… ์—”๋“œํฌ์ธํŠธ ์ •์˜
 
5
  import gluonnlp as nlp
6
  import numpy as np
7
  import os
8
+ import sys
9
+ import collections
10
+ import logging # ๋กœ๊น… ๋ชจ๋“ˆ ์ž„ํฌํŠธ
11
 
12
  # transformers์˜ AutoTokenizer ๋ฐ BertModel ์ž„ํฌํŠธ
13
+ from transformers import AutoTokenizer, BertModel
14
  from torch.utils.data import Dataset, DataLoader
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ # --- ๋กœ๊น… ์„ค์ • ---
18
+ # INFO ๋ ˆ๋ฒจ ์ด์ƒ์˜ ๋กœ๊ทธ๋ฅผ ์ถœ๋ ฅํ•˜๋„๋ก ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
19
+ # ์‹ค์ œ ๋ฐฐํฌ ํ™˜๊ฒฝ์—์„œ๋Š” ๋กœ๊ทธ ๋ ˆ๋ฒจ์„ WARNING์ด๋‚˜ ERROR๋กœ ๋†’์—ฌ ๋ถˆํ•„์š”ํ•œ ๋กœ๊ทธ๋ฅผ ์ค„์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
20
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
21
+ logger = logging.getLogger(__name__)
22
 
23
  # --- 1. BERTClassifier ๋ชจ๋ธ ํด๋ž˜์Šค ์ •์˜ ---
 
24
  class BERTClassifier(torch.nn.Module):
25
  def __init__(self,
26
  bert,
 
54
  return self.classifier(out)
55
 
56
  # --- 2. BERTDataset ํด๋ž˜์Šค ์ •์˜ ---
 
57
  class BERTDataset(Dataset):
58
  def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
 
 
59
  transform = nlp.data.BERTSentenceTransform(
60
  bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
61
  )
 
72
  app = FastAPI()
73
  device = torch.device("cpu") # Hugging Face Spaces์˜ ๋ฌด๋ฃŒ ํ‹ฐ์–ด๋Š” ์ฃผ๋กœ CPU๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
74
 
75
+ # โœ… category ๋กœ๋“œ
76
  try:
77
  with open("category.pkl", "rb") as f:
78
  category = pickle.load(f)
79
+ logger.info("category.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
80
  except FileNotFoundError:
81
+ logger.error("Error: category.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
82
+ sys.exit(1)
83
 
84
+ # โœ… vocab ๋กœ๋“œ
85
  try:
86
  with open("vocab.pkl", "rb") as f:
87
  vocab = pickle.load(f)
88
+ logger.info("vocab.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
89
  except FileNotFoundError:
90
+ logger.error("Error: vocab.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
91
+ sys.exit(1)
92
 
93
+ # โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
94
  tokenizer = AutoTokenizer.from_pretrained('skt/kobert-base-v1')
95
+ logger.info("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
96
 
97
  # โœ… ๋ชจ๋ธ ๋กœ๋“œ (Hugging Face Hub์—์„œ ๋‹ค์šด๋กœ๋“œ)
98
  try:
99
+ HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
100
+ HF_MODEL_FILENAME = "textClassifierModel.pt"
101
 
102
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
103
+ logger.info(f"๋ชจ๋ธ ํŒŒ์ผ์ด '{model_path}'์— ์„ฑ๊ณต์ ์œผ๋กœ ๋‹ค์šด๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
104
 
 
 
 
105
  bert_base_model = BertModel.from_pretrained('skt/kobert-base-v1')
 
 
 
106
  model = BERTClassifier(
107
  bert_base_model,
108
  dr_rate=0.5, # ํ•™์Šต ์‹œ ์‚ฌ์šฉ๋œ dr_rate ๊ฐ’์œผ๋กœ ๋ณ€๊ฒฝํ•˜์„ธ์š”.
109
  num_classes=len(category)
110
  )
111
 
 
 
112
  loaded_state_dict = torch.load(model_path, map_location=device)
113
 
 
 
114
  new_state_dict = collections.OrderedDict()
115
  for k, v in loaded_state_dict.items():
116
  name = k
 
118
  name = name[7:]
119
  new_state_dict[name] = v
120
 
 
 
 
121
  model.load_state_dict(new_state_dict, strict=False)
122
+ model.to(device)
123
+ model.eval()
124
+ logger.info("๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต.")
 
 
125
 
126
  except Exception as e:
127
+ logger.error(f"Error: ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ๋˜๋Š” ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
128
+ sys.exit(1)
129
 
130
 
131
  # โœ… ๋ฐ์ดํ„ฐ์…‹ ์ƒ์„ฑ์— ํ•„์š”ํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ
 
136
  def predict(predict_sentence):
137
  data = [predict_sentence, '0']
138
  dataset_another = [data]
 
 
139
  another_test = BERTDataset(dataset_another, 0, 1, tokenizer.tokenize, vocab, max_len, True, False)
140
  test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0)
141
 
142
+ model.eval()
143
 
144
+ with torch.no_grad():
145
  for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader):
146
  token_ids = token_ids.long().to(device)
147
  segment_ids = segment_ids.long().to(device)
148
 
149
  out = model(token_ids, valid_length, segment_ids)
150
 
151
+ logits = out # ๋ชจ๋ธ์˜ ์ง์ ‘ ์ถœ๋ ฅ์€ ๋กœ์ง“์ž…๋‹ˆ๋‹ค.
152
+ probs = torch.nn.functional.softmax(logits, dim=1) # ํ™•๋ฅ  ๊ณ„์‚ฐ
153
+
154
+ predicted_category_index = torch.argmax(probs, dim=1).item() # ์˜ˆ์ธก ์ธ๋ฑ์Šค
155
+ predicted_category_name = list(category.keys())[predicted_category_index] # ์˜ˆ์ธก ์นดํ…Œ๊ณ ๋ฆฌ ์ด๋ฆ„
156
 
157
+ # --- ์˜ˆ์ธก ์ƒ์„ธ ๋กœ๊น… ---
158
+ logger.info(f"Input Text: '{predict_sentence}'")
159
+ logger.info(f"Raw Logits: {logits.tolist()}")
160
+ logger.info(f"Probabilities: {probs.tolist()}")
161
+ logger.info(f"Predicted Index: {predicted_category_index}")
162
+ logger.info(f"Predicted Label: '{predicted_category_name}'")
163
+ # --- ์˜ˆ์ธก ์ƒ์„ธ ๋กœ๊น… ๋ ---
164
+
165
  return predicted_category_name
166
 
167
  # โœ… ์—”๋“œํฌ์ธํŠธ ์ •์˜