hiddenFront commited on
Commit
3cc319e
ยท
verified ยท
1 Parent(s): dd0f82c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -23
app.py CHANGED
@@ -5,13 +5,16 @@ import pickle
5
  import gluonnlp as nlp
6
  import numpy as np
7
  import os
8
- from kobert_tokenizer import KoBERTTokenizer # kobert_tokenizer ์ž„ํฌํŠธ ์œ ์ง€
9
- from transformers import BertModel # BertModel ์ž„ํฌํŠธ ์œ ์ง€
10
- from torch.utils.data import Dataset, DataLoader # DataLoader ์ž„ํฌํŠธ ์ถ”๊ฐ€
 
 
11
  import logging # ๋กœ๊น… ๋ชจ๋“ˆ ์ž„ํฌํŠธ ์œ ์ง€
 
 
12
 
13
  # --- 1. BERTClassifier ๋ชจ๋ธ ํด๋ž˜์Šค ์ •์˜ (model.py์—์„œ ์˜ฎ๊ฒจ์˜ด) ---
14
- # ์ด ํด๋ž˜์Šค๋Š” ๋ชจ๋ธ์˜ ์•„ํ‚คํ…์ฒ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
15
  class BERTClassifier(torch.nn.Module):
16
  def __init__(self,
17
  bert,
@@ -36,9 +39,6 @@ class BERTClassifier(torch.nn.Module):
36
  def forward(self, token_ids, valid_length, segment_ids):
37
  attention_mask = self.gen_attention_mask(token_ids, valid_length)
38
 
39
- # BertModel์˜ ์ถœ๋ ฅ ๊ตฌ์กฐ์— ๋”ฐ๋ผ ์ˆ˜์ •
40
- # Hugging Face Transformers์˜ BertModel์€ (last_hidden_state, pooler_output, ...) ๋ฐ˜ํ™˜
41
- # pooler_output (CLS ํ† ํฐ์˜ ์ตœ์ข… ์€๋‹‰ ์ƒํƒœ๋ฅผ ํ†ต๊ณผํ•œ ๊ฒฐ๊ณผ) ์‚ฌ์šฉ
42
  _, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(), attention_mask=attention_mask.float().to(token_ids.device), return_dict=False)
43
 
44
  if self.dr_rate:
@@ -48,9 +48,10 @@ class BERTClassifier(torch.nn.Module):
48
  return self.classifier(out)
49
 
50
  # --- 2. BERTDataset ํด๋ž˜์Šค ์ •์˜ (dataset.py์—์„œ ์˜ฎ๊ฒจ์˜ด) ---
51
- # ์ด ํด๋ž˜์Šค๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ๋ชจ๋ธ ์ž…๋ ฅ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
52
  class BERTDataset(Dataset):
53
  def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
 
 
54
  transform = nlp.data.BERTSentenceTransform(
55
  bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
56
  )
