Starberry15 commited on
Commit
16e585d
Β·
verified Β·
1 Parent(s): 20dc7a6

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +119 -68
src/streamlit_app.py CHANGED
@@ -7,14 +7,13 @@ import plotly.figure_factory as ff
7
  from dotenv import load_dotenv
8
  from huggingface_hub import InferenceClient, login
9
  from io import StringIO
10
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
11
 
12
  # ======================================================
13
  # βš™οΈ APP CONFIGURATION
14
  # ======================================================
15
  st.set_page_config(page_title="πŸ“Š Smart Data Analyst Pro", layout="wide")
16
  st.title("πŸ“Š Smart Data Analyst Pro")
17
- st.caption("AI that cleans, analyzes, and visualizes your data β€” powered by Hugging Face Inference API and local open-source models.")
18
 
19
  # ======================================================
20
  # πŸ” Load Environment Variables
@@ -27,7 +26,7 @@ else:
27
  login(token=HF_TOKEN)
28
 
29
  # ======================================================
30
- # 🧠 MODEL SETTINGS
31
  # ======================================================
32
  with st.sidebar:
33
  st.header("βš™οΈ Model Settings")
@@ -37,18 +36,17 @@ with st.sidebar:
37
  [
38
  "Qwen/Qwen2.5-Coder-7B-Instruct",
39
  "meta-llama/Meta-Llama-3-8B-Instruct",
40
- "microsoft/Phi-3-mini-4k-instruct"
 
41
  ],
42
  index=0
43
  )
44
 
45
  ANALYST_MODEL = st.selectbox(
46
- "Select Analysis Model (Local Open-Source Recommended):",
47
- [
48
- "meta-llama/Meta-Llama-3-8B-Instruct",
49
- "Qwen/Qwen2.5-Coder-7B-Instruct",
50
- "HuggingFaceH4/zephyr-7b-beta",
51
- "mistralai/Mistral-7B-Instruct-v0.3"
52
  ],
53
  index=0
54
  )
@@ -56,91 +54,122 @@ with st.sidebar:
56
  temperature = st.slider("Temperature", 0.0, 1.0, 0.3)
57
  max_tokens = st.slider("Max Tokens", 128, 2048, 512)
58
 
59
- # Initialize cleaner client (HF API)
60
  cleaner_client = InferenceClient(model=CLEANER_MODEL, token=HF_TOKEN)
61
-
62
- # Initialize local analyst if open-source
63
- local_analyst = None
64
- if ANALYST_MODEL in ["meta-llama/Meta-Llama-3-8B-Instruct"]:
65
- try:
66
- tokenizer = AutoTokenizer.from_pretrained(ANALYST_MODEL)
67
- model = AutoModelForCausalLM.from_pretrained(ANALYST_MODEL)
68
- local_analyst = pipeline("text-generation", model=model, tokenizer=tokenizer)
69
- except Exception as e:
70
- st.warning(f"⚠️ Failed to load local analyst: {e}")
71
 
72
  # ======================================================
73
- # 🧩 DATA CLEANING FUNCTIONS
74
  # ======================================================
75
  def fallback_clean(df: pd.DataFrame) -> pd.DataFrame:
 
76
  df = df.copy()
77
  df.dropna(axis=1, how="all", inplace=True)
78
  df.columns = [c.strip().replace(" ", "_").lower() for c in df.columns]
79
  for col in df.columns:
80
  if df[col].dtype == "O":
81
- df[col].fillna(df[col].mode()[0] if not df[col].mode().empty else "Unknown", inplace=True)
 
 
 
82
  else:
83
  df[col].fillna(df[col].median(), inplace=True)
84
  df.drop_duplicates(inplace=True)
85
  return df
86
 
 
87
  def ai_clean_dataset(df: pd.DataFrame) -> pd.DataFrame:
 
 
 
88
  raw_preview = df.head(5).to_csv(index=False)
89
  prompt = f"""
90
- You are a Python data cleaning expert.
91
- Clean and standardize the dataset dynamically:
92
- - Handle missing values logically
93
- - Correct and normalize column names
94
- - Detect and fix datatype inconsistencies
95
- - Remove duplicates or invalid rows
96
- Return ONLY valid CSV text (no Markdown).
97
 
98
  --- RAW SAMPLE ---
99
  {raw_preview}
100
  """
 
101
  try:
102
- response = cleaner_client.text_generation(prompt, max_new_tokens=1024, temperature=0.1, return_full_text=False)
 
 
 
 
 
 
103
  cleaned_str = response.strip()
104
  except Exception as e:
