Spaces:
Sleeping
Sleeping
File size: 2,041 Bytes
7f17fe7 95b43d8 7f17fe7 95b43d8 e66afc2 95b43d8 7f17fe7 95b43d8 7f17fe7 95b43d8 7f17fe7 95b43d8 6ba018e 95b43d8 6ba018e 95b43d8 7f17fe7 95b43d8 7f17fe7 95b43d8 7f17fe7 95b43d8 4607c9c 689eabe 95b43d8 689eabe 7f17fe7 95b43d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from fastapi import FastAPI, Request
from pydantic import BaseModel
import torch
import pickle
import gluonnlp as nlp
import numpy as np
import os
from kobert_tokenizer import KoBERTTokenizer
from model import BERTClassifier
from dataset import BERTDataset
from transformers import BertModel
import logging
app = FastAPI()
device = torch.device("cpu")
# β
category λ‘λ
with open("category.pkl", "rb") as f:
category = pickle.load(f)
# β
vocab λ‘λ
with open("vocab.pkl", "rb") as f:
vocab = pickle.load(f)
# β
ν ν¬λμ΄μ
tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1')
# β
λͺ¨λΈ λ‘λ
model = BERTClassifier(
BertModel.from_pretrained('skt/kobert-base-v1'),
dr_rate=0.5,
num_classes=len(category)
)
model.load_state_dict(torch.load("textClassifierModel.pt", map_location=device))
model.to(device)
model.eval()
# β
λ°μ΄ν°μ
μμ±μ νμν νλΌλ―Έν°
max_len = 64
batch_size = 32
# β
μμΈ‘ ν¨μ
def predict(predict_sentence):
data = [predict_sentence, '0']
dataset_another = [data]
another_test = BERTDataset(dataset_another, 0, 1, tokenizer, vocab, max_len, True, False)
test_dataLoader = torch.utils.data.DataLoader(another_test, batch_size=batch_size, num_workers=0)
model.eval()
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader):
token_ids = token_ids.long().to(device)
segment_ids = segment_ids.long().to(device)
out = model(token_ids, valid_length, segment_ids)
test_eval = []
for i in out:
logits = i.detach().cpu().numpy()
test_eval.append(list(category.keys())[np.argmax(logits)])
return test_eval[0]
# β
μλν¬μΈνΈ μ μ
class InputText(BaseModel):
text: str
@app.get("/")
def root():
return {"message": "Text Classification API (KoBERT)"}
@app.post("/predict")
async def predict_route(item: InputText):
result = predict(item.text)
return {"text": item.text, "classification": result}
|