import gradio as gr
import time
from typing import List, Dict, Tuple, Callable
from pathlib import Path
from config import (
GRADIO_THEME,
CUSTOM_CSS,
EXAMPLE_QUERIES_BY_CATEGORY,
)
from src.search.bm25_lexical_search import search_bm25
from src.search.vector_search import (
search_vector,
search_vector_with_expansion,
search_vector_with_reranking,
)
from src.data_prep.data_prep import load_clean_amazon_product_data
_FILE_PATH = Path(__file__).parents[1]
def format_results(results: List[Dict], stage_name: str, metrics: Dict) -> str:
"""Format search results as HTML.
Args:
results: List of dicts with keys: product_name, description, main_category, secondary_category, score
stage_name: Name of the search stage
metrics: Dict with keys: top1_score, top5_avg, latency_ms
"""
html_parts = [
f"## 🔍 {stage_name}\n\n",
f"""
TOP-1 SCORE
{metrics['top1_score']:.3f}
Best result
TOP-5 AVG
{metrics['top5_avg']:.3f}
Overall quality
LATENCY
{metrics['latency_ms']}ms
Response time
""",
'\n\n',
]
# Performance metrics at the top with prominent styling
# Results section
for idx, result in enumerate(results, 1):
category = f"{result.get('main_category', 'N/A')} > {result.get('secondary_category', 'N/A')}"
html_parts.append(
f"""
{idx}. {result['product_name']}
{result['description'][:150]}...
Category: {category}
"""
)
html_parts.append("
")
return "".join(html_parts)
def run_search_function_and_time(query: str, func: Callable, top_n: int = 5):
start = time.time()
results = func(query)
latency = int((time.time() - start) * 1000)
return results[:top_n], latency
def get_average_score(results: List[Dict]) -> float:
return sum(r["score"] for r in results) / len(results) if results else 0
def get_weighted_score(results: List[Dict]) -> float:
"""
Calculate position-weighted average score.
Top positions get higher weight (5x for #1, 4x for #2, etc.)
This rewards ranking quality - putting best results at the top.
Args:
results: List of search results with 'score' field
Returns:
Weighted average score (0-1 scale)
"""
if not results:
return 0.0
weights = [5, 4, 3, 2, 1]
total_weight = sum(weights)
weighted_sum = sum((weights[i] * r["score"]) for i, r in enumerate(results[:5]))
return weighted_sum / total_weight
def search_stage_1(query: str) -> Tuple[str, Dict]:
"""Stage 1: Baseline BM25 keyword search."""
results, latency = run_search_function_and_time(query, search_bm25)
top1_score = results[0]["score"] / 5.0 if results else 0.0 # Normalize BM25 scores
top5_avg = get_average_score(results) / 5.0 if results else 0.0
metrics = {
"top1_score": min(1.0, top1_score),
"top5_avg": min(1.0, top5_avg),
"latency_ms": latency,
}
print(f"Searched BM25 for {query} in {latency}ms")
return format_results(results, "Stage 1: BM25 Baseline", metrics), metrics
def search_stage_2(query: str) -> Tuple[str, Dict]:
"""Stage 2: Vector Embeddings using FAISS."""
results, latency = run_search_function_and_time(query, search_vector)
top1_score = results[0]["score"] if results else 0.0
top5_avg = get_average_score(results)
metrics = {
"top1_score": top1_score,
"top5_avg": top5_avg,
"latency_ms": latency,
}
print(f"Searched vector embeddings for '{query}' in {latency}ms")
return format_results(results[:5], "Stage 2: Vector Embeddings", metrics), metrics
def search_stage_3(query: str) -> Tuple[str, Dict]:
"""Stage 3: Query Expansion + Vector Embeddings."""
results, latency = run_search_function_and_time(query, search_vector_with_expansion)
top1_score = results[0]["score"] if results else 0.0
top5_avg = get_average_score(results)
metrics = {
"top1_score": top1_score,
"top5_avg": top5_avg,
"latency_ms": latency,
}
return format_results(results[:5], "Stage 3: Query Expansion", metrics), metrics
def search_stage_4(query: str) -> Tuple[str, Dict]:
"""Stage 4: Query Expansion + Vector Embeddings + Reranking."""
results, latency = run_search_function_and_time(query, search_vector_with_reranking)
top1_score = results[0]["score"] if results else 0.0
top5_avg = get_average_score(results)
metrics = {
"top1_score": top1_score,
"top5_avg": top5_avg,
"latency_ms": latency,
}
return format_results(results, "Stage 4: Reranking", metrics), metrics
def search_all_stages(query: str) -> Tuple[str, str, str, str, str]:
"""Run search across all stages and return comparison."""
if not query.strip():
empty_msg = "Please enter a search query."
return empty_msg, empty_msg, empty_msg, empty_msg, empty_msg
results_1, metrics_1 = search_stage_1(query)
results_2, metrics_2 = search_stage_2(query)
results_3, metrics_3 = search_stage_3(query)
results_4, metrics_4 = search_stage_4(query)
comparison = generate_comparison_table([metrics_1, metrics_2, metrics_3, metrics_4])
return results_1, results_2, results_3, results_4, comparison
def calculate_improvement(metric1, metric2, metric_name):
"""Calculate improvement as a percentage."""
if metric2[metric_name] == 0:
return (metric1[metric_name] - metric2[metric_name]) * 100
return (metric1[metric_name] - metric2[metric_name]) / metric2[metric_name] * 100
def generate_comparison_table(all_metrics: List[Dict]) -> str:
"""Generate comparison table for all stages."""
stage_names = [
"Baseline: BM25",
"Stage 1: + Embeddings",
"Stage 2: + Query Expansion",
"Stage 3: + Reranking",
]
# Build markdown table
html = "## Stage-by-Stage Comparison\n\n"
html += "| Stage | Top-1 Score | Top-5 Avg | Latency (ms) |\n"
html += "|-------|-------------|-----------|-------------|\n"
for name, metrics in zip(stage_names, all_metrics):
html += f"| **{name}** | {metrics['top1_score']:.3f} | {metrics['top5_avg']:.3f} | {metrics['latency_ms']} |\n"
top5_improvement = calculate_improvement(all_metrics[3], all_metrics[0], "top5_avg")
top1_improvement = calculate_improvement(
all_metrics[3], all_metrics[0], "top1_score"
)
html += "\n---\n\n"
html += "## Key Insights\n\n"
html += f"- **Relevance Improvement**: Top-1 score increases by **{top1_improvement:.0f}%**, Top-5 average by **{top5_improvement:.0f}%**\n"
html += f"- **Production-Ready Performance**: All stages complete in under **{max(m['latency_ms'] for m in all_metrics)}ms**\n"
html += "- **Semantic Understanding**: Vector embeddings provide the largest single improvement in search quality\n"
html += "- **Progressive Enhancement**: Each stage builds upon the previous, creating a robust pipeline\n"
html += "- **Real-World Applicability**: This architecture scales to millions of documents with proper infrastructure\n"
html += "\n\n---\n\n"
html += """
💡 Understanding Reranker Scores
Note: You may notice that reranking shows the same cosine similarity scores from Stage 2
despite improved result ordering. This is intentional and highlights an important concept:
- Different ranking mechanisms: Cosine similarity measures vector distance in embedding space,
while rerankers use cross-encoder models that analyze query-document pairs directly for deeper semantic understanding.
- Why reranking works better: Cross-encoders examine token-level interactions between the query
and each document, capturing nuances that simple vector distance misses.
- The scores displayed: We preserve cosine scores to show that reranking reorders results
based on relevance, not similarity. A document with slightly lower cosine similarity might be more contextually relevant.
- Production best practice: Use fast vector search to retrieve candidates (~100-1000 results),
then apply computationally expensive reranking to the top results for maximum accuracy.
This two-stage approach (retrieve + rerank) is the industry standard for building high-quality search systems at scale.
"""
return html
def set_example(example: str) -> str:
"""Set an example query."""
return example
def load_example_query(category: str, ambiguity: str) -> str:
"""Load example query based on category and ambiguity level."""
ambiguity_key = ambiguity.lower().replace(" ", "_")
return EXAMPLE_QUERIES_BY_CATEGORY[category][ambiguity_key]
def generate_category_distribution_table() -> str:
"""Generate HTML table showing MainCategory distribution."""
df = load_clean_amazon_product_data()
category_counts = df["MainCategory"].value_counts()
total = len(df)
html = """
### Dataset Category Distribution
| Category |
Count |
Percentage |
"""
for category, count in category_counts.items():
percentage = (count / total) * 100
html += f"""
| {category} |
{count:,} |
{percentage:.1f}% |
"""
html += f"""
| Total |
{total:,} |
100.0% |
"""
return html
def generate_sample_data_table() -> str:
"""Generate HTML table showing sample rows from the dataset."""
df = load_clean_amazon_product_data()
sample_df = df.sample(n=5, random_state=42)
html = """
### Sample Products from Dataset
| Product Name |
Main Category |
Secondary Category |
Description |
"""
for _, row in sample_df.iterrows():
description = (
row["Description"][:80] + "..."
if len(row["Description"]) > 80
else row["Description"]
)
html += f"""
| {row["Product Name"]} |
{row["MainCategory"]} |
{row["SecondaryCategory"]} |
{description} |
"""
html += "
"
return html
# Build Gradio Interface
with gr.Blocks(
css=CUSTOM_CSS, theme=GRADIO_THEME, title="Search Alchemy - Fireworks AI"
) as demo:
# Header
with gr.Row():
with gr.Column(scale=3):
gr.Markdown(
"""
Building Production Search Pipelines with Fireworks AI
Four progressive stages demonstrating how to build production-grade semantic search:
BM25 → Vector Embeddings → Query Expansion → Reranking
"""
)
with gr.Row(elem_classes="compact-header"):
with gr.Column(scale=1, min_width=150):
gr.Markdown(
"Powered by
"
)
gr.Image(
value=str(_FILE_PATH / "assets" / "fireworks_logo.png"),
height=35,
width=140,
show_label=False,
show_download_button=False,
container=False,
show_fullscreen_button=False,
show_share_button=False,
)
# Introduction Section
with gr.Row():
gr.Markdown(
"""
**The Context:**
- **[Dataset](https://huggingface.co/datasets/ckandemir/amazon-products):** 10,000+ Amazon products across Toys, Home & Kitchen, Clothing, Sports, and Baby Products
- **The Problem:** Users search with vague terms like "keep kids busy" or "make bedroom nicer" instead of specific product names
- **The Solution:** Four progressive stages showing how semantic search handles ambiguity better than keyword matching
**How to Use:**
1. The default query "keep kids busy" is intentionally ambiguous - try it first to see the dramatic improvement across stages
2. Select different categories and specificity levels to explore more examples
3. Click **Search** and compare **Top-1 Score** and **Top-5 Avg** across all stages in the "Compare All Stages" tab
4. Notice how BM25 (keyword matching) struggles with ambiguous queries while vector embeddings + reranking excel
Note: scores are normalized to a 0-1, higher is better.
"""
)
with gr.Row():
gr.Markdown(
"**Try Example Queries:** Select a category and specificity level to auto-load an example"
)
with gr.Row():
with gr.Column(scale=1):
category_dropdown = gr.Dropdown(
choices=list(EXAMPLE_QUERIES_BY_CATEGORY.keys()),
value=list(EXAMPLE_QUERIES_BY_CATEGORY.keys())[0],
label="Category",
container=True,
)
with gr.Column(scale=1):
ambiguity_dropdown = gr.Dropdown(
choices=["Clear", "Somewhat Ambiguous", "Ambiguous"],
value="Ambiguous",
label="Query Specificity",
container=True,
)
with gr.Column(scale=1):
search_btn = gr.Button("Search", variant="primary", scale=1, size="lg")
with gr.Row():
gr.Markdown(
"**Or write your own query:** Write your own query to find product in the database"
)
with gr.Row():
with gr.Column(scale=4):
query_input = gr.Textbox(
label="Search Query",
placeholder="...",
value=EXAMPLE_QUERIES_BY_CATEGORY["Baby Products"]["ambiguous"],
scale=3,
elem_classes="search-box",
)
with gr.Tabs() as tabs:
with gr.Tab("Stage 1: BM25 Baseline"):
stage1_output = gr.Markdown(
value="Click **Search** to see results", label="Results"
)
with gr.Tab("Stage 2: + Vector Embeddings"):
stage2_output = gr.Markdown(
value="Click **Search** to see results", label="Results"
)
with gr.Tab("Stage 3: + Query Expansion"):
stage3_output = gr.Markdown(
value="Click **Search** to see results", label="Results"
)
with gr.Tab("Stage 4: + LLM Reranking"):
stage4_output = gr.Markdown(
value="Click **Search** to see results", label="Results"
)
with gr.Tab("Compare All Stages"):
comparison_output = gr.Markdown(
value="Click **Search** to see results", label="Comparison"
)
with gr.Accordion("Dataset Information", open=False):
gr.Markdown("Explore the dataset used for this search demo")
with gr.Row():
category_dist_table = gr.Markdown(
value=generate_category_distribution_table()
)
with gr.Row():
sample_data_table = gr.Markdown(value=generate_sample_data_table())
# Event handlers - auto-load query when dropdown changes
category_dropdown.change(
fn=load_example_query,
inputs=[category_dropdown, ambiguity_dropdown],
outputs=[query_input],
)
ambiguity_dropdown.change(
fn=load_example_query,
inputs=[category_dropdown, ambiguity_dropdown],
outputs=[query_input],
)
search_btn.click(
fn=search_all_stages,
inputs=[query_input],
outputs=[
stage1_output,
stage2_output,
stage3_output,
stage4_output,
comparison_output,
],
)
if __name__ == "__main__":
demo.launch()