105
- st.warning(f"⚠️ AI cleaning failed: {e}")
106
- return fallback_clean(df)
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- cleaned_str = cleaned_str.replace("```csv","").replace("```","").replace("###","").replace(";",",").strip()
109
- lines = [l for l in cleaned_str.splitlines() if "," in l]
 
 
 
 
 
 
 
 
 
 
110
  cleaned_str = "\n".join(lines)
111
 
 
112
  try:
113
  cleaned_df = pd.read_csv(StringIO(cleaned_str), on_bad_lines="skip")
114
- cleaned_df.dropna(axis=1, how="all", inplace=True)
115
  cleaned_df.columns = [c.strip().replace(" ", "_").lower() for c in cleaned_df.columns]
116
  return cleaned_df
117
  except Exception as e:
118
- st.warning(f"⚠️ CSV parse failed: {e}")
119
  return fallback_clean(df)
120
 
 
121
  def summarize_dataframe(df: pd.DataFrame) -> str:
 
122
  lines = [f"Rows: {len(df)} | Columns: {len(df.columns)}", "Column summaries:"]
123
  for col in df.columns[:10]:
124
  non_null = int(df[col].notnull().sum())
125
  if pd.api.types.is_numeric_dtype(df[col]):
126
- mean = df[col].mean()
127
- median = df[col].median() if non_null > 0 else None
 
128
  lines.append(f"- {col}: mean={mean:.3f}, median={median}, non_null={non_null}")
129
  else:
130
  top = df[col].value_counts().head(3).to_dict()
131
  lines.append(f"- {col}: top_values={top}, non_null={non_null}")
132
  return "\n".join(lines)
133
 
134
- # ======================================================
135
- # 🧠 ANALYSIS FUNCTION
136
- # ======================================================
137
  def query_analysis_model(df: pd.DataFrame, user_query: str, dataset_name: str) -> str:
 
138
  df_summary = summarize_dataframe(df)
139
  sample = df.head(6).to_csv(index=False)
140
  prompt = f"""
141
- You are a data analyst.
142
- Analyze '{dataset_name}' and answer the question below.
143
- Base your insights only on the provided data.
144
 
145
  --- SUMMARY ---
146
  {df_summary}
@@ -148,32 +177,41 @@ Base your insights only on the provided data.
148
  --- SAMPLE DATA ---
149
  {sample}
150
 
151
- --- QUESTION ---
152
  {user_query}
153
 
154
- Respond concisely with key insights, numbers, patterns, and recommended steps.
 
 
 
 
155
  """
156
- if local_analyst:
157
- try:
158
- response = local_analyst(prompt, max_new_tokens=max_tokens, temperature=temperature)
159
- return response[0]['generated_text']
160
- except Exception as e:
161
- return f"⚠️ Local analysis failed: {e}"
162
- else:
163
- st.warning("⚠️ Analyst model is not local. Using HF API may require payment.")
164
- return "Analysis not available for free model."
 
 
 
 
 
 
 
 
 
165
 
166
  # ======================================================
167
- # πŸš€ MAIN APP
168
  # ======================================================
169
  uploaded = st.file_uploader("πŸ“Ž Upload CSV or Excel file", type=["csv", "xlsx"])
170
 
171
  if uploaded:
172
- try:
173
- df = pd.read_csv(uploaded) if uploaded.name.endswith(".csv") else pd.read_excel(uploaded)
174
- except Exception as e:
175
- st.error(f"❌ File load failed: {e}")
176
- st.stop()
177
 
178
  with st.spinner("🧼 AI Cleaning your dataset..."):
179
  cleaned_df = ai_clean_dataset(df)
@@ -181,33 +219,46 @@ if uploaded:
181
  st.subheader("βœ… Cleaned Dataset Preview")
182
  st.dataframe(cleaned_df.head(), use_container_width=True)
183
 
184
- with st.expander("πŸ“‹ Cleaning Summary"):
185
  st.text(summarize_dataframe(cleaned_df))
186
 
187
  with st.expander("πŸ“ˆ Quick Visualizations", expanded=True):
188
  numeric_cols = cleaned_df.select_dtypes(include="number").columns.tolist()
189
  categorical_cols = cleaned_df.select_dtypes(exclude="number").columns.tolist()
190
- viz_type = st.selectbox("Visualization Type", ["Scatter Plot", "Histogram", "Box Plot", "Correlation Heatmap", "Categorical Count"])
 
 
 
 
191
 
192
  if viz_type == "Scatter Plot" and len(numeric_cols) >= 2:
193
  x = st.selectbox("X-axis", numeric_cols)
