hiddenFront's picture
Update app.py
95b43d8 verified
raw
history blame
2.04 kB
from fastapi import FastAPI, Request
from pydantic import BaseModel
import torch
import pickle
import gluonnlp as nlp
import numpy as np
import os
from kobert_tokenizer import KoBERTTokenizer
from model import BERTClassifier
from dataset import BERTDataset
from transformers import BertModel
import logging
app = FastAPI()
device = torch.device("cpu")
# βœ… category λ‘œλ“œ
with open("category.pkl", "rb") as f:
category = pickle.load(f)
# βœ… vocab λ‘œλ“œ
with open("vocab.pkl", "rb") as f:
vocab = pickle.load(f)
# βœ… ν† ν¬λ‚˜μ΄μ €
tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1')
# βœ… λͺ¨λΈ λ‘œλ“œ
model = BERTClassifier(
BertModel.from_pretrained('skt/kobert-base-v1'),
dr_rate=0.5,
num_classes=len(category)
)
model.load_state_dict(torch.load("textClassifierModel.pt", map_location=device))
model.to(device)
model.eval()
# βœ… 데이터셋 생성에 ν•„μš”ν•œ νŒŒλΌλ―Έν„°
max_len = 64
batch_size = 32
# βœ… 예츑 ν•¨μˆ˜
def predict(predict_sentence):
data = [predict_sentence, '0']
dataset_another = [data]
another_test = BERTDataset(dataset_another, 0, 1, tokenizer, vocab, max_len, True, False)
test_dataLoader = torch.utils.data.DataLoader(another_test, batch_size=batch_size, num_workers=0)
model.eval()
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader):
token_ids = token_ids.long().to(device)
segment_ids = segment_ids.long().to(device)
out = model(token_ids, valid_length, segment_ids)
test_eval = []
for i in out:
logits = i.detach().cpu().numpy()
test_eval.append(list(category.keys())[np.argmax(logits)])
return test_eval[0]
# βœ… μ—”λ“œν¬μΈνŠΈ μ •μ˜
class InputText(BaseModel):
text: str
@app.get("/")
def root():
return {"message": "Text Classification API (KoBERT)"}
@app.post("/predict")
async def predict_route(item: InputText):
result = predict(item.text)
return {"text": item.text, "classification": result}