DSIP / train.py
KaiquanMah's picture
Update train.py
abfaf79 verified
raw
history blame
2.05 kB
import argparse
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV
import joblib
import wandb
def run_training(max_depth=10, min_samples_split=2):
train_path = "data/train.csv"
test_path = "data/test.csv"
# Load dataset
train_dataset = pd.read_csv(train_path)
test_dataset = pd.read_csv(test_path)
X_train = train_dataset.drop("is_click", axis="columns")
y_train = train_dataset["is_click"]
# Define DecisionTreeClassifier and GridSearch
param_grid = {
'max_depth': range(1, max_depth + 1),
'min_samples_split': range(2, min_samples_split + 1)
}
# start a new wandb run to track this script
wandb.init(
# set the wandb project where this run will be logged
project="dsip",
# track hyperparameters and run metadata
config=param_grid
)
grid_search = GridSearchCV(DecisionTreeClassifier(random_state=42), param_grid, cv=5, scoring='f1', verbose=1)
grid_search.fit(X_train, y_train)
# Save the best model
best_model = grid_search.best_estimator_
joblib.dump(best_model, "model/best_model.pkl")
print("Training completed and best model saved.")
wandb.log({"best_params_": grid_search.best_params_,
"best_score_": grid_search.best_score_,
"cv_results_": grid_search.cv_results_})
wandb.finish()
return grid_search.best_params_
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--max_depth', type=int, default=10, help='max_depth')
parser.add_argument('--min_samples_split', type=int, default=2, help='min_samples_split')
args = parser.parse_args()
print(f"Training with max_depth={args.max_depth}, min_samples_split={args.min_samples_split}")
# Add training logic here
run_training(max_depth=args.max_depth, min_samples_split=args.min_samples_split)
if __name__ == '__main__':
main()