194
- y = st.selectbox("Y-axis", numeric_cols, index=min(1,len(numeric_cols)-1))
195
  color = st.selectbox("Color", ["None"] + categorical_cols)
196
  fig = px.scatter(cleaned_df, x=x, y=y, color=None if color=="None" else color)
197
  st.plotly_chart(fig, use_container_width=True)
 
198
  elif viz_type == "Histogram" and numeric_cols:
199
  col = st.selectbox("Column", numeric_cols)
200
  fig = px.histogram(cleaned_df, x=col, nbins=30)
201
  st.plotly_chart(fig, use_container_width=True)
 
202
  elif viz_type == "Box Plot" and numeric_cols:
203
  col = st.selectbox("Column", numeric_cols)
204
  fig = px.box(cleaned_df, y=col)
205
  st.plotly_chart(fig, use_container_width=True)
 
206
  elif viz_type == "Correlation Heatmap" and len(numeric_cols) > 1:
207
  corr = cleaned_df[numeric_cols].corr()
208
- fig = ff.create_annotated_heatmap(z=corr.values, x=list(corr.columns), y=list(corr.index),
209
- annotation_text=corr.round(2).values, showscale=True)
 
 
 
 
 
210
  st.plotly_chart(fig, use_container_width=True)
 
211
  elif viz_type == "Categorical Count" and categorical_cols:
212
  cat = st.selectbox("Category", categorical_cols)
213
  fig = px.bar(cleaned_df[cat].value_counts().reset_index(), x="index", y=cat)
@@ -216,7 +267,7 @@ if uploaded:
216
  st.warning("⚠️ Not enough columns for this visualization type.")
217
 
218
  st.subheader("πŸ’¬ Ask AI About Your Data")
219
- user_query = st.text_area("Enter your question:", placeholder="e.g. What factors influence sales?")
220
  if st.button("Analyze with AI", use_container_width=True) and user_query:
221
  with st.spinner("πŸ€– Interpreting data..."):
222
  result = query_analysis_model(cleaned_df, user_query, uploaded.name)
 
7
  from dotenv import load_dotenv
8
  from huggingface_hub import InferenceClient, login
9
  from io import StringIO
 
10
 
11
  # ======================================================
12
  # βš™οΈ APP CONFIGURATION
13
  # ======================================================
14
  st.set_page_config(page_title="πŸ“Š Smart Data Analyst Pro", layout="wide")
15
  st.title("πŸ“Š Smart Data Analyst Pro")
16
+ st.caption("AI that cleans, analyzes, and visualizes your data β€” powered by Hugging Face Inference API.")
17
 
18
  # ======================================================
19
  # πŸ” Load Environment Variables
 
26
  login(token=HF_TOKEN)
27
 
28
  # ======================================================
29
+ # 🧠 MODEL SETUP
30
  # ======================================================
31
  with st.sidebar:
32
  st.header("βš™οΈ Model Settings")
 
36
  [
37
  "Qwen/Qwen2.5-Coder-7B-Instruct",
38
  "meta-llama/Meta-Llama-3-8B-Instruct",
39
+ "microsoft/Phi-3-mini-4k-instruct",
40
+ "mistralai/Mistral-7B-Instruct-v0.3"
41
  ],
42
  index=0
43
  )
44
 
45
  ANALYST_MODEL = st.selectbox(
46
+ "Select Analysis Model:",
47
+ [ "Qwen/Qwen2.5-14B-Instruct",
48
+ "mistralai/Mistral-7B-Instruct-v0.3",
49
+ "HuggingFaceH4/zephyr-7b-beta"
 
 
50
  ],
51
  index=0
52
  )
 
54
  temperature = st.slider("Temperature", 0.0, 1.0, 0.3)
55
  max_tokens = st.slider("Max Tokens", 128, 2048, 512)
56
 
57
+ # Initialize inference clients
58
  cleaner_client = InferenceClient(model=CLEANER_MODEL, token=HF_TOKEN)
