Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -37,11 +37,6 @@ class CustomClassifier(torch.nn.Module):
|
|
| 37 |
pooled_output = outputs[1] # CLS ํ ํฐ
|
| 38 |
return self.classifier(pooled_output)
|
| 39 |
|
| 40 |
-
|
| 41 |
-
model = CustomClassifier()
|
| 42 |
-
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 43 |
-
model.eval()
|
| 44 |
-
|
| 45 |
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
|
| 46 |
HF_MODEL_FILENAME = "textClassifierModel.pt"
|
| 47 |
|
|
@@ -50,7 +45,6 @@ process = psutil.Process(os.getpid())
|
|
| 50 |
mem_before = process.memory_info().rss / (1024 * 1024)
|
| 51 |
print(f"๐ฆ ๋ชจ๋ธ ๋ค์ด๋ก๋ ์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋: {mem_before:.2f} MB")
|
| 52 |
|
| 53 |
-
# ๋ชจ๋ธ ๊ฐ์ค์น ๋ค์ด๋ก๋
|
| 54 |
try:
|
| 55 |
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
|
| 56 |
print(f"โ
๋ชจ๋ธ ํ์ผ ๋ค์ด๋ก๋ ์ฑ๊ณต: {model_path}")
|
|
@@ -58,7 +52,8 @@ try:
|
|
| 58 |
mem_after_dl = process.memory_info().rss / (1024 * 1024)
|
| 59 |
print(f"๐ฆ ๋ชจ๋ธ ๋ค์ด๋ก๋ ํ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋: {mem_after_dl:.2f} MB")
|
| 60 |
|
| 61 |
-
# state_dict ๋ก๋
|
|
|
|
| 62 |
state_dict = torch.load(model_path, map_location=device)
|
| 63 |
model.load_state_dict(state_dict)
|
| 64 |
model.eval()
|
|
@@ -70,6 +65,7 @@ except Exception as e:
|
|
| 70 |
print(f"โ Error: ๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 71 |
sys.exit(1)
|
| 72 |
|
|
|
|
| 73 |
# ์์ธก API
|
| 74 |
@app.post("/predict")
|
| 75 |
async def predict_api(request: Request):
|
|
|
|
| 37 |
pooled_output = outputs[1] # CLS ํ ํฐ
|
| 38 |
return self.classifier(pooled_output)
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
|
| 41 |
HF_MODEL_FILENAME = "textClassifierModel.pt"
|
| 42 |
|
|
|
|
| 45 |
mem_before = process.memory_info().rss / (1024 * 1024)
|
| 46 |
print(f"๐ฆ ๋ชจ๋ธ ๋ค์ด๋ก๋ ์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋: {mem_before:.2f} MB")
|
| 47 |
|
|
|
|
| 48 |
try:
|
| 49 |
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
|
| 50 |
print(f"โ
๋ชจ๋ธ ํ์ผ ๋ค์ด๋ก๋ ์ฑ๊ณต: {model_path}")
|
|
|
|
| 52 |
mem_after_dl = process.memory_info().rss / (1024 * 1024)
|
| 53 |
print(f"๐ฆ ๋ชจ๋ธ ๋ค์ด๋ก๋ ํ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋: {mem_after_dl:.2f} MB")
|
| 54 |
|
| 55 |
+
# ๋ชจ๋ธ ๊ตฌ์ฑ ๋ฐ state_dict ๋ก๋
|
| 56 |
+
model = CustomClassifier()
|
| 57 |
state_dict = torch.load(model_path, map_location=device)
|
| 58 |
model.load_state_dict(state_dict)
|
| 59 |
model.eval()
|
|
|
|
| 65 |
print(f"โ Error: ๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 66 |
sys.exit(1)
|
| 67 |
|
| 68 |
+
|
| 69 |
# ์์ธก API
|
| 70 |
@app.post("/predict")
|
| 71 |
async def predict_api(request: Request):
|