import gradio as gr import time from typing import List, Dict, Tuple, Callable from pathlib import Path from config import ( GRADIO_THEME, CUSTOM_CSS, EXAMPLE_QUERIES_BY_CATEGORY, ) from src.search.bm25_lexical_search import search_bm25 from src.search.vector_search import ( search_vector, search_vector_with_expansion, search_vector_with_reranking, ) from src.data_prep.data_prep import load_clean_amazon_product_data _FILE_PATH = Path(__file__).parents[1] def format_results(results: List[Dict], stage_name: str, metrics: Dict) -> str: """Format search results as HTML. Args: results: List of dicts with keys: product_name, description, main_category, secondary_category, score stage_name: Name of the search stage metrics: Dict with keys: top1_score, top5_avg, latency_ms """ html_parts = [ f"## 🔍 {stage_name}\n\n", f"""
TOP-1 SCORE
{metrics['top1_score']:.3f}
Best result
TOP-5 AVG
{metrics['top5_avg']:.3f}
Overall quality
LATENCY
{metrics['latency_ms']}ms
Response time
""", '
\n\n', ] # Performance metrics at the top with prominent styling # Results section for idx, result in enumerate(results, 1): category = f"{result.get('main_category', 'N/A')} > {result.get('secondary_category', 'N/A')}" html_parts.append( f"""
{idx}. {result['product_name']}
{result['description'][:150]}...
Category: {category}
""" ) html_parts.append("
") return "".join(html_parts) def run_search_function_and_time(query: str, func: Callable, top_n: int = 5): start = time.time() results = func(query) latency = int((time.time() - start) * 1000) return results[:top_n], latency def get_average_score(results: List[Dict]) -> float: return sum(r["score"] for r in results) / len(results) if results else 0 def get_weighted_score(results: List[Dict]) -> float: """ Calculate position-weighted average score. Top positions get higher weight (5x for #1, 4x for #2, etc.) This rewards ranking quality - putting best results at the top. Args: results: List of search results with 'score' field Returns: Weighted average score (0-1 scale) """ if not results: return 0.0 weights = [5, 4, 3, 2, 1] total_weight = sum(weights) weighted_sum = sum((weights[i] * r["score"]) for i, r in enumerate(results[:5])) return weighted_sum / total_weight def search_stage_1(query: str) -> Tuple[str, Dict]: """Stage 1: Baseline BM25 keyword search.""" results, latency = run_search_function_and_time(query, search_bm25) top1_score = results[0]["score"] / 5.0 if results else 0.0 # Normalize BM25 scores top5_avg = get_average_score(results) / 5.0 if results else 0.0 metrics = { "top1_score": min(1.0, top1_score), "top5_avg": min(1.0, top5_avg), "latency_ms": latency, } print(f"Searched BM25 for {query} in {latency}ms") return format_results(results, "Stage 1: BM25 Baseline", metrics), metrics def search_stage_2(query: str) -> Tuple[str, Dict]: """Stage 2: Vector Embeddings using FAISS.""" results, latency = run_search_function_and_time(query, search_vector) top1_score = results[0]["score"] if results else 0.0 top5_avg = get_average_score(results) metrics = { "top1_score": top1_score, "top5_avg": top5_avg, "latency_ms": latency, } print(f"Searched vector embeddings for '{query}' in {latency}ms") return format_results(results[:5], "Stage 2: Vector Embeddings", metrics), metrics def search_stage_3(query: str) -> Tuple[str, Dict]: """Stage 3: Query Expansion + Vector Embeddings.""" results, latency = run_search_function_and_time(query, search_vector_with_expansion) top1_score = results[0]["score"] if results else 0.0 top5_avg = get_average_score(results) metrics = { "top1_score": top1_score, "top5_avg": top5_avg, "latency_ms": latency, } return format_results(results[:5], "Stage 3: Query Expansion", metrics), metrics def search_stage_4(query: str) -> Tuple[str, Dict]: """Stage 4: Query Expansion + Vector Embeddings + Reranking.""" results, latency = run_search_function_and_time(query, search_vector_with_reranking) top1_score = results[0]["score"] if results else 0.0 top5_avg = get_average_score(results) metrics = { "top1_score": top1_score, "top5_avg": top5_avg, "latency_ms": latency, } return format_results(results, "Stage 4: Reranking", metrics), metrics def search_all_stages(query: str) -> Tuple[str, str, str, str, str]: """Run search across all stages and return comparison.""" if not query.strip(): empty_msg = "Please enter a search query." return empty_msg, empty_msg, empty_msg, empty_msg, empty_msg results_1, metrics_1 = search_stage_1(query) results_2, metrics_2 = search_stage_2(query) results_3, metrics_3 = search_stage_3(query) results_4, metrics_4 = search_stage_4(query) comparison = generate_comparison_table([metrics_1, metrics_2, metrics_3, metrics_4]) return results_1, results_2, results_3, results_4, comparison def calculate_improvement(metric1, metric2, metric_name): """Calculate improvement as a percentage.""" if metric2[metric_name] == 0: return (metric1[metric_name] - metric2[metric_name]) * 100 return (metric1[metric_name] - metric2[metric_name]) / metric2[metric_name] * 100 def generate_comparison_table(all_metrics: List[Dict]) -> str: """Generate comparison table for all stages.""" stage_names = [ "Baseline: BM25", "Stage 1: + Embeddings", "Stage 2: + Query Expansion", "Stage 3: + Reranking", ] # Build markdown table html = "## Stage-by-Stage Comparison\n\n" html += "| Stage | Top-1 Score | Top-5 Avg | Latency (ms) |\n" html += "|-------|-------------|-----------|-------------|\n" for name, metrics in zip(stage_names, all_metrics): html += f"| **{name}** | {metrics['top1_score']:.3f} | {metrics['top5_avg']:.3f} | {metrics['latency_ms']} |\n" top5_improvement = calculate_improvement(all_metrics[3], all_metrics[0], "top5_avg") top1_improvement = calculate_improvement( all_metrics[3], all_metrics[0], "top1_score" ) html += "\n---\n\n" html += "## Key Insights\n\n" html += f"- **Relevance Improvement**: Top-1 score increases by **{top1_improvement:.0f}%**, Top-5 average by **{top5_improvement:.0f}%**\n" html += f"- **Production-Ready Performance**: All stages complete in under **{max(m['latency_ms'] for m in all_metrics)}ms**\n" html += "- **Semantic Understanding**: Vector embeddings provide the largest single improvement in search quality\n" html += "- **Progressive Enhancement**: Each stage builds upon the previous, creating a robust pipeline\n" html += "- **Real-World Applicability**: This architecture scales to millions of documents with proper infrastructure\n" html += "\n\n---\n\n" html += """

💡 Understanding Reranker Scores

Note: You may notice that reranking shows the same cosine similarity scores from Stage 2 despite improved result ordering. This is intentional and highlights an important concept:

This two-stage approach (retrieve + rerank) is the industry standard for building high-quality search systems at scale.

""" return html def set_example(example: str) -> str: """Set an example query.""" return example def load_example_query(category: str, ambiguity: str) -> str: """Load example query based on category and ambiguity level.""" ambiguity_key = ambiguity.lower().replace(" ", "_") return EXAMPLE_QUERIES_BY_CATEGORY[category][ambiguity_key] def generate_category_distribution_table() -> str: """Generate HTML table showing MainCategory distribution.""" df = load_clean_amazon_product_data() category_counts = df["MainCategory"].value_counts() total = len(df) html = """ ### Dataset Category Distribution """ for category, count in category_counts.items(): percentage = (count / total) * 100 html += f""" """ html += f"""
Category Count Percentage
{category} {count:,} {percentage:.1f}%
Total {total:,} 100.0%
""" return html def generate_sample_data_table() -> str: """Generate HTML table showing sample rows from the dataset.""" df = load_clean_amazon_product_data() sample_df = df.sample(n=5, random_state=42) html = """ ### Sample Products from Dataset """ for _, row in sample_df.iterrows(): description = ( row["Description"][:80] + "..." if len(row["Description"]) > 80 else row["Description"] ) html += f""" """ html += "
Product Name Main Category Secondary Category Description
{row["Product Name"]} {row["MainCategory"]} {row["SecondaryCategory"]} {description}
" return html # Build Gradio Interface with gr.Blocks( css=CUSTOM_CSS, theme=GRADIO_THEME, title="Search Alchemy - Fireworks AI" ) as demo: # Header with gr.Row(): with gr.Column(scale=3): gr.Markdown( """

🧙 Search Alchemy 🧙

Building Production Search Pipelines with Fireworks AI

Four progressive stages demonstrating how to build production-grade semantic search: BM25Vector EmbeddingsQuery ExpansionReranking

""" ) with gr.Row(elem_classes="compact-header"): with gr.Column(scale=1, min_width=150): gr.Markdown( "

Powered by

" ) gr.Image( value=str(_FILE_PATH / "assets" / "fireworks_logo.png"), height=35, width=140, show_label=False, show_download_button=False, container=False, show_fullscreen_button=False, show_share_button=False, ) # Introduction Section with gr.Row(): gr.Markdown( """ **The Context:** - **[Dataset](https://huggingface.co/datasets/ckandemir/amazon-products):** 10,000+ Amazon products across Toys, Home & Kitchen, Clothing, Sports, and Baby Products - **The Problem:** Users search with vague terms like "keep kids busy" or "make bedroom nicer" instead of specific product names - **The Solution:** Four progressive stages showing how semantic search handles ambiguity better than keyword matching **How to Use:** 1. The default query "keep kids busy" is intentionally ambiguous - try it first to see the dramatic improvement across stages 2. Select different categories and specificity levels to explore more examples 3. Click **Search** and compare **Top-1 Score** and **Top-5 Avg** across all stages in the "Compare All Stages" tab 4. Notice how BM25 (keyword matching) struggles with ambiguous queries while vector embeddings + reranking excel Note: scores are normalized to a 0-1, higher is better. """ ) with gr.Row(): gr.Markdown( "**Try Example Queries:** Select a category and specificity level to auto-load an example" ) with gr.Row(): with gr.Column(scale=1): category_dropdown = gr.Dropdown( choices=list(EXAMPLE_QUERIES_BY_CATEGORY.keys()), value=list(EXAMPLE_QUERIES_BY_CATEGORY.keys())[0], label="Category", container=True, ) with gr.Column(scale=1): ambiguity_dropdown = gr.Dropdown( choices=["Clear", "Somewhat Ambiguous", "Ambiguous"], value="Ambiguous", label="Query Specificity", container=True, ) with gr.Column(scale=1): search_btn = gr.Button("Search", variant="primary", scale=1, size="lg") with gr.Row(): gr.Markdown( "**Or write your own query:** Write your own query to find product in the database" ) with gr.Row(): with gr.Column(scale=4): query_input = gr.Textbox( label="Search Query", placeholder="...", value=EXAMPLE_QUERIES_BY_CATEGORY["Baby Products"]["ambiguous"], scale=3, elem_classes="search-box", ) with gr.Tabs() as tabs: with gr.Tab("Stage 1: BM25 Baseline"): stage1_output = gr.Markdown( value="Click **Search** to see results", label="Results" ) with gr.Tab("Stage 2: + Vector Embeddings"): stage2_output = gr.Markdown( value="Click **Search** to see results", label="Results" ) with gr.Tab("Stage 3: + Query Expansion"): stage3_output = gr.Markdown( value="Click **Search** to see results", label="Results" ) with gr.Tab("Stage 4: + LLM Reranking"): stage4_output = gr.Markdown( value="Click **Search** to see results", label="Results" ) with gr.Tab("Compare All Stages"): comparison_output = gr.Markdown( value="Click **Search** to see results", label="Comparison" ) with gr.Accordion("Dataset Information", open=False): gr.Markdown("Explore the dataset used for this search demo") with gr.Row(): category_dist_table = gr.Markdown( value=generate_category_distribution_table() ) with gr.Row(): sample_data_table = gr.Markdown(value=generate_sample_data_table()) # Event handlers - auto-load query when dropdown changes category_dropdown.change( fn=load_example_query, inputs=[category_dropdown, ambiguity_dropdown], outputs=[query_input], ) ambiguity_dropdown.change( fn=load_example_query, inputs=[category_dropdown, ambiguity_dropdown], outputs=[query_input], ) search_btn.click( fn=search_all_stages, inputs=[query_input], outputs=[ stage1_output, stage2_output, stage3_output, stage4_output, comparison_output, ], ) if __name__ == "__main__": demo.launch()