Spaces:
Configuration error
Configuration error
| import argparse | |
| import pandas as pd | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.tree import DecisionTreeClassifier | |
| from sklearn.model_selection import GridSearchCV | |
| import joblib | |
| import wandb | |
| def run_training(max_depth=10, min_samples_split=2): | |
| train_path = "data/train.csv" | |
| test_path = "data/test.csv" | |
| # Load dataset | |
| train_dataset = pd.read_csv(train_path) | |
| test_dataset = pd.read_csv(test_path) | |
| X_train = train_dataset.drop("is_click", axis="columns") | |
| y_train = train_dataset["is_click"] | |
| # Define DecisionTreeClassifier and GridSearch | |
| param_grid = { | |
| 'max_depth': range(1, max_depth + 1), | |
| 'min_samples_split': range(2, min_samples_split + 1) | |
| } | |
| # start a new wandb run to track this script | |
| wandb.init( | |
| # set the wandb project where this run will be logged | |
| project="dsip", | |
| # track hyperparameters and run metadata | |
| config=param_grid | |
| ) | |
| grid_search = GridSearchCV(DecisionTreeClassifier(random_state=42), param_grid, cv=5, scoring='f1', verbose=1) | |
| grid_search.fit(X_train, y_train) | |
| # Save the best model | |
| best_model = grid_search.best_estimator_ | |
| joblib.dump(best_model, "model/best_model.pkl") | |
| print("Training completed and best model saved.") | |
| wandb.log({"best_params_": grid_search.best_params_, | |
| "best_score_": grid_search.best_score_, | |
| "cv_results_": grid_search.cv_results_}) | |
| wandb.finish() | |
| return grid_search.best_params_ | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--max_depth', type=int, default=10, help='max_depth') | |
| parser.add_argument('--min_samples_split', type=int, default=2, help='min_samples_split') | |
| args = parser.parse_args() | |
| print(f"Training with max_depth={args.max_depth}, min_samples_split={args.min_samples_split}") | |
| # Add training logic here | |
| run_training(max_depth=args.max_depth, min_samples_split=args.min_samples_split) | |
| if __name__ == '__main__': | |
| main() |