hiddenFront commited on
Commit
8153817
ยท
verified ยท
1 Parent(s): 041f67a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -28
app.py CHANGED
@@ -5,35 +5,134 @@ import pickle
5
  import gluonnlp as nlp
6
  import numpy as np
7
  import os
8
- from kobert_tokenizer import KoBERTTokenizer
9
- from model import BERTClassifier
10
- from dataset import BERTDataset
11
- from transformers import BertModel
12
- import logging
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  app = FastAPI()
15
- device = torch.device("cpu")
16
 
17
- # โœ… category ๋กœ๋“œ
18
- with open("category.pkl", "rb") as f:
19
- category = pickle.load(f)
 
 
 
 
 
20
 
21
- # โœ… vocab ๋กœ๋“œ
22
- with open("vocab.pkl", "rb") as f:
23
- vocab = pickle.load(f)
 
 
 
 
 
24
 
25
- # โœ… ํ† ํฌ๋‚˜์ด์ €
 
26
  tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1')
 
27
 
28
  # โœ… ๋ชจ๋ธ ๋กœ๋“œ
 
 
29
  model = BERTClassifier(
30
  BertModel.from_pretrained('skt/kobert-base-v1'),
31
- dr_rate=0.5,
32
  num_classes=len(category)
33
  )
34
- model.load_state_dict(torch.load("textClassifierModel.pt", map_location=device))
35
- model.to(device)
36
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # โœ… ๋ฐ์ดํ„ฐ์…‹ ์ƒ์„ฑ์— ํ•„์š”ํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ
39
  max_len = 64
@@ -43,20 +142,26 @@ batch_size = 32
43
  def predict(predict_sentence):
44
  data = [predict_sentence, '0']
45
  dataset_another = [data]
 
46
  another_test = BERTDataset(dataset_another, 0, 1, tokenizer, vocab, max_len, True, False)
47
- test_dataLoader = torch.utils.data.DataLoader(another_test, batch_size=batch_size, num_workers=0)
48
 
49
- model.eval()
50
- for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader):
51
- token_ids = token_ids.long().to(device)
52
- segment_ids = segment_ids.long().to(device)
53
 
54
- out = model(token_ids, valid_length, segment_ids)
55
- test_eval = []
56
- for i in out:
57
- logits = i.detach().cpu().numpy()
58
- test_eval.append(list(category.keys())[np.argmax(logits)])
59
- return test_eval[0]
 
 
 
 
 
 
 
 
60
 
61
  # โœ… ์—”๋“œํฌ์ธํŠธ ์ •์˜
62
  class InputText(BaseModel):
@@ -70,3 +175,4 @@ def root():
70
  async def predict_route(item: InputText):
71
  result = predict(item.text)
72
  return {"text": item.text, "classification": result}
 
 
5
  import gluonnlp as nlp
6
  import numpy as np
7
  import os
8
+ from kobert_tokenizer import KoBERTTokenizer # kobert_tokenizer ์ž„ํฌํŠธ ์œ ์ง€
9
+ from transformers import BertModel # BertModel ์ž„ํฌํŠธ ์œ ์ง€
10
+ from torch.utils.data import Dataset, DataLoader # DataLoader ์ž„ํฌํŠธ ์ถ”๊ฐ€
11
+ import logging # ๋กœ๊น… ๋ชจ๋“ˆ ์ž„ํฌํŠธ ์œ ์ง€
 
12
 
