Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -23,10 +23,24 @@ except FileNotFoundError:
|
|
| 23 |
tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
|
| 24 |
print("โ
ํ ํฌ๋์ด์ ๋ก๋ ์ฑ๊ณต.")
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
|
| 32 |
HF_MODEL_FILENAME = "textClassifierModel.pt"
|
|
|
|
| 23 |
tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
|
| 24 |
print("โ
ํ ํฌ๋์ด์ ๋ก๋ ์ฑ๊ณต.")
|
| 25 |
|
| 26 |
+
class CustomClassifier(torch.nn.Module):
|
| 27 |
+
def __init__(self):
|
| 28 |
+
super().__init__()
|
| 29 |
+
# ์ ์ํ๋ ๊ตฌ์กฐ ๊ทธ๋๋ก ๋ณต์ํด์ผ ํจ
|
| 30 |
+
self.bert = BertModel.from_pretrained("skt/kobert-base-v1")
|
| 31 |
+
self.classifier = torch.nn.Linear(768, len(category))
|
| 32 |
+
|
| 33 |
+
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
|
| 34 |
+
outputs = self.bert(input_ids=input_ids,
|
| 35 |
+
attention_mask=attention_mask,
|
| 36 |
+
token_type_ids=token_type_ids)
|
| 37 |
+
pooled_output = outputs[1] # CLS ํ ํฐ
|
| 38 |
+
return self.classifier(pooled_output)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
model = CustomClassifier()
|
| 42 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 43 |
+
model.eval()
|
| 44 |
|
| 45 |
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
|
| 46 |
HF_MODEL_FILENAME = "textClassifierModel.pt"
|