59
+ analyst_client = InferenceClient(model=ANALYST_MODEL, token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
60
 
61
  # ======================================================
62
+ # 🧩 SMART DATA CLEANING
63
  # ======================================================
64
  def fallback_clean(df: pd.DataFrame) -> pd.DataFrame:
65
+ """Backup rule-based cleaner."""
66
  df = df.copy()
67
  df.dropna(axis=1, how="all", inplace=True)
68
  df.columns = [c.strip().replace(" ", "_").lower() for c in df.columns]
69
  for col in df.columns:
70
  if df[col].dtype == "O":
71
+ if not df[col].mode().empty:
72
+ df[col].fillna(df[col].mode()[0], inplace=True)
73
+ else:
74
+ df[col].fillna("Unknown", inplace=True)
75
  else:
76
  df[col].fillna(df[col].median(), inplace=True)
77
  df.drop_duplicates(inplace=True)
78
  return df
79
 
80
+
81
  def ai_clean_dataset(df: pd.DataFrame) -> pd.DataFrame:
82
+ """
83
+ Cleans the dataset using the selected AI model. Falls back gracefully if the model fails.
84
+ """
85
  raw_preview = df.head(5).to_csv(index=False)
86
  prompt = f"""
87
+ You are a professional data cleaning assistant.
88
+ Clean and standardize the dataset below dynamically:
89
+ 1. Handle missing values
90
+ 2. Fix column name inconsistencies
91
+ 3. Convert data types (dates, numbers, categories)
92
+ 4. Remove irrelevant or duplicate rows
93
+ Return ONLY a valid CSV text (no markdown, no explanations).
94
 
95
  --- RAW SAMPLE ---
96
  {raw_preview}
97
  """
98
+
99
  try:
100
+ # Try text-generation task first
101
+ response = cleaner_client.text_generation(
102
+ prompt,
103
+ max_new_tokens=1024,
104
+ temperature=0.1,
105
+ return_full_text=False,
106
+ )
107
  cleaned_str = response.strip()
108
  except Exception as e:
109
+ # Retry with chat completion if needed
110
+ if "Supported task: conversational" in str(e) or "not supported" in str(e):
111
+ try:
112
+ chat_resp = cleaner_client.chat_completion(
113
+ messages=[{"role": "user", "content": prompt}],
114
+ max_tokens=1024,
115
+ temperature=0.1,
116
+ )
117
+ cleaned_str = chat_resp["choices"][0]["message"]["content"].strip()
118
+ except Exception as e2:
119
+ st.warning(f"⚠️ AI cleaning failed (chat mode): {e2}")
120
+ return fallback_clean(df)
121
+ else:
122
+ st.warning(f"⚠️ AI cleaning failed ({e})")
123
+ return fallback_clean(df)
124
 
125
+ # Remove possible markdown/code fences
126
+ cleaned_str = (
127
+ cleaned_str.replace("```csv", "")
128
+ .replace("```", "")
129
+ .replace("###", "")
130
+ .replace(";", ",")
131
+ .strip()
132
+ )
133
+
134
+ # Keep only valid CSV-like lines
135
+ lines = cleaned_str.splitlines()
136
+ lines = [line for line in lines if "," in line and not line.lower().startswith(("note", "summary"))]
137
  cleaned_str = "\n".join(lines)
138
 
139
+ # Try parsing robustly
140
  try:
141
  cleaned_df = pd.read_csv(StringIO(cleaned_str), on_bad_lines="skip")
142
+ cleaned_df = cleaned_df.dropna(axis=1, how="all")
143
  cleaned_df.columns = [c.strip().replace(" ", "_").lower() for c in cleaned_df.columns]
144
  return cleaned_df
145
  except Exception as e:
146
+ st.warning(f"⚠️ AI CSV parse failed: {e}")
147
  return fallback_clean(df)
148
 
149
+
150
  def summarize_dataframe(df: pd.DataFrame) -> str:
151
+ """Generate a concise summary of the dataframe."""
152
  lines = [f"Rows: {len(df)} | Columns: {len(df.columns)}", "Column summaries:"]
153
  for col in df.columns[:10]:
154
  non_null = int(df[col].notnull().sum())
155
  if pd.api.types.is_numeric_dtype(df[col]):
156
+ desc = df[col].describe().to_dict()
157
+ mean = float(desc.get("mean", np.nan))
158
+ median = float(df[col].median()) if non_null > 0 else None
159
  lines.append(f"- {col}: mean={mean:.3f}, median={median}, non_null={non_null}")
160
  else:
161
  top = df[col].value_counts().head(3).to_dict()
162
  lines.append(f"- {col}: top_values={top}, non_null={non_null}")
163
  return "\n".join(lines)
164
 
165
+
 
 
166
  def query_analysis_model(df: pd.DataFrame, user_query: str, dataset_name: str) -> str:
167
+ """Send the dataframe and user query to the analysis model for interpretation."""
168
  df_summary = summarize_dataframe(df)
169
  sample = df.head(6).to_csv(index=False)
170
  prompt = f"""
171
+ You are a professional data analyst.
172
+ Analyze the dataset '{dataset_name}' and answer the user's question.
 
173
 
174
  --- SUMMARY ---
175
  {df_summary}
 
177
  --- SAMPLE DATA ---
178
  {sample}
179
 
180
+ --- USER QUESTION ---
181
  {user_query}
182
 
183
+ Respond with:
184
+ 1. Key insights and patterns
185
+ 2. Quantitative findings
186
+ 3. Notable relationships or anomalies
187
+ 4. Data-driven recommendations
188
  """
189
+ try:
190
+ response = analyst_client.text_generation(
191
+ prompt, temperature=temperature, max_new_tokens=max_tokens, return_full_text=False
192
+ )
193
+ return response.strip()
194
+ except Exception as e:
195
+ if "Supported task: conversational" in str(e) or "not supported" in str(e):
196
+ try:
197
+ chat_resp = analyst_client.chat_completion(
198
+ messages=[{"role": "user", "content": prompt}],
199
+ max_tokens=max_tokens,
200
+ temperature=temperature,
201
+ )
202
+ return chat_resp["choices"][0]["message"]["content"].strip()
203
+ except Exception as e2:
204
+ return f"⚠️ Analysis failed (chat mode): {e2}"
205
+ return f"⚠️ Analysis failed: {e}"
206
+
207
 
208
  # ======================================================
209
+ # πŸš€ MAIN APP LOGIC
210
  # ======================================================
211
  uploaded = st.file_uploader("πŸ“Ž Upload CSV or Excel file", type=["csv", "xlsx"])
212
 
213
  if uploaded:
214
+ df = pd.read_csv(uploaded) if uploaded.name.endswith(".csv") else pd.read_excel(uploaded)
 
 
 
 
215
 
216
  with st.spinner("🧼 AI Cleaning your dataset..."):
217
  cleaned_df = ai_clean_dataset(df)
 
219
  st.subheader("βœ… Cleaned Dataset Preview")
220
  st.dataframe(cleaned_df.head(), use_container_width=True)
221
 
222
+ with st.expander("πŸ“‹ Cleaning Summary", expanded=False):
223
  st.text(summarize_dataframe(cleaned_df))
224
 
225
  with st.expander("πŸ“ˆ Quick Visualizations", expanded=True):
226
  numeric_cols = cleaned_df.select_dtypes(include="number").columns.tolist()
227
  categorical_cols = cleaned_df.select_dtypes(exclude="number").columns.tolist()
228
+
229
+ viz_type = st.selectbox(
230
+ "Visualization Type",
231
+ ["Scatter Plot", "Histogram", "Box Plot", "Correlation Heatmap", "Categorical Count"]
232
+ )
233
 
234
  if viz_type == "Scatter Plot" and len(numeric_cols) >= 2:
235
  x = st.selectbox("X-axis", numeric_cols)
236
+ y = st.selectbox("Y-axis", numeric_cols, index=min(1, len(numeric_cols)-1))
237
  color = st.selectbox("Color", ["None"] + categorical_cols)
238
  fig = px.scatter(cleaned_df, x=x, y=y, color=None if color=="None" else color)
239
  st.plotly_chart(fig, use_container_width=True)
240
+
241
  elif viz_type == "Histogram" and numeric_cols:
242
  col = st.selectbox("Column", numeric_cols)
243
  fig = px.histogram(cleaned_df, x=col, nbins=30)
244
  st.plotly_chart(fig, use_container_width=True)
245
+
246
  elif viz_type == "Box Plot" and numeric_cols:
247
  col = st.selectbox("Column", numeric_cols)
248
  fig = px.box(cleaned_df, y=col)
249
  st.plotly_chart(fig, use_container_width=True)
250
+
251
  elif viz_type == "Correlation Heatmap" and len(numeric_cols) > 1:
252
  corr = cleaned_df[numeric_cols].corr()
253
+ fig = ff.create_annotated_heatmap(
254
+ z=corr.values,
255
+ x=list(corr.columns),
256
+ y=list(corr.index),
257
+ annotation_text=corr.round(2).values,
258
+ showscale=True
259
+ )
260
  st.plotly_chart(fig, use_container_width=True)
261
+
262
  elif viz_type == "Categorical Count" and categorical_cols:
263
  cat = st.selectbox("Category", categorical_cols)
264
  fig = px.bar(cleaned_df[cat].value_counts().reset_index(), x="index", y=cat)
 
267
  st.warning("⚠️ Not enough columns for this visualization type.")
268
 
269
  st.subheader("πŸ’¬ Ask AI About Your Data")
270
+ user_query = st.text_area("Enter your question:", placeholder="e.g. What factors influence sales the most?")
271
  if st.button("Analyze with AI", use_container_width=True) and user_query:
272
  with st.spinner("πŸ€– Interpreting data..."):
273
  result = query_analysis_model(cleaned_df, user_query, uploaded.name)