13
+ # --- 1. BERTClassifier ๋ชจ๋ธ ํด๋ž˜์Šค ์ •์˜ (model.py์—์„œ ์˜ฎ๊ฒจ์˜ด) ---
14
+ # ์ด ํด๋ž˜์Šค๋Š” ๋ชจ๋ธ์˜ ์•„ํ‚คํ…์ฒ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
15
+ class BERTClassifier(torch.nn.Module):
16
+ def __init__(self,
17
+ bert,
18
+ hidden_size = 768,
19
+ num_classes=5, # ๋ถ„๋ฅ˜ํ•  ํด๋ž˜์Šค ์ˆ˜ (category ๋”•์…”๋„ˆ๋ฆฌ ํฌ๊ธฐ์™€ ์ผ์น˜)
20
+ dr_rate=None,
21
+ params=None):
22
+ super(BERTClassifier, self).__init__()
23
+ self.bert = bert
24
+ self.dr_rate = dr_rate
25
+
26
+ self.classifier = torch.nn.Linear(hidden_size , num_classes)
27
+ if dr_rate:
28
+ self.dropout = torch.nn.Dropout(p=dr_rate)
29
+
30
+ def gen_attention_mask(self, token_ids, valid_length):
31
+ attention_mask = torch.zeros_like(token_ids)
32
+ for i, v in enumerate(valid_length):
33
+ attention_mask[i][:v] = 1
34
+ return attention_mask.float()
35
+
36
+ def forward(self, token_ids, valid_length, segment_ids):
37
+ attention_mask = self.gen_attention_mask(token_ids, valid_length)
38
+
39
+ # BertModel์˜ ์ถœ๋ ฅ ๊ตฌ์กฐ์— ๋”ฐ๋ผ ์ˆ˜์ •
40
+ # Hugging Face Transformers์˜ BertModel์€ (last_hidden_state, pooler_output, ...) ๋ฐ˜ํ™˜
41
+ # pooler_output (CLS ํ† ํฐ์˜ ์ตœ์ข… ์€๋‹‰ ์ƒํƒœ๋ฅผ ํ†ต๊ณผํ•œ ๊ฒฐ๊ณผ) ์‚ฌ์šฉ
42
+ _, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(), attention_mask=attention_mask.float().to(token_ids.device), return_dict=False)
43
+
44
+ if self.dr_rate:
45
+ out = self.dropout(pooler)
46
+ else:
47
+ out = pooler
48
+ return self.classifier(out)
49
+
50
+ # --- 2. BERTDataset ํด๋ž˜์Šค ์ •์˜ (dataset.py์—์„œ ์˜ฎ๊ฒจ์˜ด) ---
51
+ # ์ด ํด๋ž˜์Šค๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ๋ชจ๋ธ ์ž…๋ ฅ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
52
+ class BERTDataset(Dataset):
53
+ def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
54
+ transform = nlp.data.BERTSentenceTransform(
55
+ bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
56
+ )
57
+ self.sentences = [transform([i[sent_idx]]) for i in dataset]
58
+ self.labels = [np.int32(i[label_idx]) for i in dataset]
59
+
60
+ def __getitem__(self, i):
61
+ return (self.sentences[i] + (self.labels[i],))
62
+
63
+ def __len__(self):
64
+ return len(self.labels)
65
+
66
+ # --- 3. FastAPI ์•ฑ ๋ฐ ์ „์—ญ ๋ณ€์ˆ˜ ์„ค์ • ---
67
  app = FastAPI()
68
+ device = torch.device("cpu") # Render์˜ ๋ฌด๋ฃŒ ํ‹ฐ์–ด๋Š” ์ฃผ๋กœ CPU๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
69
 
