hiddenFront commited on
Commit
7f17fe7
ยท
verified ยท
1 Parent(s): 3d84f36

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from transformers import AutoTokenizer, BertForSequenceClassification, BertConfig
3
+ from huggingface_hub import hf_hub_download
4
+ import torch
5
+ import numpy as np
6
+ import pickle
7
+ import sys
8
+ import collections
9
+ import os # os ๋ชจ๋“ˆ ์ž„ํฌํŠธ
10
+ import psutil # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ํ™•์ธ์„ ์œ„ํ•ด psutil ์ž„ํฌํŠธ (requirements.txt์— ์ถ”๊ฐ€ ํ•„์š”)
11
+
12
+ app = FastAPI()
13
+ device = torch.device("cpu")
14
+
15
+ # category.pkl ๋กœ๋“œ
16
+ try:
17
+ with open("category.pkl", "rb") as f:
18
+ category = pickle.load(f)
19
+ print("category.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
20
+ except FileNotFoundError:
21
+ print("Error: category.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
22
+ sys.exit(1)
23
+
24
+ # ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
25
+ tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
26
+ print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
27
+
28
+ HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
29
+ HF_MODEL_FILENAME = "textClassifierModel.pt"
30
+
31
+ # --- ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๋กœ๊น… ์‹œ์ž‘ ---
32
+ process = psutil.Process(os.getpid())
33
+ mem_before_model_download = process.memory_info().rss / (1024 * 1024) # MB ๋‹จ์œ„
34
+ print(f"๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์ „ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_before_model_download:.2f} MB")
35
+ # --- ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๋กœ๊น… ๋ ---
36
+
37
+ try:
38
+ model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
39
+ print(f"๋ชจ๋ธ ํŒŒ์ผ์ด '{model_path}'์— ์„ฑ๊ณต์ ์œผ๋กœ ๋‹ค์šด๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
40
+
41
+ # --- ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๋กœ๊น… ์‹œ์ž‘ ---
42
+ mem_after_model_download = process.memory_info().rss / (1024 * 1024) # MB ๋‹จ์œ„
43
+ print(f"๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_after_model_download:.2f} MB")
44
+ # --- ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๋กœ๊น… ๋ ---
45
+
46
+ # 1. ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜ ์ •์˜ (๊ฐ€์ค‘์น˜๋Š” ๋กœ๋“œํ•˜์ง€ ์•Š๊ณ  ๊ตฌ์กฐ๋งŒ ์ดˆ๊ธฐํ™”)
47
+ config = BertConfig.from_pretrained("skt/kobert-base-v1", num_labels=len(category))
48
+ model = BertForSequenceClassification(config)
49
+
50
+ # 2. ๋‹ค์šด๋กœ๋“œ๋œ ํŒŒ์ผ์—์„œ state_dict๋ฅผ ๋กœ๋“œ
51
+ loaded_state_dict = torch.load(model_path, map_location=device)
52
+
53
+ # 3. ๋กœ๋“œ๋œ state_dict๋ฅผ ์ •์˜๋œ ๋ชจ๋ธ์— ์ ์šฉ
54
+ new_state_dict = collections.OrderedDict()
55
+ for k, v in loaded_state_dict.items():
56
+ name = k
57
+ if name.startswith('module.'):
58
+ name = name[7:]
59
+ new_state_dict[name] = v
60
+
61
+ model.load_state_dict(new_state_dict)
62
+
63
+ # --- ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๋กœ๊น… ์‹œ์ž‘ ---
64
+ mem_after_model_load = process.memory_info().rss / (1024 * 1024) # MB ๋‹จ์œ„
65
+ print(f"๋ชจ๋ธ ๋กœ๋“œ ๋ฐ state_dict ์ ์šฉ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_after_model_load:.2f} MB")
66
+ # --- ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๋กœ๊น… ๋ ---
67
+
68
+ model.eval()
69
+ print("๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต.")
70
+ except Exception as e:
71
+ print(f"Error: ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ๋˜๋Š” ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
72
+ sys.exit(1)
73
+
74
+ @app.post("/predict")
75
+ async def predict_api(request: Request):
76
+ data = await request.json()
77
+ text = data.get("text")
78
+ if not text:
79
+ return {"error": "No text provided", "classification": "null"}
80
+
81
+ encoded = tokenizer.encode_plus(
82
+ text, max_length=64, padding='max_length', truncation=True, return_tensors='pt'
83
+ )
84
+
85
+ with torch.no_grad():
86
+ outputs = model(**encoded)
87
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
88
+ predicted = torch.argmax(probs, dim=1).item()
89
+
90
+ label = list(category.keys())[predicted]
91
+ return {"text": text, "classification": label}