File size: 3,099 Bytes
7f17fe7
9dd37b1
7f17fe7
 
 
e66afc2
7f17fe7
ec61894
7f17fe7
 
 
 
 
 
 
 
e66afc2
7f17fe7
ec61894
7f17fe7
 
 
 
e66afc2
7f17fe7
6ba018e
 
 
 
 
 
 
 
 
 
 
 
 
 
7f17fe7
 
 
ec61894
7f17fe7
e66afc2
 
7f17fe7
0914de7
7f17fe7
 
e66afc2
7f17fe7
ec61894
0914de7
 
 
 
 
 
7f17fe7
ec61894
7f17fe7
ec61894
7f17fe7
 
4607c9c
689eabe
 
 
 
 
 
 
 
 
 
e66afc2
7f17fe7
 
 
 
3dd80ec
7f17fe7
 
 
 
 
 
 
 
 
 
 
e66afc2
7f17fe7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from fastapi import FastAPI, Request
from transformers import BertModel, BertForSequenceClassification, AutoTokenizer
from huggingface_hub import hf_hub_download
import torch
import pickle
import os
import sys
import psutil

app = FastAPI()
device = torch.device("cpu")

# category.pkl ๋กœ๋“œ
try:
    with open("category.pkl", "rb") as f:
        category = pickle.load(f)
    print("โœ… category.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
except FileNotFoundError:
    print("โŒ Error: category.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
    sys.exit(1)

# ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
print("โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")

class CustomClassifier(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # ์ •์˜ํ–ˆ๋˜ ๊ตฌ์กฐ ๊ทธ๋Œ€๋กœ ๋ณต์›ํ•ด์•ผ ํ•จ
        self.bert = BertModel.from_pretrained("skt/kobert-base-v1")
        self.classifier = torch.nn.Linear(768, len(category))

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.bert(input_ids=input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids)
        pooled_output = outputs[1]  # CLS ํ† ํฐ
        return self.classifier(pooled_output)

HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
HF_MODEL_FILENAME = "textClassifierModel.pt"

# ๋ฉ”๋ชจ๋ฆฌ ์ธก์ • ์ „
process = psutil.Process(os.getpid())
mem_before = process.memory_info().rss / (1024 * 1024)
print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์ „ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_before:.2f} MB")

# ๋ชจ๋ธ ๋กœ๋“œ
try:
    model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
    print(f"โœ… ๋ชจ๋ธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ์„ฑ๊ณต: {model_path}")

    state_dict = torch.load(model_path, map_location=device)
    model = BertForSequenceClassification.from_pretrained(
        "skt/kobert-base-v1",
        num_labels=len(category),
        state_dict=state_dict,
    )
    model.to(device)
    model.eval()
    print("โœ… ๋ชจ๋ธ ๋กœ๋“œ ๋ฐ ์ค€๋น„ ์™„๋ฃŒ.")
except Exception as e:
    print(f"โŒ Error: ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
    sys.exit(1)


@app.get("/")
def root(request: Request):
    client_host = request.client.host
    client_port = request.client.port
    return {
        "message": "Text Classification API is running!",
        "client_ip": client_host,
        "client_port": client_port
    }

# ์˜ˆ์ธก API
@app.post("/predict")
async def predict_api(request: Request):
    data = await request.json()
    text = data.get("text")
    print("request date", data);
    if not text:
        return {"error": "No text provided", "classification": "null"}

    encoded = tokenizer.encode_plus(
        text, max_length=64, padding='max_length', truncation=True, return_tensors='pt'
    )

    with torch.no_grad():
        outputs = model(**encoded)
        probs = torch.nn.functional.softmax(outputs.logits, dim=1)
        predicted = torch.argmax(probs, dim=1).item()

    label = list(category.keys())[predicted]
    return {"text": text, "classification": label}