hiddenFront commited on
Commit
e66afc2
Β·
verified Β·
1 Parent(s): 4e6dd7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -42
app.py CHANGED
@@ -1,13 +1,11 @@
1
  from fastapi import FastAPI, Request
2
- from transformers import AutoTokenizer, BertForSequenceClassification, BertConfig
3
  from huggingface_hub import hf_hub_download
4
  import torch
5
- import numpy as np
6
  import pickle
 
 
7
  import sys
8
- import collections
9
- import os # os λͺ¨λ“ˆ μž„ν¬νŠΈ
10
- import psutil # λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰ 확인을 μœ„ν•΄ psutil μž„ν¬νŠΈ (requirements.txt에 μΆ”κ°€ ν•„μš”)
11
 
12
  app = FastAPI()
13
  device = torch.device("cpu")
@@ -16,61 +14,42 @@ device = torch.device("cpu")
16
  try:
17
  with open("category.pkl", "rb") as f:
18
  category = pickle.load(f)
19
- print("category.pkl λ‘œλ“œ 성곡.")
20
  except FileNotFoundError:
21
- print("Error: category.pkl νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€. ν”„λ‘œμ νŠΈ λ£¨νŠΈμ— μžˆλŠ”μ§€ ν™•μΈν•˜μ„Έμš”.")
22
  sys.exit(1)
23
 
24
  # ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
25
  tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
26
- print("ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ 성곡.")
27
 
28
  HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
29
  HF_MODEL_FILENAME = "textClassifierModel.pt"
30
 
31
- # --- λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰ λ‘œκΉ… μ‹œμž‘ ---
32
  process = psutil.Process(os.getpid())
33
- mem_before_model_download = process.memory_info().rss / (1024 * 1024) # MB λ‹¨μœ„
34
- print(f"λͺ¨λΈ λ‹€μš΄λ‘œλ“œ μ „ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_before_model_download:.2f} MB")
35
- # --- λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰ λ‘œκΉ… 끝 ---
36
 
 
37
  try:
38
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
39
- print(f"λͺ¨λΈ 파일이 '{model_path}'에 μ„±κ³΅μ μœΌλ‘œ λ‹€μš΄λ‘œλ“œλ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
40
 
41
- # --- λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰ λ‘œκΉ… μ‹œμž‘ ---
42
- mem_after_model_download = process.memory_info().rss / (1024 * 1024) # MB λ‹¨μœ„
43
- print(f"λͺ¨λΈ λ‹€μš΄λ‘œλ“œ ν›„ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_after_model_download:.2f} MB")
44
- # --- λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰ λ‘œκΉ… 끝 ---
45
-
46
- # 1. λͺ¨λΈ μ•„ν‚€ν…μ²˜ μ •μ˜ (κ°€μ€‘μΉ˜λŠ” λ‘œλ“œν•˜μ§€ μ•Šκ³  ꡬ쑰만 μ΄ˆκΈ°ν™”)
47
- config = BertConfig.from_pretrained("skt/kobert-base-v1", num_labels=len(category))
48
- model = BertForSequenceClassification(config)
49
-
50
- # 2. λ‹€μš΄λ‘œλ“œλœ νŒŒμΌμ—μ„œ state_dictλ₯Ό λ‘œλ“œ
51
- loaded_state_dict = torch.load(model_path, map_location=device)
52
-
53
- # 3. λ‘œλ“œλœ state_dictλ₯Ό μ •μ˜λœ λͺ¨λΈμ— 적용
54
- new_state_dict = collections.OrderedDict()
55
- for k, v in loaded_state_dict.items():
56
- name = k
57
- if name.startswith('module.'):
58
- name = name[7:]
59
- new_state_dict[name] = v
60
-
61
- model.load_state_dict(new_state_dict)
62
-
63
- # --- λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰ λ‘œκΉ… μ‹œμž‘ ---
64
- mem_after_model_load = process.memory_info().rss / (1024 * 1024) # MB λ‹¨μœ„
65
- print(f"λͺ¨λΈ λ‘œλ“œ 및 state_dict 적용 ν›„ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_after_model_load:.2f} MB")
66
- # --- λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰ λ‘œκΉ… 끝 ---
67
 
 
68
  model.eval()
69
- print("λͺ¨λΈ λ‘œλ“œ 성곡.")
 
 
 
70
  except Exception as e:
71
- print(f"Error: λͺ¨λΈ λ‹€μš΄λ‘œλ“œ λ˜λŠ” λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
72
  sys.exit(1)
73
 
 
74
  @app.post("/predict")
75
  async def predict_api(request: Request):
76
  data = await request.json()
@@ -86,6 +65,6 @@ async def predict_api(request: Request):
86
  outputs = model(**encoded)
87
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
88
  predicted = torch.argmax(probs, dim=1).item()
89
-
90
  label = list(category.keys())[predicted]
91
  return {"text": text, "classification": label}
 
1
  from fastapi import FastAPI, Request
2
+ from transformers import AutoTokenizer
3
  from huggingface_hub import hf_hub_download
4
  import torch
 
5
  import pickle
6
+ import os
7
+ import psutil
8
  import sys
 
 
 
9
 
10
  app = FastAPI()
11
  device = torch.device("cpu")
 
14
  try:
15
  with open("category.pkl", "rb") as f:
16
  category = pickle.load(f)
17
+ print("βœ… category.pkl λ‘œλ“œ 성곡.")
18
  except FileNotFoundError:
19
+ print("❌ Error: category.pkl νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€. ν”„λ‘œμ νŠΈ λ£¨νŠΈμ— μžˆλŠ”μ§€ ν™•μΈν•˜μ„Έμš”.")
20
  sys.exit(1)
21
 
22
  # ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
23
  tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
24
+ print("βœ… ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ 성곡.")
25
 
26
  HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
27
  HF_MODEL_FILENAME = "textClassifierModel.pt"
28
 
29
+ # λ©”λͺ¨λ¦¬ 확인
30
  process = psutil.Process(os.getpid())
31
+ mem_before = process.memory_info().rss / (1024 * 1024)
32
+ print(f"πŸ“¦ λͺ¨λΈ λ‹€μš΄λ‘œλ“œ μ „ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_before:.2f} MB")
 
33
 
34
+ # λͺ¨λΈ λ‹€μš΄λ‘œλ“œ 및 λ‘œλ“œ
35
  try:
36
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
37
+ print(f"βœ… λͺ¨λΈ 파일 λ‹€μš΄λ‘œλ“œ 성곡: {model_path}")
38
 
39
+ mem_after_dl = process.memory_info().rss / (1024 * 1024)
40
+ print(f"πŸ“¦ λͺ¨λΈ λ‹€μš΄λ‘œλ“œ ν›„ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_after_dl:.2f} MB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ model = torch.load(model_path, map_location=device) # 전체 λͺ¨λΈ 객체 λ‘œλ“œ
43
  model.eval()
44
+
45
+ mem_after_load = process.memory_info().rss / (1024 * 1024)
46
+ print(f"πŸ“¦ λͺ¨λΈ λ‘œλ“œ ν›„ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_after_load:.2f} MB")
47
+ print("βœ… λͺ¨λΈ λ‘œλ“œ 성곡")
48
  except Exception as e:
49
+ print(f"❌ Error: λͺ¨λΈ λ‹€μš΄λ‘œλ“œ λ˜λŠ” λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
50
  sys.exit(1)
51
 
52
+ # 예츑 API
53
  @app.post("/predict")
54
  async def predict_api(request: Request):
55
  data = await request.json()
 
65
  outputs = model(**encoded)
66
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
67
  predicted = torch.argmax(probs, dim=1).item()
68
+
69
  label = list(category.keys())[predicted]
70
  return {"text": text, "classification": label}