70
+ # โœ… category ๋กœ๋“œ (GitHub ์ €์žฅ์†Œ ๋ฃจํŠธ์— ์žˆ์–ด์•ผ ํ•จ)
71
+ try:
72
+ with open("category.pkl", "rb") as f:
73
+ category = pickle.load(f)
74
+ print("category.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
75
+ except FileNotFoundError:
76
+ print("Error: category.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
77
+ sys.exit(1) # ํŒŒ์ผ ์—†์œผ๋ฉด ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜์ง€ ์•Š์Œ
78
 
79
+ # โœ… vocab ๋กœ๋“œ (GitHub ์ €์žฅ์†Œ ๋ฃจํŠธ์— ์žˆ์–ด์•ผ ํ•จ)
80
+ try:
81
+ with open("vocab.pkl", "rb") as f:
82
+ vocab = pickle.load(f)
83
+ print("vocab.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
84
+ except FileNotFoundError:
85
+ print("Error: vocab.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
86
+ sys.exit(1) # ํŒŒ์ผ ์—†์œผ๋ฉด ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜์ง€ ์•Š์Œ
87
 
88
+ # โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ (kobert_tokenizer ์‚ฌ์šฉ)
89
+ # Colab ์ฝ”๋“œ์—์„œ ์‚ฌ์šฉ๋œ ๋ฐฉ์‹์ด๋ฏ€๋กœ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.
90
  tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1')
91
+ print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
92
 
93
  # โœ… ๋ชจ๋ธ ๋กœ๋“œ
94
+ # ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜๋ฅผ ์ •์˜ํ•˜๊ณ , ์ €์žฅ๋œ state_dict๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
95
+ # num_classes๋Š” category ๋”•์…”๋„ˆ๋ฆฌ์˜ ํฌ๊ธฐ์™€ ์ผ์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
96
  model = BERTClassifier(
97
  BertModel.from_pretrained('skt/kobert-base-v1'),
98
+ dr_rate=0.5, # ํ•™์Šต ์‹œ ์‚ฌ์šฉ๋œ dr_rate ๊ฐ’์œผ๋กœ ๋ณ€๊ฒฝํ•˜์„ธ์š”.
99
  num_classes=len(category)
100
  )
101
+
102
+ # textClassifierModel.pt ํŒŒ์ผ ๋กœ๋“œ
103
+ # ์ด ํŒŒ์ผ์€ GitHub ์ €์žฅ์†Œ์— ์—†์–ด์•ผ ํ•˜๋ฉฐ, Dockerfile์—์„œ Hugging Face Hub์—์„œ ๋‹ค์šด๋กœ๋“œํ•˜๋„๋ก ์„ค์ •๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
104
+ try:
105
+ # Dockerfile์—์„œ ๋ชจ๋ธ์„ ๋‹ค์šด๋กœ๋“œํ•˜๋„๋ก ์„ค์ •ํ–ˆ์œผ๋ฏ€๋กœ, ์—ฌ๊ธฐ์„œ๋Š” ๋กœ์ปฌ ๊ฒฝ๋กœ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
106
+ # ๋งŒ์•ฝ Dockerfile์—์„œ hf_hub_download๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š”๋‹ค๋ฉด, ์—ฌ๊ธฐ์— hf_hub_download๋ฅผ ์ถ”๊ฐ€ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
107
+ # ํ˜„์žฌ Dockerfile์€ git+https://github.com/SKTBrain/KOBERT#egg=kobert_tokenizer ๋กœ๋“œ๋งŒ ํฌํ•จํ•˜๊ณ ,
108
+ # ๋ชจ๋ธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ๋Š” ํฌํ•จํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
109
+ # ๋”ฐ๋ผ์„œ, ๋ชจ๋ธ ํŒŒ์ผ์„ Hugging Face Hub์—์„œ ๋‹ค์šด๋กœ๋“œํ•˜๋Š” ๋กœ์ง์„ ๋‹ค์‹œ ์ถ”๊ฐ€ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
110
+ from huggingface_hub import hf_hub_download
111
+ HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" # ์‚ฌ์šฉ์ž๋‹˜์˜ ์‹ค์ œ Hugging Face ์ €์žฅ์†Œ ID
112
+ HF_MODEL_FILENAME = "textClassifierModel.pt"
113
+ model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
114
+ print(f"๋ชจ๋ธ ํŒŒ์ผ์ด '{model_path}'์— ์„ฑ๊ณต์ ์œผ๋กœ ๋‹ค์šด๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
115
+
116
+ # ๋ชจ๋ธ์˜ state_dict๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
117
+ loaded_state_dict = torch.load(model_path, map_location=device)
118
+
119
+ # state_dict ํ‚ค ์กฐ์ • (ํ•„์š”ํ•œ ๊ฒฝ์šฐ)
120
+ new_state_dict = collections.OrderedDict()
121
+ for k, v in loaded_state_dict.items():
122
+ name = k
123
+ if name.startswith('module.'):
124
+ name = name[7:]
125
+ new_state_dict[name] = v
126
+
127
+ model.load_state_dict(new_state_dict)
128
+ model.to(device) # ๋ชจ๋ธ์„ ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™
129
+ model.eval() # ์ถ”๋ก  ๋ชจ๋“œ๋กœ ์„ค์ •
130
+ print("๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต.")
131
+
132
+ except Exception as e:
133
+ print(f"Error: ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ๋˜๋Š” ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
134
+ sys.exit(1) # ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ ์‹œ ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜์ง€ ์•Š์Œ
135
+
136
 
137
  # โœ… ๋ฐ์ดํ„ฐ์…‹ ์ƒ์„ฑ์— ํ•„์š”ํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ
138
  max_len = 64
 
142
  def predict(predict_sentence):
143
  data = [predict_sentence, '0']
144
  dataset_another = [data]
145
+ # num_workers๋Š” ๋ฐฐํฌ ํ™˜๊ฒฝ์—์„œ 0์œผ๋กœ ์„ค์ • ๊ถŒ์žฅ
146
  another_test = BERTDataset(dataset_another, 0, 1, tokenizer, vocab, max_len, True, False)
147
+ test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0)
148
 
149
+ model.eval() # ์˜ˆ์ธก ์‹œ ๋ชจ๋ธ์„ ํ‰๊ฐ€ ๋ชจ๋“œ๋กœ ์„ค์ •
 
 
 
150
 
151
+ with torch.no_grad(): # ๊ทธ๋ผ๋””์–ธํŠธ ๊ณ„์‚ฐ ๋น„ํ™œ์„ฑํ™”
152
+ for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader):
153
+ token_ids = token_ids.long().to(device)
154
+ segment_ids = segment_ids.long().to(device)
155
+
156
+ out = model(token_ids, valid_length, segment_ids)
157
+
158
+ logits = out
159
+ logits = logits.detach().cpu().numpy()
160
+
161
+ predicted_category_index = np.argmax(logits)
162
+ predicted_category_name = list(category.keys())[predicted_category_index]
163
+
164
+ return predicted_category_name
165
 
166
  # โœ… ์—”๋“œํฌ์ธํŠธ ์ •์˜
167
  class InputText(BaseModel):
 
175
  async def predict_route(item: InputText):
176
  result = predict(item.text)
177
  return {"text": item.text, "classification": result}
178
+