hiddenFront commited on
Commit
4607c9c
ยท
verified ยท
1 Parent(s): 9dd37b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -7
app.py CHANGED
@@ -37,11 +37,6 @@ class CustomClassifier(torch.nn.Module):
37
  pooled_output = outputs[1] # CLS ํ† ํฐ
38
  return self.classifier(pooled_output)
39
 
40
-
41
- model = CustomClassifier()
42
- model.load_state_dict(torch.load(model_path, map_location=device))
43
- model.eval()
44
-
45
  HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
46
  HF_MODEL_FILENAME = "textClassifierModel.pt"
47
 
@@ -50,7 +45,6 @@ process = psutil.Process(os.getpid())
50
  mem_before = process.memory_info().rss / (1024 * 1024)
51
  print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์ „ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_before:.2f} MB")
52
 
53
- # ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋‹ค์šด๋กœ๋“œ
54
  try:
55
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
56
  print(f"โœ… ๋ชจ๋ธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ์„ฑ๊ณต: {model_path}")
@@ -58,7 +52,8 @@ try:
58
  mem_after_dl = process.memory_info().rss / (1024 * 1024)
59
  print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_after_dl:.2f} MB")
60
 
61
- # state_dict ๋กœ๋“œ
 
62
  state_dict = torch.load(model_path, map_location=device)
63
  model.load_state_dict(state_dict)
64
  model.eval()
@@ -70,6 +65,7 @@ except Exception as e:
70
  print(f"โŒ Error: ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
71
  sys.exit(1)
72
 
 
73
  # ์˜ˆ์ธก API
74
  @app.post("/predict")
75
  async def predict_api(request: Request):
 
37
  pooled_output = outputs[1] # CLS ํ† ํฐ
38
  return self.classifier(pooled_output)
39
 
 
 
 
 
 
40
  HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
41
  HF_MODEL_FILENAME = "textClassifierModel.pt"
42
 
 
45
  mem_before = process.memory_info().rss / (1024 * 1024)
46
  print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์ „ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_before:.2f} MB")
47
 
 
48
  try:
49
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
50
  print(f"โœ… ๋ชจ๋ธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ์„ฑ๊ณต: {model_path}")
 
52
  mem_after_dl = process.memory_info().rss / (1024 * 1024)
53
  print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_after_dl:.2f} MB")
54
 
55
+ # ๋ชจ๋ธ ๊ตฌ์„ฑ ๋ฐ state_dict ๋กœ๋“œ
56
+ model = CustomClassifier()
57
  state_dict = torch.load(model_path, map_location=device)
58
  model.load_state_dict(state_dict)
59
  model.eval()
 
65
  print(f"โŒ Error: ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
66
  sys.exit(1)
67
 
68
+
69
  # ์˜ˆ์ธก API
70
  @app.post("/predict")
71
  async def predict_api(request: Request):