File size: 2,041 Bytes
7f17fe7
95b43d8
7f17fe7
 
95b43d8
 
e66afc2
95b43d8
 
 
 
 
7f17fe7
 
 
 
95b43d8
 
 
7f17fe7
95b43d8
 
 
7f17fe7
95b43d8
 
6ba018e
95b43d8
 
 
 
 
 
 
 
 
6ba018e
95b43d8
 
 
7f17fe7
95b43d8
 
 
 
 
 
7f17fe7
 
95b43d8
 
 
 
 
 
 
 
 
 
7f17fe7
95b43d8
 
 
4607c9c
689eabe
95b43d8
 
689eabe
7f17fe7
95b43d8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from fastapi import FastAPI, Request
from pydantic import BaseModel
import torch
import pickle
import gluonnlp as nlp
import numpy as np
import os
from kobert_tokenizer import KoBERTTokenizer
from model import BERTClassifier
from dataset import BERTDataset
from transformers import BertModel
import logging

app = FastAPI()
device = torch.device("cpu")

# βœ… category λ‘œλ“œ
with open("category.pkl", "rb") as f:
    category = pickle.load(f)

# βœ… vocab λ‘œλ“œ
with open("vocab.pkl", "rb") as f:
    vocab = pickle.load(f)

# βœ… ν† ν¬λ‚˜μ΄μ €
tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1')

# βœ… λͺ¨λΈ λ‘œλ“œ
model = BERTClassifier(
    BertModel.from_pretrained('skt/kobert-base-v1'),
    dr_rate=0.5,
    num_classes=len(category)
)
model.load_state_dict(torch.load("textClassifierModel.pt", map_location=device))
model.to(device)
model.eval()

# βœ… 데이터셋 생성에 ν•„μš”ν•œ νŒŒλΌλ―Έν„°
max_len = 64
batch_size = 32

# βœ… 예츑 ν•¨μˆ˜
def predict(predict_sentence):
    data = [predict_sentence, '0']
    dataset_another = [data]
    another_test = BERTDataset(dataset_another, 0, 1, tokenizer, vocab, max_len, True, False)
    test_dataLoader = torch.utils.data.DataLoader(another_test, batch_size=batch_size, num_workers=0)

    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)

        out = model(token_ids, valid_length, segment_ids)
        test_eval = []
        for i in out:
            logits = i.detach().cpu().numpy()
            test_eval.append(list(category.keys())[np.argmax(logits)])
        return test_eval[0]

# βœ… μ—”λ“œν¬μΈνŠΈ μ •μ˜
class InputText(BaseModel):
    text: str

@app.get("/")
def root():
    return {"message": "Text Classification API (KoBERT)"}

@app.post("/predict")
async def predict_route(item: InputText):
    result = predict(item.text)
    return {"text": item.text, "classification": result}