File size: 2,054 Bytes
f01f652
331bb85
 
 
 
 
a7ec079
 
 
331bb85
 
 
 
b1bb87f
 
331bb85
 
b1bb87f
 
abfaf79
 
 
a7ec079
331bb85
 
 
 
 
a7ec079
 
 
 
 
 
 
 
 
331bb85
 
 
 
 
 
 
 
 
a7ec079
 
 
1131316
 
331bb85
 
 
 
 
f01f652
 
 
331bb85
 
f01f652
 
331bb85
f01f652
331bb85
 
f01f652
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import argparse
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV
import joblib
import wandb





def run_training(max_depth=10, min_samples_split=2):
    train_path = "data/train.csv"
    test_path = "data/test.csv"

    # Load dataset
    train_dataset = pd.read_csv(train_path)
    test_dataset = pd.read_csv(test_path)
    X_train = train_dataset.drop("is_click", axis="columns")
    y_train = train_dataset["is_click"]
    
    
    # Define DecisionTreeClassifier and GridSearch
    param_grid = {
        'max_depth': range(1, max_depth + 1),
        'min_samples_split': range(2, min_samples_split + 1)
    }
    # start a new wandb run to track this script
    wandb.init(
        # set the wandb project where this run will be logged
        project="dsip",
        # track hyperparameters and run metadata
        config=param_grid
    )

    

    grid_search = GridSearchCV(DecisionTreeClassifier(random_state=42), param_grid, cv=5, scoring='f1', verbose=1)
    grid_search.fit(X_train, y_train)

    # Save the best model
    best_model = grid_search.best_estimator_
    joblib.dump(best_model, "model/best_model.pkl")

    print("Training completed and best model saved.")
    wandb.log({"best_params_": grid_search.best_params_,
               "best_score_": grid_search.best_score_,
               "cv_results_": grid_search.cv_results_})

    wandb.finish()
    return grid_search.best_params_



    

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--max_depth', type=int, default=10, help='max_depth')
    parser.add_argument('--min_samples_split', type=int, default=2, help='min_samples_split')
    args = parser.parse_args()

    print(f"Training with max_depth={args.max_depth}, min_samples_split={args.min_samples_split}")
    # Add training logic here
    run_training(max_depth=args.max_depth, min_samples_split=args.min_samples_split)
    

if __name__ == '__main__':
    main()