File size: 2,349 Bytes
7f17fe7
e66afc2
7f17fe7
 
 
e66afc2
 
7f17fe7
 
 
 
 
 
 
 
 
e66afc2
7f17fe7
e66afc2
7f17fe7
 
 
 
e66afc2
7f17fe7
 
 
 
e66afc2
7f17fe7
e66afc2
 
7f17fe7
e66afc2
7f17fe7
 
e66afc2
7f17fe7
e66afc2
 
7f17fe7
e66afc2
7f17fe7
e66afc2
 
 
 
7f17fe7
e66afc2
7f17fe7
 
e66afc2
7f17fe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e66afc2
7f17fe7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from fastapi import FastAPI, Request
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
import torch
import pickle
import os
import psutil
import sys

app = FastAPI()
device = torch.device("cpu")

# category.pkl λ‘œλ“œ
try:
    with open("category.pkl", "rb") as f:
        category = pickle.load(f)
    print("βœ… category.pkl λ‘œλ“œ 성곡.")
except FileNotFoundError:
    print("❌ Error: category.pkl νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€. ν”„λ‘œμ νŠΈ λ£¨νŠΈμ— μžˆλŠ”μ§€ ν™•μΈν•˜μ„Έμš”.")
    sys.exit(1)

# ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
print("βœ… ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ 성곡.")

HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
HF_MODEL_FILENAME = "textClassifierModel.pt"

# λ©”λͺ¨λ¦¬ 확인
process = psutil.Process(os.getpid())
mem_before = process.memory_info().rss / (1024 * 1024)
print(f"πŸ“¦ λͺ¨λΈ λ‹€μš΄λ‘œλ“œ μ „ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_before:.2f} MB")

# λͺ¨λΈ λ‹€μš΄λ‘œλ“œ 및 λ‘œλ“œ
try:
    model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
    print(f"βœ… λͺ¨λΈ 파일 λ‹€μš΄λ‘œλ“œ 성곡: {model_path}")

    mem_after_dl = process.memory_info().rss / (1024 * 1024)
    print(f"πŸ“¦ λͺ¨λΈ λ‹€μš΄λ‘œλ“œ ν›„ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_after_dl:.2f} MB")

    model = torch.load(model_path, map_location=device)  # 전체 λͺ¨λΈ 객체 λ‘œλ“œ
    model.eval()

    mem_after_load = process.memory_info().rss / (1024 * 1024)
    print(f"πŸ“¦ λͺ¨λΈ λ‘œλ“œ ν›„ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_after_load:.2f} MB")
    print("βœ… λͺ¨λΈ λ‘œλ“œ 성곡")
except Exception as e:
    print(f"❌ Error: λͺ¨λΈ λ‹€μš΄λ‘œλ“œ λ˜λŠ” λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
    sys.exit(1)

# 예츑 API
@app.post("/predict")
async def predict_api(request: Request):
    data = await request.json()
    text = data.get("text")
    if not text:
        return {"error": "No text provided", "classification": "null"}

    encoded = tokenizer.encode_plus(
        text, max_length=64, padding='max_length', truncation=True, return_tensors='pt'
    )

    with torch.no_grad():
        outputs = model(**encoded)
        probs = torch.nn.functional.softmax(outputs.logits, dim=1)
        predicted = torch.argmax(probs, dim=1).item()

    label = list(category.keys())[predicted]
    return {"text": text, "classification": label}