Danielos100 commited on
Commit
64b2f09
Β·
verified Β·
1 Parent(s): 6427279

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +283 -0
app.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # 🎁 Gift Recommender – Gradio app (English / USD)
3
+ # Dataset: ckandemir/amazon-products (Hugging Face)
4
+ # Baseline retrieval: TF-IDF + cosine (fast & dependency-light)
5
+ # Optional: switch to embeddings + FAISS by flipping USE_EMBEDDINGS to True.
6
+
7
+ import os
8
+ import re
9
+ import random
10
+ from typing import List, Dict, Tuple
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+ from datasets import load_dataset
15
+ from sklearn.feature_extraction.text import TfidfVectorizer
16
+ from sklearn.neighbors import NearestNeighbors
17
+ import gradio as gr
18
+
19
+ # ========= Configuration =========
20
+ USE_EMBEDDINGS = False # set True to try SentenceTransformers + FAISS (see TODO block below)
21
+ MAX_ROWS = int(os.getenv("MAX_ROWS", "5000")) # cap for speed
22
+ DEFAULT_OCCASIONS = "birthday, thank_you, housewarming"
23
+
24
+ # ========= Data Loading & Schema =========
25
+ def _to_price_usd(x):
26
+ s = str(x).strip()
27
+ s = s.replace("$", "").replace(",", "")
28
+ try:
29
+ return float(s)
30
+ except Exception:
31
+ return np.nan
32
+
33
+ def map_amazon_to_schema(df_raw: pd.DataFrame) -> pd.DataFrame:
34
+ # Normalize column lookup (case-insensitive)
35
+ cols = {c.lower().strip(): c for c in df_raw.columns}
36
+
37
+ # Source columns (case-insensitive)
38
+ get = lambda key: df_raw.get(cols.get(key, ""), "")
39
+
40
+ out = pd.DataFrame({
41
+ "name": get("product name"),
42
+ "short_desc": get("description"),
43
+ "tags": get("category"),
44
+ "price_usd": get("selling price").map(_to_price_usd) if "selling price" in cols else np.nan,
45
+ "age_range": "any",
46
+ "gender_tags": "any",
47
+ "occasion_tags": DEFAULT_OCCASIONS,
48
+ "persona_fit": get("category"),
49
+ "image_url": get("image") if "image" in cols else "",
50
+ })
51
+
52
+ # Basic cleaning
53
+ out["name"] = out["name"].astype(str).str.strip().str.slice(0, 120)
54
+ out["short_desc"] = out["short_desc"].astype(str).str.strip().str.slice(0, 400)
55
+ out["tags"] = out["tags"].astype(str).str.replace("|", ", ").str.lower()
56
+ out["persona_fit"] = out["persona_fit"].astype(str).str.lower()
57
+ return out
58
+
59
+ def build_doc(row: pd.Series) -> str:
60
+ parts = [
61
+ str(row.get("name", "")),
62
+ str(row.get("short_desc", "")),
63
+ str(row.get("tags", "")),
64
+ str(row.get("persona_fit", "")),
65
+ str(row.get("occasion_tags", "")),
66
+ ]
67
+ return " | ".join([p for p in parts if p])
68
+
69
+ def load_catalog() -> pd.DataFrame:
70
+ # Load HF dataset (internet required in Space). If it fails, create tiny fallback.
71
+ try:
72
+ ds = load_dataset("ckandemir/amazon-products", split="train")
73
+ raw = ds.to_pandas()
74
+ except Exception:
75
+ # Minimal fallback (keeps app alive even without internet)
76
+ raw = pd.DataFrame(
77
+ {
78
+ "Product Name": ["Wireless Earbuds", "Coffee Sampler", "Strategy Board Game"],
79
+ "Description": [
80
+ "Compact earbuds with noise isolation and long battery life.",
81
+ "Four single-origin roasts from small roasters.",
82
+ "Modern eurogame for 2–4 players, 45–60 minutes."
83
+ ],
84
+ "Category": ["Electronics | Audio", "Grocery | Coffee", "Toys & Games | Board Games"],
85
+ "Selling Price": ["$59.00", "$34.00", "$39.00"],
86
+ "Image": ["", "", ""],
87
+ }
88
+ )
89
+
90
+ df = map_amazon_to_schema(raw).drop_duplicates(subset=["name", "short_desc"])
91
+ if len(df) > MAX_ROWS:
92
+ df = df.sample(n=MAX_ROWS, random_state=42).reset_index(drop=True)
93
+ df["doc"] = df.apply(build_doc, axis=1)
94
+ return df
95
+
96
+ CATALOG = load_catalog()
97
+
98
+ # ========= Retrieval (Baseline: TF-IDF) =========
99
+ _vectorizer = TfidfVectorizer(min_df=1, ngram_range=(1, 2))
100
+ _X = _vectorizer.fit_transform(CATALOG["doc"].fillna(""))
101
+ _nn = NearestNeighbors(n_neighbors=10, metric="cosine").fit(_X)
102
+
103
+ def profile_to_query(profile: Dict) -> str:
104
+ interests = ", ".join(profile.get("interests", []))
105
+ occasion = profile.get("occasion", "")
106
+ budget = profile.get("budget_usd", "")
107
+ extras = profile.get("extras", "")
108
+ return f"{interests}. occasion: {occasion}. budget: {budget} USD. {extras}".strip()
109
+
110
+ def filter_business(df: pd.DataFrame, budget_min=None, budget_max=None, occasion: str = None) -> pd.DataFrame:
111
+ m = pd.Series(True, index=df.index)
112
+ if budget_min is not None:
113
+ m &= df["price_usd"].fillna(0) >= float(budget_min)
114
+ if budget_max is not None:
115
+ m &= df["price_usd"].fillna(1e9) <= float(budget_max)
116
+ if occasion:
117
+ # case-insensitive contains in occasion_tags
118
+ pattern = re.escape(str(occasion))
119
+ m &= df["occasion_tags"].fillna("").str.contains(pattern, case=False, regex=True)
120
+ return df[m]
121
+
122
+ def recommend_topk(profile: Dict, k: int = 3) -> pd.DataFrame:
123
+ q = profile_to_query(profile)
124
+ q_vec = _vectorizer.transform([q])
125
+
126
+ df_f = filter_business(
127
+ CATALOG,
128
+ profile.get("budget_min"),
129
+ profile.get("budget_max"),
130
+ profile.get("occasion"),
131
+ )
132
+ if df_f.empty:
133
+ df_f = CATALOG
134
+
135
+ idx = df_f.index.values
136
+ dists, inds = _nn.kneighbors(q_vec, n_neighbors=min(max(k * 4, k), len(df_f)))
137
+ cand_idx = idx[inds[0]]
138
+ d = dists[0]
139
+ order = np.argsort(d)
140
+ cand_idx = cand_idx[order]
141
+ d = d[order]
142
+
143
+ seen, picks = set(), []
144
+ for ci, dist in zip(cand_idx, d):
145
+ nm = CATALOG.loc[ci, "name"]
146
+ if nm in seen:
147
+ continue
148
+ seen.add(nm)
149
+ picks.append((ci, 1 - float(dist))) # similarity = 1 - distance
150
+ if len(picks) >= k:
151
+ break
152
+
153
+ res = CATALOG.loc[[ci for ci, _ in picks]].copy()
154
+ res["similarity"] = [sim for _, sim in picks]
155
+ return res[["name", "short_desc", "price_usd", "occasion_tags", "persona_fit", "image_url", "similarity"]]
156
+
157
+ # ========= Optional: Embeddings + FAISS (toggle USE_EMBEDDINGS=True) =========
158
+ # If you want to try embeddings, uncomment and flip the flag to True. This is optional.
159
+ # import faiss
160
+ # from sentence_transformers import SentenceTransformer
161
+ # _st_model = None
162
+ # _faiss_index = None
163
+ # def _build_embeddings_index(model_name="sentence-transformers/all-MiniLM-L6-v2"):
164
+ # global _st_model, _faiss_index
165
+ # _st_model = SentenceTransformer(model_name)
166
+ # embs = _st_model.encode(CATALOG["doc"].tolist(), convert_to_numpy=True, normalize_embeddings=True)
167
+ # _faiss_index = faiss.IndexFlatIP(embs.shape[1]) # cosine if normalized
168
+ # _faiss_index.add(embs)
169
+ # _MODEL_BUILT = False
170
+ #
171
+ # def recommend_topk_embeddings(profile: Dict, k: int = 3) -> pd.DataFrame:
172
+ # global _MODEL_BUILT
173
+ # if not _MODEL_BUILT:
174
+ # _build_embeddings_index()
175
+ # _MODEL_BUILT = True
176
+ # query = profile_to_query(profile)
177
+ # qv = _st_model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
178
+ # sims, idxs = _faiss_index.search(qv, min(max(k * 6, k), len(CATALOG)))
179
+ # order = np.argsort(-sims[0])
180
+ # picks = [int(i) for i in order[:k]]
181
+ # out = CATALOG.iloc[picks].copy()
182
+ # out["similarity"] = sims[0][order][:k]
183
+ # return out[["name", "short_desc", "price_usd", "occasion_tags", "persona_fit", "image_url", "similarity"]]
184
+
185
+ # ========= Generative placeholders (synthetic idea + message) =========
186
+ def generate_item(profile: Dict) -> Dict:
187
+ interests = profile.get("interests", [])
188
+ occasion = profile.get("occasion", "birthday")
189
+ budget = profile.get("budget_max", profile.get("budget_usd", 50)) or 50
190
+ style = random.choice(["personalized", "experience", "bundle"])
191
+ core = (interests[0] if interests else "hobby").strip()
192
+ if style == "personalized":
193
+ name = f"Custom {core} accessory with initials"
194
+ desc = f"Thoughtful personalized {core} accessory tailored to their taste."
195
+ elif style == "experience":
196
+ name = f"{core.title()} workshop voucher"
197
+ desc = f"A guided intro session to explore {core} in a fun, hands-on way."
198
+ else:
199
+ name = f"{core.title()} starter bundle"
200
+ desc = f"A curated set to kickstart their {core} passion."
201
+ return {
202
+ "name": f"{name} ({occasion})",
203
+ "short_desc": desc,
204
+ "price_usd": float(np.clip(float(budget), 20, 200)),
205
+ "occasion_tags": occasion,
206
+ "persona_fit": ", ".join(interests) or "general",
207
+ "image_url": ""
208
+ }
209
+
210
+ def generate_message(profile: Dict, language: str = "en") -> str:
211
+ name = profile.get("recipient_name", "Friend")
212
+ occasion = profile.get("occasion", "birthday")
213
+ tone = profile.get("tone", "warm and friendly")
214
+ return (
215
+ f"Dear {name},\n"
216
+ f"Happy {occasion}! Wishing you health, joy, and a year full of great memories. "
217
+ f"May your goals come true. With {tone}."
218
+ )
219
+
220
+ # ========= Gradio UI =========
221
+ EXAMPLES = [
222
+ ["music, fitness", "birthday", 20, 60, "Noa", "warm and friendly"],
223
+ ["coffee, remote work", "housewarming", 20, 40, "Daniel", "warm"],
224
+ ["travel, design", "hanukkah", 20, 70, "Maya", "friendly"],
225
+ ["photography, tech", "birthday", 30, 100, "Omer", "fun"],
226
+ ["wellness, yoga", "thank_you", 15, 35, "Lior", "heartfelt"],
227
+ ]
228
+
229
+ def ui_predict(interests: str, occasion: str, budget_min, budget_max, recipient_name: str, tone: str):
230
+ profile = {
231
+ "recipient_name": recipient_name or "Friend",
232
+ "interests": [s.strip() for s in (interests or "").split(",") if s.strip()],
233
+ "occasion": occasion or "birthday",
234
+ "budget_min": float(budget_min) if budget_min not in (None, "") else None,
235
+ "budget_max": float(budget_max) if budget_max not in (None, "") else None,
236
+ "budget_usd": float(budget_max) if budget_max not in (None, "") else 50.0,
237
+ "tone": tone or "warm and friendly",
238
+ }
239
+
240
+ # Retrieval
241
+ if USE_EMBEDDINGS:
242
+ # out_df = recommend_topk_embeddings(profile, k=3)
243
+ # For the template, we keep TF-IDF default. If you enable embeddings, uncomment the line above.
244
+ out_df = recommend_topk(profile, k=3)
245
+ else:
246
+ out_df = recommend_topk(profile, k=3)
247
+
248
+ # Generated
249
+ gen = generate_item(profile)
250
+ msg = generate_message(profile, language="en")
251
+
252
+ # Present results
253
+ top3_md = out_df[["name", "short_desc", "price_usd", "similarity"]].to_markdown(index=False)
254
+ gen_md = f"**{gen['name']}**\n\n{gen['short_desc']}\n\n~${gen['price_usd']:.0f}"
255
+ return top3_md, gen_md, msg
256
+
257
+ with gr.Blocks() as demo:
258
+ gr.Markdown("## 🎁 Gift Recommender β€” English / USD (Top-3 + 1 Generated + Message)")
259
+
260
+ with gr.Row():
261
+ interests = gr.Textbox(label="Interests (comma-separated)", value="music, fitness")
262
+ occasion = gr.Textbox(label="Occasion", value="birthday")
263
+
264
+ with gr.Row():
265
+ budget_min = gr.Number(label="Budget min (USD)", value=20)
266
+ budget_max = gr.Number(label="Budget max (USD)", value=60)
267
+
268
+ with gr.Row():
269
+ recipient_name = gr.Textbox(label="Recipient name", value="Noa")
270
+ tone = gr.Textbox(label="Message tone", value="warm and friendly")
271
+
272
+ go = gr.Button("Recommend 🎯")
273
+ out_top3 = gr.Markdown(label="Top-3 recommendations")
274
+ out_gen = gr.Markdown(label="Generated item")
275
+ out_msg = gr.Markdown(label="Personalized message")
276
+
277
+ gr.Examples(EXAMPLES, [interests, occasion, budget_min, budget_max, recipient_name, tone])
278
+ go.click(ui_predict, [interests, occasion, budget_min, budget_max, recipient_name, tone],
279
+ [out_top3, out_gen, out_msg])
280
+
281
+ # For Spaces
282
+ if __name__ == "__main__":
283
+ demo.launch()