hiddenFront's picture
Create app.py
7f17fe7 verified
raw
history blame
3.46 kB
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, BertForSequenceClassification, BertConfig
from huggingface_hub import hf_hub_download
import torch
import numpy as np
import pickle
import sys
import collections
import os # os λͺ¨λ“ˆ μž„ν¬νŠΈ
import psutil # λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰ 확인을 μœ„ν•΄ psutil μž„ν¬νŠΈ (requirements.txt에 μΆ”κ°€ ν•„μš”)
app = FastAPI()
device = torch.device("cpu")
# category.pkl λ‘œλ“œ
try:
with open("category.pkl", "rb") as f:
category = pickle.load(f)
print("category.pkl λ‘œλ“œ 성곡.")
except FileNotFoundError:
print("Error: category.pkl νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€. ν”„λ‘œμ νŠΈ λ£¨νŠΈμ— μžˆλŠ”μ§€ ν™•μΈν•˜μ„Έμš”.")
sys.exit(1)
# ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
print("ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ 성곡.")
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
HF_MODEL_FILENAME = "textClassifierModel.pt"
# --- λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰ λ‘œκΉ… μ‹œμž‘ ---
process = psutil.Process(os.getpid())
mem_before_model_download = process.memory_info().rss / (1024 * 1024) # MB λ‹¨μœ„
print(f"λͺ¨λΈ λ‹€μš΄λ‘œλ“œ μ „ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_before_model_download:.2f} MB")
# --- λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰ λ‘œκΉ… 끝 ---
try:
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
print(f"λͺ¨λΈ 파일이 '{model_path}'에 μ„±κ³΅μ μœΌλ‘œ λ‹€μš΄λ‘œλ“œλ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
# --- λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰ λ‘œκΉ… μ‹œμž‘ ---
mem_after_model_download = process.memory_info().rss / (1024 * 1024) # MB λ‹¨μœ„
print(f"λͺ¨λΈ λ‹€μš΄λ‘œλ“œ ν›„ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_after_model_download:.2f} MB")
# --- λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰ λ‘œκΉ… 끝 ---
# 1. λͺ¨λΈ μ•„ν‚€ν…μ²˜ μ •μ˜ (κ°€μ€‘μΉ˜λŠ” λ‘œλ“œν•˜μ§€ μ•Šκ³  ꡬ쑰만 μ΄ˆκΈ°ν™”)
config = BertConfig.from_pretrained("skt/kobert-base-v1", num_labels=len(category))
model = BertForSequenceClassification(config)
# 2. λ‹€μš΄λ‘œλ“œλœ νŒŒμΌμ—μ„œ state_dictλ₯Ό λ‘œλ“œ
loaded_state_dict = torch.load(model_path, map_location=device)
# 3. λ‘œλ“œλœ state_dictλ₯Ό μ •μ˜λœ λͺ¨λΈμ— 적용
new_state_dict = collections.OrderedDict()
for k, v in loaded_state_dict.items():
name = k
if name.startswith('module.'):
name = name[7:]
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
# --- λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰ λ‘œκΉ… μ‹œμž‘ ---
mem_after_model_load = process.memory_info().rss / (1024 * 1024) # MB λ‹¨μœ„
print(f"λͺ¨λΈ λ‘œλ“œ 및 state_dict 적용 ν›„ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_after_model_load:.2f} MB")
# --- λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰ λ‘œκΉ… 끝 ---
model.eval()
print("λͺ¨λΈ λ‘œλ“œ 성곡.")
except Exception as e:
print(f"Error: λͺ¨λΈ λ‹€μš΄λ‘œλ“œ λ˜λŠ” λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
sys.exit(1)
@app.post("/predict")
async def predict_api(request: Request):
data = await request.json()
text = data.get("text")
if not text:
return {"error": "No text provided", "classification": "null"}
encoded = tokenizer.encode_plus(
text, max_length=64, padding='max_length', truncation=True, return_tensors='pt'
)
with torch.no_grad():
outputs = model(**encoded)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
predicted = torch.argmax(probs, dim=1).item()
label = list(category.keys())[predicted]
return {"text": text, "classification": label}