kurniawan commited on
Commit
a1bef06
·
1 Parent(s): 9b56554

Add cache check before downloading models

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +28 -20
src/streamlit_app.py CHANGED
@@ -3,6 +3,7 @@ import pickle
3
  import numpy as np
4
  import streamlit as st
5
  import gdown
 
6
 
7
  # File IDs
8
  model_id = "1HSQTjJ_hvBBmVJmYUmrkq5T7ubpfDwzF"
@@ -14,15 +15,19 @@ top_country_url = f"https://drive.google.com/uc?id={top_country_id}"
14
 
15
  @st.cache_resource
16
  def load_model():
17
- gdown.download(model_url, "best_rf_model.pkl", quiet=False)
18
- with open("best_rf_model.pkl", "rb") as f:
 
 
19
  return pickle.load(f)
20
 
21
 
22
  @st.cache_resource
23
  def load_top_country():
24
- gdown.download(top_country_url, "top_country.pkl", quiet=False)
25
- with open("top_country.pkl", "rb") as f:
 
 
26
  return pickle.load(f)
27
 
28
 
@@ -31,7 +36,8 @@ top_country = load_top_country()
31
 
32
  st.set_page_config(page_title="Hotel Booking Prediction", layout="wide")
33
 
34
- st.markdown("""
 
35
  <div style="
36
  background-color: white;
37
  padding: 50px;
@@ -51,14 +57,16 @@ st.markdown("""
51
  Fill in the form below to predict hotel booking!
52
  </p>
53
  </div>
54
- """, unsafe_allow_html=True)
 
 
55
 
56
  st.write("")
57
  st.write("")
58
 
59
  with st.form(key="hotel_bookings"):
60
  col1, col2 = st.columns(2)
61
-
62
  with col1:
63
  name = st.selectbox("Hotel Type", ("city_hotel", "resort_hotel"), index=0)
64
  lead = st.number_input(
@@ -88,7 +96,7 @@ with st.form(key="hotel_bookings"):
88
  ),
89
  index=0,
90
  )
91
-
92
  with col2:
93
  arrival_week = st.number_input(
94
  "Arrival Weeks",
@@ -108,29 +116,29 @@ with st.form(key="hotel_bookings"):
108
  )
109
 
110
  submitted = st.form_submit_button("Predict", use_container_width=True)
111
-
112
  if submitted:
113
  # Prepare data for prediction
114
  data = {
115
- 'hotel': name,
116
- 'lead_time': lead,
117
- 'arrival_date_year': int(arrival_year),
118
- 'arrival_date_month': arrival_month,
119
- 'arrival_date_week_number': arrival_week,
120
- 'arrival_date_day_of_month': arrival_day
121
  }
122
-
123
  df = pd.DataFrame([data])
124
-
125
  try:
126
  prediction = model.predict(df)
127
-
128
  st.success("Prediction Complete!")
129
-
130
  if prediction[0] == 1:
131
  st.error("⚠️ This booking is likely to be CANCELLED")
132
  else:
133
  st.success("✅ This booking is likely to be CONFIRMED")
134
-
135
  except Exception as e:
136
  st.error(f"Error making prediction: {str(e)}")
 
3
  import numpy as np
4
  import streamlit as st
5
  import gdown
6
+ import os
7
 
8
  # File IDs
9
  model_id = "1HSQTjJ_hvBBmVJmYUmrkq5T7ubpfDwzF"
 
15
 
16
  @st.cache_resource
17
  def load_model():
18
+ model_path = "best_rf_model.pkl"
19
+ if not os.path.exists(model_path):
20
+ gdown.download(model_url, model_path, quiet=False)
21
+ with open(model_path, "rb") as f:
22
  return pickle.load(f)
23
 
24
 
25
  @st.cache_resource
26
  def load_top_country():
27
+ country_path = "top_country.pkl"
28
+ if not os.path.exists(country_path):
29
+ gdown.download(top_country_url, country_path, quiet=False)
30
+ with open(country_path, "rb") as f:
31
  return pickle.load(f)
32
 
33
 
 
36
 
37
  st.set_page_config(page_title="Hotel Booking Prediction", layout="wide")
38
 
39
+ st.markdown(
40
+ """
41
  <div style="
42
  background-color: white;
43
  padding: 50px;
 
57
  Fill in the form below to predict hotel booking!
58
  </p>
59
  </div>
60
+ """,
61
+ unsafe_allow_html=True,
62
+ )
63
 
64
  st.write("")
65
  st.write("")
66
 
67
  with st.form(key="hotel_bookings"):
68
  col1, col2 = st.columns(2)
69
+
70
  with col1:
71
  name = st.selectbox("Hotel Type", ("city_hotel", "resort_hotel"), index=0)
72
  lead = st.number_input(
 
96
  ),
97
  index=0,
98
  )
99
+
100
  with col2:
101
  arrival_week = st.number_input(
102
  "Arrival Weeks",
 
116
  )
117
 
118
  submitted = st.form_submit_button("Predict", use_container_width=True)
119
+
120
  if submitted:
121
  # Prepare data for prediction
122
  data = {
123
+ "hotel": name,
124
+ "lead_time": lead,
125
+ "arrival_date_year": int(arrival_year),
126
+ "arrival_date_month": arrival_month,
127
+ "arrival_date_week_number": arrival_week,
128
+ "arrival_date_day_of_month": arrival_day,
129
  }
130
+
131
  df = pd.DataFrame([data])
132
+
133
  try:
134
  prediction = model.predict(df)
135
+
136
  st.success("Prediction Complete!")
137
+
138
  if prediction[0] == 1:
139
  st.error("⚠️ This booking is likely to be CANCELLED")
140
  else:
141
  st.success("✅ This booking is likely to be CONFIRMED")
142
+
143
  except Exception as e:
144
  st.error(f"Error making prediction: {str(e)}")