File size: 2,801 Bytes
7f17fe7
9dd37b1
7f17fe7
 
 
e66afc2
7f17fe7
ec61894
7f17fe7
 
 
 
 
 
 
 
e66afc2
7f17fe7
ec61894
7f17fe7
 
 
 
e66afc2
7f17fe7
6ba018e
 
 
 
 
 
 
 
 
 
 
 
 
 
7f17fe7
 
 
ec61894
7f17fe7
e66afc2
 
7f17fe7
0914de7
7f17fe7
 
e66afc2
7f17fe7
ec61894
0914de7
 
 
 
 
 
7f17fe7
ec61894
7f17fe7
ec61894
7f17fe7
 
4607c9c
e66afc2
7f17fe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e66afc2
7f17fe7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from fastapi import FastAPI, Request
from transformers import BertModel, BertForSequenceClassification, AutoTokenizer
from huggingface_hub import hf_hub_download
import torch
import pickle
import os
import sys
import psutil

app = FastAPI()
device = torch.device("cpu")

# category.pkl ๋กœ๋“œ
try:
    with open("category.pkl", "rb") as f:
        category = pickle.load(f)
    print("โœ… category.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
except FileNotFoundError:
    print("โŒ Error: category.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
    sys.exit(1)

# ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
print("โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")

class CustomClassifier(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # ์ •์˜ํ–ˆ๋˜ ๊ตฌ์กฐ ๊ทธ๋Œ€๋กœ ๋ณต์›ํ•ด์•ผ ํ•จ
        self.bert = BertModel.from_pretrained("skt/kobert-base-v1")
        self.classifier = torch.nn.Linear(768, len(category))

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.bert(input_ids=input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids)
        pooled_output = outputs[1]  # CLS ํ† ํฐ
        return self.classifier(pooled_output)

HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
HF_MODEL_FILENAME = "textClassifierModel.pt"

# ๋ฉ”๋ชจ๋ฆฌ ์ธก์ • ์ „
process = psutil.Process(os.getpid())
mem_before = process.memory_info().rss / (1024 * 1024)
print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์ „ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_before:.2f} MB")

# ๋ชจ๋ธ ๋กœ๋“œ
try:
    model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
    print(f"โœ… ๋ชจ๋ธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ์„ฑ๊ณต: {model_path}")

    state_dict = torch.load(model_path, map_location=device)
    model = BertForSequenceClassification.from_pretrained(
        "skt/kobert-base-v1",
        num_labels=len(category),
        state_dict=state_dict,
    )
    model.to(device)
    model.eval()
    print("โœ… ๋ชจ๋ธ ๋กœ๋“œ ๋ฐ ์ค€๋น„ ์™„๋ฃŒ.")
except Exception as e:
    print(f"โŒ Error: ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
    sys.exit(1)


# ์˜ˆ์ธก API
@app.post("/predict")
async def predict_api(request: Request):
    data = await request.json()
    text = data.get("text")
    if not text:
        return {"error": "No text provided", "classification": "null"}

    encoded = tokenizer.encode_plus(
        text, max_length=64, padding='max_length', truncation=True, return_tensors='pt'
    )

    with torch.no_grad():
        outputs = model(**encoded)
        probs = torch.nn.functional.softmax(outputs.logits, dim=1)
        predicted = torch.argmax(probs, dim=1).item()

    label = list(category.keys())[predicted]
    return {"text": text, "classification": label}