File size: 2,568 Bytes
7f17fe7
ec61894
7f17fe7
 
 
e66afc2
7f17fe7
ec61894
7f17fe7
 
 
 
 
 
 
 
e66afc2
7f17fe7
ec61894
7f17fe7
 
 
 
e66afc2
7f17fe7
ec61894
 
 
 
 
7f17fe7
 
 
ec61894
7f17fe7
e66afc2
 
7f17fe7
ec61894
7f17fe7
 
e66afc2
7f17fe7
e66afc2
 
7f17fe7
ec61894
 
 
7f17fe7
e66afc2
 
 
ec61894
7f17fe7
ec61894
7f17fe7
 
e66afc2
7f17fe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e66afc2
7f17fe7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from fastapi import FastAPI, Request
from transformers import BertForSequenceClassification, AutoTokenizer
from huggingface_hub import hf_hub_download
import torch
import pickle
import os
import sys
import psutil

app = FastAPI()
device = torch.device("cpu")

# category.pkl ๋กœ๋“œ
try:
    with open("category.pkl", "rb") as f:
        category = pickle.load(f)
    print("โœ… category.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
except FileNotFoundError:
    print("โŒ Error: category.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
    sys.exit(1)

# ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
print("โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")

# ๋ชจ๋ธ ๊ตฌ์กฐ ์žฌ์ •์˜
num_labels = len(category)  # ๋ถ„๋ฅ˜ํ•  ํด๋ž˜์Šค ์ˆ˜์— ๋”ฐ๋ผ
model = BertForSequenceClassification.from_pretrained("skt/kobert-base-v1", num_labels=num_labels)
model.to(device)

HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
HF_MODEL_FILENAME = "textClassifierModel.pt"

# ๋ฉ”๋ชจ๋ฆฌ ์ธก์ • ์ „
process = psutil.Process(os.getpid())
mem_before = process.memory_info().rss / (1024 * 1024)
print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์ „ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_before:.2f} MB")

# ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋‹ค์šด๋กœ๋“œ
try:
    model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
    print(f"โœ… ๋ชจ๋ธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ์„ฑ๊ณต: {model_path}")

    mem_after_dl = process.memory_info().rss / (1024 * 1024)
    print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_after_dl:.2f} MB")

    # state_dict ๋กœ๋“œ
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict)
    model.eval()

    mem_after_load = process.memory_info().rss / (1024 * 1024)
    print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋กœ๋“œ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_after_load:.2f} MB")
    print("โœ… ๋ชจ๋ธ ๋กœ๋“œ ๋ฐ ์ค€๋น„ ์™„๋ฃŒ.")
except Exception as e:
    print(f"โŒ Error: ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
    sys.exit(1)

# ์˜ˆ์ธก API
@app.post("/predict")
async def predict_api(request: Request):
    data = await request.json()
    text = data.get("text")
    if not text:
        return {"error": "No text provided", "classification": "null"}

    encoded = tokenizer.encode_plus(
        text, max_length=64, padding='max_length', truncation=True, return_tensors='pt'
    )

    with torch.no_grad():
        outputs = model(**encoded)
        probs = torch.nn.functional.softmax(outputs.logits, dim=1)
        predicted = torch.argmax(probs, dim=1).item()

    label = list(category.keys())[predicted]
    return {"text": text, "classification": label}