KaiquanMah commited on
Commit
b0bc543
·
verified ·
1 Parent(s): 6454725

Yair - Added error handling for XGB model

Browse files
Files changed (1) hide show
  1. model_manager.py +5 -3
model_manager.py CHANGED
@@ -6,7 +6,9 @@ from config import CATBOOST_MODEL_PATH, XGB_MODEL_PATH, RF_MODEL_PATH
6
  def save_models(models):
7
  """ Save trained models """
8
  models["CatBoost"].save_model(CATBOOST_MODEL_PATH)
9
- models["XGBoost"].save_model(XGB_MODEL_PATH)
 
 
10
  joblib.dump(models["RandomForest"], RF_MODEL_PATH)
11
  print("✅ Models saved successfully!")
12
 
@@ -15,9 +17,9 @@ def load_models():
15
  catboost = CatBoostClassifier()
16
  catboost.load_model(CATBOOST_MODEL_PATH)
17
 
18
- xgb = XGBClassifier()
19
  xgb.load_model(XGB_MODEL_PATH)
20
 
21
  rf = joblib.load(RF_MODEL_PATH)
22
 
23
- return {"CatBoost": catboost, "XGBoost": xgb, "RandomForest": rf}
 
6
  def save_models(models):
7
  """ Save trained models """
8
  models["CatBoost"].save_model(CATBOOST_MODEL_PATH)
9
+ if models["XGBoost"] is not None:
10
+ # Save XGBoost model in binary format to reduce memory usage
11
+ models["XGBoost"].get_booster().save_model(XGB_MODEL_PATH)
12
  joblib.dump(models["RandomForest"], RF_MODEL_PATH)
13
  print("✅ Models saved successfully!")
14
 
 
17
  catboost = CatBoostClassifier()
18
  catboost.load_model(CATBOOST_MODEL_PATH)
19
 
20
+ xgb = XGBClassifier() # Load XGBoost model in binary format
21
  xgb.load_model(XGB_MODEL_PATH)
22
 
23
  rf = joblib.load(RF_MODEL_PATH)
24
 
25
+ return {"CatBoost": catboost, "XGBoost": xgb, "RandomForest": rf}