hiddenFront commited on
Commit
ec61894
Β·
verified Β·
1 Parent(s): e66afc2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -1,11 +1,11 @@
1
  from fastapi import FastAPI, Request
2
- from transformers import AutoTokenizer
3
  from huggingface_hub import hf_hub_download
4
  import torch
5
  import pickle
6
  import os
7
- import psutil
8
  import sys
 
9
 
10
  app = FastAPI()
11
  device = torch.device("cpu")
@@ -16,22 +16,27 @@ try:
16
  category = pickle.load(f)
17
  print("βœ… category.pkl λ‘œλ“œ 성곡.")
18
  except FileNotFoundError:
19
- print("❌ Error: category.pkl νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€. ν”„λ‘œμ νŠΈ λ£¨νŠΈμ— μžˆλŠ”μ§€ ν™•μΈν•˜μ„Έμš”.")
20
  sys.exit(1)
21
 
22
  # ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
23
  tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
24
  print("βœ… ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ 성곡.")
25
 
 
 
 
 
 
26
  HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
27
  HF_MODEL_FILENAME = "textClassifierModel.pt"
28
 
29
- # λ©”λͺ¨λ¦¬ 확인
30
  process = psutil.Process(os.getpid())
31
  mem_before = process.memory_info().rss / (1024 * 1024)
32
  print(f"πŸ“¦ λͺ¨λΈ λ‹€μš΄λ‘œλ“œ μ „ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_before:.2f} MB")
33
 
34
- # λͺ¨λΈ λ‹€μš΄λ‘œλ“œ 및 λ‘œλ“œ
35
  try:
36
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
37
  print(f"βœ… λͺ¨λΈ 파일 λ‹€μš΄λ‘œλ“œ 성곡: {model_path}")
@@ -39,14 +44,16 @@ try:
39
  mem_after_dl = process.memory_info().rss / (1024 * 1024)
40
  print(f"πŸ“¦ λͺ¨λΈ λ‹€μš΄λ‘œλ“œ ν›„ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_after_dl:.2f} MB")
41
 
42
- model = torch.load(model_path, map_location=device) # 전체 λͺ¨λΈ 객체 λ‘œλ“œ
 
 
43
  model.eval()
44
 
45
  mem_after_load = process.memory_info().rss / (1024 * 1024)
46
  print(f"πŸ“¦ λͺ¨λΈ λ‘œλ“œ ν›„ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_after_load:.2f} MB")
47
- print("βœ… λͺ¨λΈ λ‘œλ“œ 성곡")
48
  except Exception as e:
49
- print(f"❌ Error: λͺ¨λΈ λ‹€μš΄λ‘œλ“œ λ˜λŠ” λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
50
  sys.exit(1)
51
 
52
  # 예츑 API
 
1
  from fastapi import FastAPI, Request
2
+ from transformers import BertForSequenceClassification, AutoTokenizer
3
  from huggingface_hub import hf_hub_download
4
  import torch
5
  import pickle
6
  import os
 
7
  import sys
8
+ import psutil
9
 
10
  app = FastAPI()
11
  device = torch.device("cpu")
 
16
  category = pickle.load(f)
17
  print("βœ… category.pkl λ‘œλ“œ 성곡.")
18
  except FileNotFoundError:
19
+ print("❌ Error: category.pkl νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.")
20
  sys.exit(1)
21
 
22
  # ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
23
  tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
24
  print("βœ… ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ 성곡.")
25
 
26
+ # λͺ¨λΈ ꡬ쑰 μž¬μ •μ˜
27
+ num_labels = len(category) # λΆ„λ₯˜ν•  클래슀 μˆ˜μ— 따라
28
+ model = BertForSequenceClassification.from_pretrained("skt/kobert-base-v1", num_labels=num_labels)
29
+ model.to(device)
30
+
31
  HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
32
  HF_MODEL_FILENAME = "textClassifierModel.pt"
33
 
34
+ # λ©”λͺ¨λ¦¬ μΈ‘μ • μ „
35
  process = psutil.Process(os.getpid())
36
  mem_before = process.memory_info().rss / (1024 * 1024)
37
  print(f"πŸ“¦ λͺ¨λΈ λ‹€μš΄λ‘œλ“œ μ „ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_before:.2f} MB")
38
 
39
+ # λͺ¨λΈ κ°€μ€‘μΉ˜ λ‹€μš΄λ‘œλ“œ
40
  try:
41
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
42
  print(f"βœ… λͺ¨λΈ 파일 λ‹€μš΄λ‘œλ“œ 성곡: {model_path}")
 
44
  mem_after_dl = process.memory_info().rss / (1024 * 1024)
45
  print(f"πŸ“¦ λͺ¨λΈ λ‹€μš΄λ‘œλ“œ ν›„ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_after_dl:.2f} MB")
46
 
47
+ # state_dict λ‘œλ“œ
48
+ state_dict = torch.load(model_path, map_location=device)
49
+ model.load_state_dict(state_dict)
50
  model.eval()
51
 
52
  mem_after_load = process.memory_info().rss / (1024 * 1024)
53
  print(f"πŸ“¦ λͺ¨λΈ λ‘œλ“œ ν›„ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_after_load:.2f} MB")
54
+ print("βœ… λͺ¨λΈ λ‘œλ“œ 및 μ€€λΉ„ μ™„λ£Œ.")
55
  except Exception as e:
56
+ print(f"❌ Error: λͺ¨λΈ λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
57
  sys.exit(1)
58
 
59
  # 예츑 API