hiddenFront's picture
Update app.py
0914de7 verified
raw
history blame
2.8 kB
from fastapi import FastAPI, Request
from transformers import BertModel, BertForSequenceClassification, AutoTokenizer
from huggingface_hub import hf_hub_download
import torch
import pickle
import os
import sys
import psutil
app = FastAPI()
device = torch.device("cpu")
# category.pkl λ‘œλ“œ
try:
with open("category.pkl", "rb") as f:
category = pickle.load(f)
print("βœ… category.pkl λ‘œλ“œ 성곡.")
except FileNotFoundError:
print("❌ Error: category.pkl νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.")
sys.exit(1)
# ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
print("βœ… ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ 성곡.")
class CustomClassifier(torch.nn.Module):
def __init__(self):
super().__init__()
# μ •μ˜ν–ˆλ˜ ꡬ쑰 κ·ΈλŒ€λ‘œ 볡원해야 함
self.bert = BertModel.from_pretrained("skt/kobert-base-v1")
self.classifier = torch.nn.Linear(768, len(category))
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
outputs = self.bert(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
pooled_output = outputs[1] # CLS 토큰
return self.classifier(pooled_output)
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
HF_MODEL_FILENAME = "textClassifierModel.pt"
# λ©”λͺ¨λ¦¬ μΈ‘μ • μ „
process = psutil.Process(os.getpid())
mem_before = process.memory_info().rss / (1024 * 1024)
print(f"πŸ“¦ λͺ¨λΈ λ‹€μš΄λ‘œλ“œ μ „ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_before:.2f} MB")
# λͺ¨λΈ λ‘œλ“œ
try:
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
print(f"βœ… λͺ¨λΈ 파일 λ‹€μš΄λ‘œλ“œ 성곡: {model_path}")
state_dict = torch.load(model_path, map_location=device)
model = BertForSequenceClassification.from_pretrained(
"skt/kobert-base-v1",
num_labels=len(category),
state_dict=state_dict,
)
model.to(device)
model.eval()
print("βœ… λͺ¨λΈ λ‘œλ“œ 및 μ€€λΉ„ μ™„λ£Œ.")
except Exception as e:
print(f"❌ Error: λͺ¨λΈ λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
sys.exit(1)
# 예츑 API
@app.post("/predict")
async def predict_api(request: Request):
data = await request.json()
text = data.get("text")
if not text:
return {"error": "No text provided", "classification": "null"}
encoded = tokenizer.encode_plus(
text, max_length=64, padding='max_length', truncation=True, return_tensors='pt'
)
with torch.no_grad():
outputs = model(**encoded)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
predicted = torch.argmax(probs, dim=1).item()
label = list(category.keys())[predicted]
return {"text": text, "classification": label}