KaiquanMah commited on
Commit
9e1219a
·
verified ·
1 Parent(s): b1bb87f

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +46 -10
predict.py CHANGED
@@ -1,14 +1,50 @@
1
  import argparse
 
 
2
 
3
- def main():
4
- parser = argparse.ArgumentParser()
5
- parser.add_argument('--model-path', type=str, required=True, help='Path to the trained model')
6
- parser.add_argument('--input-data', type=str, required=True, help='Path to input data for prediction')
7
- args = parser.parse_args()
8
 
9
- print(f"Loading model from {args.model_path}")
10
- print(f"Predicting on data from {args.input_data}")
11
- # Add prediction logic here
 
12
 
13
- if __name__ == '__main__':
14
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import argparse
2
+ import pandas as pd
3
+ import joblib
4
 
 
 
 
 
 
5
 
6
+ def run_prediction():
7
+ input_path = "data/X_test_1st.csv"
8
+ output_path = "results/predictions.csv"
9
+ model_path = "model/best_model.pkl"
10
 
11
+ # Load data and model
12
+ df = pd.read_csv(input_path)
13
+ model = joblib.load(model_path)
14
+
15
+ # Preprocessing
16
+ features = [
17
+ 'product_category_1',
18
+ 'product_category_2',
19
+ 'user_depth',
20
+ 'age_level',
21
+ 'city_development_index',
22
+ 'var_1',
23
+ 'gender'
24
+ ]
25
+
26
+ X = df[features]
27
+ X = pd.get_dummies(X, columns=['gender'], drop_first=True)
28
+
29
+ # Predict
30
+ predictions = model.predict(X)
31
+
32
+ # Save predictions
33
+ df['predictions'] = predictions
34
+ df.to_csv(output_path, index=False)
35
+ print(f"Predictions saved to {output_path}")
36
+
37
+
38
+
39
+ # def main():
40
+ # parser = argparse.ArgumentParser()
41
+ # parser.add_argument('--model-path', type=str, required=True, help='Path to the trained model')
42
+ # parser.add_argument('--input-data', type=str, required=True, help='Path to input data for prediction')
43
+ # args = parser.parse_args()
44
+
45
+ # print(f"Loading model from {args.model_path}")
46
+ # print(f"Predicting on data from {args.input_data}")
47
+ # # Add prediction logic here
48
+
49
+ # if __name__ == '__main__':
50
+ # main()