hiddenFront commited on
Commit
6ba018e
ยท
verified ยท
1 Parent(s): ec61894

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -4
app.py CHANGED
@@ -23,10 +23,24 @@ except FileNotFoundError:
23
  tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
24
  print("โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
25
 
26
- # ๋ชจ๋ธ ๊ตฌ์กฐ ์žฌ์ •์˜
27
- num_labels = len(category) # ๋ถ„๋ฅ˜ํ•  ํด๋ž˜์Šค ์ˆ˜์— ๋”ฐ๋ผ
28
- model = BertForSequenceClassification.from_pretrained("skt/kobert-base-v1", num_labels=num_labels)
29
- model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
32
  HF_MODEL_FILENAME = "textClassifierModel.pt"
 
23
  tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
24
  print("โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
25
 
26
+ class CustomClassifier(torch.nn.Module):
27
+ def __init__(self):
28
+ super().__init__()
29
+ # ์ •์˜ํ–ˆ๋˜ ๊ตฌ์กฐ ๊ทธ๋Œ€๋กœ ๋ณต์›ํ•ด์•ผ ํ•จ
30
+ self.bert = BertModel.from_pretrained("skt/kobert-base-v1")
31
+ self.classifier = torch.nn.Linear(768, len(category))
32
+
33
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None):
34
+ outputs = self.bert(input_ids=input_ids,
35
+ attention_mask=attention_mask,
36
+ token_type_ids=token_type_ids)
37
+ pooled_output = outputs[1] # CLS ํ† ํฐ
38
+ return self.classifier(pooled_output)
39
+
40
+
41
+ model = CustomClassifier()
42
+ model.load_state_dict(torch.load(model_path, map_location=device))
43
+ model.eval()
44
 
45
  HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
46
  HF_MODEL_FILENAME = "textClassifierModel.pt"