KaiquanMah commited on
Commit
f5755a2
·
verified ·
1 Parent(s): b0bc543

Yair - Fixed CatBoost. Works with cat_features

Browse files
Files changed (1) hide show
  1. model_predictor.py +17 -1
model_predictor.py CHANGED
@@ -1,9 +1,24 @@
1
  import numpy as np
2
  import pandas as pd
 
 
3
 
4
  def predict(models, X_test):
5
  """ Make predictions using trained models """
6
- test_predictions = {name: np.array(model.predict(X_test)).squeeze() for name, model in models.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  test_predictions_df = pd.DataFrame(test_predictions)
9
 
@@ -15,3 +30,4 @@ def predict(models, X_test):
15
  test_predictions_df["is_click_predicted"] = test_predictions_df.max(axis=1)
16
 
17
  return test_predictions_df
 
 
1
  import numpy as np
2
  import pandas as pd
3
+ from catboost import Pool
4
+ from data_loader import CATEGORICAL_COLUMNS, IDS_COLUMNS, TARGET_COLUMN, FEATURE_COLUMNS, AGGREGATED_COLUMNS, TEMPORAL_COLUMNS
5
 
6
  def predict(models, X_test):
7
  """ Make predictions using trained models """
8
+ # Ensure categorical features are properly handled
9
+ cat_features = CATEGORICAL_COLUMNS
10
+ test_predictions = {}
11
+ #
12
+ # test_predictions = {name: np.array(model.predict(X_test)).squeeze() for name, model in models.items()}
13
+ for name, model in models.items():
14
+ if "CatBoost" in name: # Handle CatBoost models
15
+ pool = Pool(data=X_test, cat_features=cat_features)
16
+ test_predictions[name] = model.predict(pool)
17
+ else: # Other models
18
+ # reordering columns to match the order of columns in the model
19
+ new_X_test = X_test[IDS_COLUMNS + FEATURE_COLUMNS + AGGREGATED_COLUMNS + TEMPORAL_COLUMNS]
20
+ test_predictions[name] = np.array(model.predict(new_X_test)).squeeze()
21
+
22
 
23
  test_predictions_df = pd.DataFrame(test_predictions)
24
 
 
30
  test_predictions_df["is_click_predicted"] = test_predictions_df.max(axis=1)
31
 
32
  return test_predictions_df
33
+