Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request | |
| from pydantic import BaseModel | |
| import torch | |
| import pickle | |
| import gluonnlp as nlp | |
| import numpy as np | |
| import os | |
| from kobert_tokenizer import KoBERTTokenizer | |
| from model import BERTClassifier | |
| from dataset import BERTDataset | |
| from transformers import BertModel | |
| import logging | |
| app = FastAPI() | |
| device = torch.device("cpu") | |
| # β category λ‘λ | |
| with open("category.pkl", "rb") as f: | |
| category = pickle.load(f) | |
| # β vocab λ‘λ | |
| with open("vocab.pkl", "rb") as f: | |
| vocab = pickle.load(f) | |
| # β ν ν¬λμ΄μ | |
| tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1') | |
| # β λͺ¨λΈ λ‘λ | |
| model = BERTClassifier( | |
| BertModel.from_pretrained('skt/kobert-base-v1'), | |
| dr_rate=0.5, | |
| num_classes=len(category) | |
| ) | |
| model.load_state_dict(torch.load("textClassifierModel.pt", map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| # β λ°μ΄ν°μ μμ±μ νμν νλΌλ―Έν° | |
| max_len = 64 | |
| batch_size = 32 | |
| # β μμΈ‘ ν¨μ | |
| def predict(predict_sentence): | |
| data = [predict_sentence, '0'] | |
| dataset_another = [data] | |
| another_test = BERTDataset(dataset_another, 0, 1, tokenizer, vocab, max_len, True, False) | |
| test_dataLoader = torch.utils.data.DataLoader(another_test, batch_size=batch_size, num_workers=0) | |
| model.eval() | |
| for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader): | |
| token_ids = token_ids.long().to(device) | |
| segment_ids = segment_ids.long().to(device) | |
| out = model(token_ids, valid_length, segment_ids) | |
| test_eval = [] | |
| for i in out: | |
| logits = i.detach().cpu().numpy() | |
| test_eval.append(list(category.keys())[np.argmax(logits)]) | |
| return test_eval[0] | |
| # β μλν¬μΈνΈ μ μ | |
| class InputText(BaseModel): | |
| text: str | |
| def root(): | |
| return {"message": "Text Classification API (KoBERT)"} | |
| async def predict_route(item: InputText): | |
| result = predict(item.text) | |
| return {"text": item.text, "classification": result} | |