KaiquanMah commited on
Commit
832efdd
·
verified ·
1 Parent(s): 26f170a

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +3 -0
train.py CHANGED
@@ -5,6 +5,7 @@ from sklearn.tree import DecisionTreeClassifier
5
  from sklearn.model_selection import GridSearchCV
6
  import joblib
7
  import wandb
 
8
 
9
 
10
 
@@ -41,6 +42,8 @@ def run_training(max_depth=10, min_samples_split=2):
41
 
42
  # Save the best model
43
  best_model = grid_search.best_estimator_
 
 
44
  joblib.dump(best_model, "model/best_model.pkl")
45
 
46
  print("Training completed and best model saved.")
 
5
  from sklearn.model_selection import GridSearchCV
6
  import joblib
7
  import wandb
8
+ import os
9
 
10
 
11
 
 
42
 
43
  # Save the best model
44
  best_model = grid_search.best_estimator_
45
+ # Create the 'model' directory if it doesn't exist
46
+ os.makedirs('model', exist_ok=True)
47
  joblib.dump(best_model, "model/best_model.pkl")
48
 
49
  print("Training completed and best model saved.")