@@ -85,38 +86,30 @@ except FileNotFoundError:
85
  print("Error: vocab.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
86
  sys.exit(1) # ํŒŒ์ผ ์—†์œผ๋ฉด ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜์ง€ ์•Š์Œ
87
 
88
- # โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ (kobert_tokenizer ์‚ฌ์šฉ)
89
- # Colab ์ฝ”๋“œ์—์„œ ์‚ฌ์šฉ๋œ ๋ฐฉ์‹์ด๋ฏ€๋กœ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.
90
- tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1')
 
91
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
92
 
93
  # โœ… ๋ชจ๋ธ ๋กœ๋“œ
94
- # ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜๋ฅผ ์ •์˜ํ•˜๊ณ , ์ €์žฅ๋œ state_dict๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
95
  # num_classes๋Š” category ๋”•์…”๋„ˆ๋ฆฌ์˜ ํฌ๊ธฐ์™€ ์ผ์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
 
96
  model = BERTClassifier(
97
- BertModel.from_pretrained('skt/kobert-base-v1'),
98
  dr_rate=0.5, # ํ•™์Šต ์‹œ ์‚ฌ์šฉ๋œ dr_rate ๊ฐ’์œผ๋กœ ๋ณ€๊ฒฝํ•˜์„ธ์š”.
99
  num_classes=len(category)
100
  )
101
 
102
  # textClassifierModel.pt ํŒŒ์ผ ๋กœ๋“œ
103
- # ์ด ํŒŒ์ผ์€ GitHub ์ €์žฅ์†Œ์— ์—†์–ด์•ผ ํ•˜๋ฉฐ, Dockerfile์—์„œ Hugging Face Hub์—์„œ ๋‹ค์šด๋กœ๋“œํ•˜๋„๋ก ์„ค์ •๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
104
  try:
105
- # Dockerfile์—์„œ ๋ชจ๋ธ์„ ๋‹ค์šด๋กœ๋“œํ•˜๋„๋ก ์„ค์ •ํ–ˆ์œผ๋ฏ€๋กœ, ์—ฌ๊ธฐ์„œ๋Š” ๋กœ์ปฌ ๊ฒฝ๋กœ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
106
- # ๋งŒ์•ฝ Dockerfile์—์„œ hf_hub_download๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š”๋‹ค๋ฉด, ์—ฌ๊ธฐ์— hf_hub_download๋ฅผ ์ถ”๊ฐ€ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
107
- # ํ˜„์žฌ Dockerfile์€ git+https://github.com/SKTBrain/KOBERT#egg=kobert_tokenizer ๋กœ๋“œ๋งŒ ํฌํ•จํ•˜๊ณ ,
108
- # ๋ชจ๋ธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ๋Š” ํฌํ•จํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
109
- # ๋”ฐ๋ผ์„œ, ๋ชจ๋ธ ํŒŒ์ผ์„ Hugging Face Hub์—์„œ ๋‹ค์šด๋กœ๋“œํ•˜๋Š” ๋กœ์ง์„ ๋‹ค์‹œ ์ถ”๊ฐ€ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
110
- from huggingface_hub import hf_hub_download
111
  HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" # ์‚ฌ์šฉ์ž๋‹˜์˜ ์‹ค์ œ Hugging Face ์ €์žฅ์†Œ ID
112
  HF_MODEL_FILENAME = "textClassifierModel.pt"
113
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
114
  print(f"๋ชจ๋ธ ํŒŒ์ผ์ด '{model_path}'์— ์„ฑ๊ณต์ ์œผ๋กœ ๋‹ค์šด๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
115
 
116
- # ๋ชจ๋ธ์˜ state_dict๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
117
  loaded_state_dict = torch.load(model_path, map_location=device)
118
 
119
- # state_dict ํ‚ค ์กฐ์ • (ํ•„์š”ํ•œ ๊ฒฝ์šฐ)
120
  new_state_dict = collections.OrderedDict()
121
  for k, v in loaded_state_dict.items():
122
  name = k
@@ -143,7 +136,7 @@ def predict(predict_sentence):
143
  data = [predict_sentence, '0']
144
  dataset_another = [data]
145
  # num_workers๋Š” ๋ฐฐํฌ ํ™˜๊ฒฝ์—์„œ 0์œผ๋กœ ์„ค์ • ๊ถŒ์žฅ
146
- another_test = BERTDataset(dataset_another, 0, 1, tokenizer, vocab, max_len, True, False)
147
  test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0)
148
 
149
  model.eval() # ์˜ˆ์ธก ์‹œ ๋ชจ๋ธ์„ ํ‰๊ฐ€ ๋ชจ๋“œ๋กœ ์„ค์ •
 
5
  import gluonnlp as nlp
6
  import numpy as np
7
  import os
8
+ import sys # sys ๋ชจ๋“ˆ ์ž„ํฌํŠธ ์ถ”๊ฐ€ (NameError ํ•ด๊ฒฐ)
9
+
10
+ # KoBERTTokenizer ๋Œ€์‹  transformers.AutoTokenizer ์‚ฌ์šฉ
11
+ from transformers import BertModel, AutoTokenizer # AutoTokenizer ์ž„ํฌํŠธ ์œ ์ง€
12
+ from torch.utils.data import Dataset, DataLoader
13
  import logging # ๋กœ๊น… ๋ชจ๋“ˆ ์ž„ํฌํŠธ ์œ ์ง€
14
+ from huggingface_hub import hf_hub_download # hf_hub_download ์ž„ํฌํŠธ ์ถ”๊ฐ€
15
+ import collections # collections ๋ชจ๋“ˆ ์ž„ํฌํŠธ ์œ ์ง€
16
 
17
  # --- 1. BERTClassifier ๋ชจ๋ธ ํด๋ž˜์Šค ์ •์˜ (model.py์—์„œ ์˜ฎ๊ฒจ์˜ด) ---
 
18
  class BERTClassifier(torch.nn.Module):
19
  def __init__(self,
20
  bert,
 
39
  def forward(self, token_ids, valid_length, segment_ids):
40
  attention_mask = self.gen_attention_mask(token_ids, valid_length)
41
 
 
 
 
42
  _, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(), attention_mask=attention_mask.float().to(token_ids.device), return_dict=False)
43
 
44
  if self.dr_rate:
 
48
  return self.classifier(out)
49
 
50
  # --- 2. BERTDataset ํด๋ž˜์Šค ์ •์˜ (dataset.py์—์„œ ์˜ฎ๊ฒจ์˜ด) ---
 
51
  class BERTDataset(Dataset):
52
  def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
53
+ # nlp.data.BERTSentenceTransform์€ ํ† ํฌ๋‚˜์ด์ € ํ•จ์ˆ˜๋ฅผ ๋ฐ›์Šต๋‹ˆ๋‹ค.
54
+ # AutoTokenizer์˜ tokenize ๋ฉ”์„œ๋“œ๋ฅผ ์ง์ ‘ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
55
  transform = nlp.data.BERTSentenceTransform(
56
  bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
57
  )
 
86
  print("Error: vocab.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
87
  sys.exit(1) # ํŒŒ์ผ ์—†์œผ๋ฉด ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜์ง€ ์•Š์Œ
88
 
89
+ # โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ (transformers.AutoTokenizer ์‚ฌ์šฉ)
90
+ # KoBERTTokenizer ๋Œ€์‹  AutoTokenizer๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ KoBERT ๋ชจ๋ธ์˜ ํ† ํฌ๋‚˜์ด์ €๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
91
+ # ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด XLNetTokenizer ๊ฒฝ๊ณ  ๋ฐ kobert_tokenizer ์„ค์น˜ ๋ฌธ์ œ๋ฅผ ํ”ผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
92
+ tokenizer = AutoTokenizer.from_pretrained('skt/kobert-base-v1')
93
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
94
 
95
  # โœ… ๋ชจ๋ธ ๋กœ๋“œ
 
96
  # num_classes๋Š” category ๋”•์…”๋„ˆ๋ฆฌ์˜ ํฌ๊ธฐ์™€ ์ผ์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
97
+ bertmodel = BertModel.from_pretrained('skt/kobert-base-v1')
98
  model = BERTClassifier(
99
+ bertmodel,
100
  dr_rate=0.5, # ํ•™์Šต ์‹œ ์‚ฌ์šฉ๋œ dr_rate ๊ฐ’์œผ๋กœ ๋ณ€๊ฒฝํ•˜์„ธ์š”.
101
  num_classes=len(category)
102
  )
103
 
104
  # textClassifierModel.pt ํŒŒ์ผ ๋กœ๋“œ
 
105
  try:
 
 
 
 
 
 
106
  HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" # ์‚ฌ์šฉ์ž๋‹˜์˜ ์‹ค์ œ Hugging Face ์ €์žฅ์†Œ ID
107
  HF_MODEL_FILENAME = "textClassifierModel.pt"
108
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
109
  print(f"๋ชจ๋ธ ํŒŒ์ผ์ด '{model_path}'์— ์„ฑ๊ณต์ ์œผ๋กœ ๋‹ค์šด๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
110
 
 
111
  loaded_state_dict = torch.load(model_path, map_location=device)
112
 
 
113
  new_state_dict = collections.OrderedDict()
114
  for k, v in loaded_state_dict.items():
115
  name = k
 
136
  data = [predict_sentence, '0']
137
  dataset_another = [data]
138
  # num_workers๋Š” ๋ฐฐํฌ ํ™˜๊ฒฝ์—์„œ 0์œผ๋กœ ์„ค์ • ๊ถŒ์žฅ
139
+ another_test = BERTDataset(dataset_another, 0, 1, tokenizer, vocab, max_len, True, False) # tokenizer ๊ฐ์ฒด ์ง์ ‘ ์ „๋‹ฌ
140
  test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0)
141
 
142
  model.eval() # ์˜ˆ์ธก ์‹œ ๋ชจ๋ธ์„ ํ‰๊ฐ€ ๋ชจ๋“œ๋กœ ์„ค์ •