import os import gradio as gr import markdown import requests import yaml from dotenv import load_dotenv try: from src.api.models.provider_models import MODEL_REGISTRY except ImportError as e: raise ImportError( "Could not import MODEL_REGISTRY from src.api.models.provider_models. " "Check the path and file existence." ) from e # Initialize environment variables load_dotenv() BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:8080") API_BASE_URL = f"{BACKEND_URL}/search" # Load feeds from YAML def load_feeds(): """Load feeds from the YAML configuration file. Returns: list: List of feeds with their details. """ feeds_path = os.path.join(os.path.dirname(__file__), "../src/configs/feeds_rss.yaml") with open(feeds_path) as f: feeds_yaml = yaml.safe_load(f) return feeds_yaml.get("feeds", []) feeds = load_feeds() feed_names = [f["name"] for f in feeds] feed_authors = [f["author"] for f in feeds] # ----------------------- # Custom CSS for modern UI # ----------------------- CUSTOM_CSS = """ /* Minimal, utility-first vibe with a neutral palette */ :root { --border: 1px solid rgba(2, 6, 23, 0.08); --surface: #ffffff; --surface-muted: #f8fafc; --text: #0f172a; --muted: #475569; --accent: #0ea5e9; --accent-strong: #0284c7; --radius: 12px; --shadow: 0 8px 20px rgba(2, 6, 23, 0.06); } .gradio-container, body { background: var(--surface-muted); color: var(--text); } .dark .gradio-container, .dark body { background: #0b1220; color: #e5e7eb; } .section { background: var(--surface); border: var(--border); border-radius: var(--radius); box-shadow: var(--shadow); padding: 16px; } .dark .section { background: #0f172a; border: 1px solid rgba(255,255,255,0.08); } .header { display: flex; align-items: baseline; justify-content: space-between; margin-bottom: 12px; } .header h2 { margin: 0; font-size: 22px; } .subtle { color: var(--muted); font-size: 13px; } .results-table { width: 100%; border-collapse: collapse; font-size: 14px; } .results-table th, .results-table td { border: 1px solid #e2e8f0; padding: 10px; text-align: left; vertical-align: top; } .results-table th { background: #f1f5f9; } .dark .results-table th { background: #0b1325; border-color: rgba(255,255,255,0.08); color: #e5e7eb; } .dark .results-table td { border-color: rgba(255,255,255,0.08); color: #e2e8f0; } .results-table a { color: var(--accent-strong); text-decoration: none; font-weight: 600; } .results-table a:hover { text-decoration: underline; } .dark .results-table a { color: #7dd3fc; } .answer { background: var(--surface); border: var(--border); border-radius: var(--radius); padding: 14px; } .dark .answer { background: #0f172a; border: 1px solid rgba(255,255,255,0.08); color: #e5e7eb; } .model-badge { display: inline-block; margin-top: 6px; padding: 6px 10px; border-radius: 999px; border: var(--border); background: #eef2ff; color: #3730a3; font-weight: 600; } .dark .model-badge { background: rgba(59,130,246,0.15); color: #c7d2fe; border: 1px solid rgba(255,255,255,0.08); } .error { border: 1px solid #fecaca; background: #fff1f2; color: #7f1d1d; border-radius: var(--radius); padding: 10px 12px; } .dark .error { border: 1px solid rgba(248,113,113,0.35); background: rgba(127,29,29,0.25); color: #fecaca; } /* Sticky status banner with spinner */ #status-banner { position: sticky; top: 0; z-index: 1000; margin: 8px 0 12px 0; } #status-banner .banner { display: flex; align-items: center; gap: 10px; padding: 10px 12px; border-radius: var(--radius); border: 1px solid #bae6fd; background: #e0f2fe; color: #075985; box-shadow: var(--shadow); } #status-banner .spinner { width: 16px; height: 16px; border-radius: 999px; border: 2px solid currentColor; border-right-color: transparent; animation: spin 0.8s linear infinite; } @keyframes spin { to { transform: rotate(360deg); } } .dark #status-banner .banner { border-color: rgba(59,130,246,0.35); background: rgba(2,6,23,0.55); color: #93c5fd; } /* Actions row aligns buttons to the right, outside filter sections */ .actions { display: flex; justify-content: flex-end; margin: 8px 0 12px 0; gap: 8px; } /* Prominent CTA buttons (not full-width) */ .cta { display: inline-flex; } .cta .gr-button { background: linear-gradient(180deg, var(--accent), var(--accent-strong)); color: #ffffff; border: none; border-radius: 14px; padding: 12px 18px; font-weight: 700; font-size: 15px; box-shadow: 0 10px 22px rgba(2,6,23,0.18); width: auto !important; } .cta .gr-button:hover { transform: translateY(-1px); filter: brightness(1.05); } .cta .gr-button:focus-visible { outline: 2px solid #93c5fd; outline-offset: 2px; } .dark .cta .gr-button { box-shadow: 0 12px 26px rgba(2,6,23,0.45); } """ # ----------------------- # API helpers # ----------------------- def fetch_unique_titles(payload): """ Fetch unique article titles based on the search criteria. Args: payload (dict): The search criteria including query_text, feed_author, feed_name, limit, and optional title_keywords. Returns: list: A list of articles matching the criteria. Raises: Exception: If the API request fails. """ try: resp = requests.post(f"{API_BASE_URL}/unique-titles", json=payload) resp.raise_for_status() return resp.json().get("results", []) except Exception as e: raise Exception(f"Failed to fetch titles: {str(e)}") from e def call_ai(payload, streaming=True): """ " Call the AI endpoint with the given payload. Args: payload (dict): The payload to send to the AI endpoint. streaming (bool): Whether to use streaming or non-streaming endpoint. Yields: tuple: A tuple containing the type of response and the response text. """ endpoint = f"{API_BASE_URL}/ask/stream" if streaming else f"{API_BASE_URL}/ask" answer_text = "" try: if streaming: with requests.post(endpoint, json=payload, stream=True) as r: r.raise_for_status() for chunk in r.iter_content(chunk_size=None, decode_unicode=True): if not chunk: continue if chunk.startswith("__model_used__:"): yield "model", chunk.replace("__model_used__:", "").strip() elif chunk.startswith("__error__"): yield "error", "Request failed. Please try again later." break elif chunk.startswith("__truncated__"): yield "truncated", "AI response truncated due to token limit." else: answer_text += chunk yield "text", answer_text else: resp = requests.post(endpoint, json=payload) resp.raise_for_status() data = resp.json() answer_text = data.get("answer", "") yield "text", answer_text if data.get("finish_reason") == "length": yield "truncated", "AI response truncated due to token limit." except Exception as e: yield "error", f"Request failed: {str(e)}" def get_models_for_provider(provider): """ Get available models for a provider Args: provider (str): The name of the provider (e.g., "openrouter", "openai") Returns: list: List of model names available for the provider """ provider_key = provider.lower() try: config = MODEL_REGISTRY.get_config(provider_key) return ( ["Automatic Model Selection (Model Routing)"] + ([config.primary_model] if config.primary_model else []) + list(config.candidate_models) ) except Exception: return ["Automatic Model Selection (Model Routing)"] # ----------------------- # Gradio interface functions # ----------------------- def handle_search_articles(query_text, feed_name, feed_author, title_keywords, limit): """ Handle article search Args: query_text (str): The text to search for in article titles. feed_name (str): The name of the feed to filter articles by. feed_author (str): The author of the feed to filter articles by. title_keywords (str): Keywords to search for in article titles. limit (int): The maximum number of articles to return. Returns: str: HTML formatted string of search results or error message. Raises: Exception: If the API request fails. """ if not query_text.strip(): return "Please enter a query text." payload = { "query_text": query_text.strip().lower(), "feed_author": feed_author.strip() if feed_author else "", "feed_name": feed_name.strip() if feed_name else "", "limit": limit, "title_keywords": title_keywords.strip().lower() if title_keywords else None, } try: results = fetch_unique_titles(payload) if not results: return "No results found." # Render results as a compact table html_output = ( "
" "

Results

Unique titles
" " " " " " " " " " " ) for item in results: title = item.get("title", "No title") feed_n = item.get("feed_name", "N/A") feed_a = item.get("feed_author", "N/A") authors = ", ".join(item.get("article_author") or ["N/A"]) url = item.get("url", "#") html_output += ( " " f" " f" " f" " f" " f" " " " ) html_output += "
TitleNewsletterFeed AuthorArticle AuthorsLink
{title}{feed_n}{feed_a}{authors}Open
" return html_output except Exception as e: return f"
Error: {str(e)}
" def handle_ai_question_streaming( query_text, feed_name, feed_author, limit, provider, model, ): """ Handle AI question with streaming Args: query_text (str): The question to ask the AI. feed_name (str): The name of the feed to filter articles by. feed_author (str): The author of the feed to filter articles by. limit (int): The maximum number of articles to consider. provider (str): The LLM provider to use. model (str): The specific model to use from the provider. Yields: tuple: (HTML formatted answer string, model info string) """ if not query_text.strip(): yield "Please enter a query text.", "" return if not provider or not model: yield "Please select provider and model.", "" return payload = { "query_text": query_text.strip().lower(), "feed_author": feed_author.strip() if feed_author else "", "feed_name": feed_name.strip() if feed_name else "", "limit": limit, "provider": provider.lower(), } if model != "Automatic Model Selection (Model Routing)": payload["model"] = model try: answer_html = "" model_info = f"Provider: {provider}" for _, (event_type, content) in enumerate(call_ai(payload, streaming=True)): if event_type == "text": html_content = markdown.markdown(content, extensions=["tables"]) answer_html = f"
{html_content}
" yield answer_html, model_info elif event_type == "model": model_info = f"Provider: {provider} | Model: {content}" yield answer_html, model_info elif event_type == "truncated": answer_html += f"
⚠️ {content}
" yield answer_html, model_info elif event_type == "error": error_html = f"
❌ {content}
" yield error_html, model_info break except Exception as e: error_html = "
Error: {}
".format(str(e)) yield error_html, model_info def handle_ai_question_non_streaming(query_text, feed_name, feed_author, limit, provider, model): """ Handle AI question without streaming Args: query_text (str): The question to ask the AI. feed_name (str): The name of the feed to filter articles by. feed_author (str): The author of the feed to filter articles by. limit (int): The maximum number of articles to consider. provider (str): The LLM provider to use. model (str): The specific model to use from the provider. Returns: tuple: (HTML formatted answer string, model info string) """ if not query_text.strip(): return "Please enter a query text.", "" if not provider or not model: return "Please select provider and model.", "" payload = { "query_text": query_text.strip().lower(), "feed_author": feed_author.strip() if feed_author else "", "feed_name": feed_name.strip() if feed_name else "", "limit": limit, "provider": provider.lower(), } if model != "Automatic Model Selection (Model Routing)": payload["model"] = model try: answer_html = "" model_info = f"Provider: {provider}" for event_type, content in call_ai(payload, streaming=False): if event_type == "text": html_content = markdown.markdown(content, extensions=["tables"]) answer_html = f"
{html_content}
" elif event_type == "model": model_info = f"Provider: {provider} | Model: {content}" elif event_type == "truncated": answer_html += f"
⚠️ {content}
" elif event_type == "error": return ( f"
❌ {content}
", model_info, ) return answer_html, model_info except Exception as e: return ( f"
Error: {str(e)}
", f"Provider: {provider}", ) def update_model_choices(provider): """ Update model choices based on selected provider Args: provider (str): The selected LLM provider Returns: gr.Dropdown: Updated model dropdown component """ models = get_models_for_provider(provider) return gr.Dropdown(choices=models, value=models[0] if models else "") # ----------------------- # Progress/status helpers # ----------------------- def start_search_status(): return "" def start_ai_status(streaming_mode): mode = "streaming" if streaming_mode == "Streaming" else "non‑streaming" return f"" def clear_status(): return "" # ----------------------- # Gradio UI (new layout) # ----------------------- def ask_ai_router( streaming_mode, query_text, feed_name, feed_author, limit, provider, model, ): """ Route AI question to streaming or non-streaming handler. Yields: tuple: (answer_html, model_info_html) """ if streaming_mode == "Streaming": yield from handle_ai_question_streaming( query_text, feed_name, feed_author, limit, provider, model ) else: result_html, model_info_text = handle_ai_question_non_streaming( query_text, feed_name, feed_author, limit, provider, model ) yield result_html, model_info_text with gr.Blocks(title="Article Search Engine", theme=gr.themes.Base(), css=CUSTOM_CSS) as demo: gr.Markdown( "### Article Search Engine\n" "Search across substack, medium and top publications articles on AI topics or ask questions with an AI assistant." ) # Sticky status banner (empty by default) status_banner = gr.HTML(value="", elem_id="status-banner") with gr.Tabs(): # Search Tab with gr.Tab("Search"): with gr.Group(elem_classes="section"): gr.Markdown("#### Find articles on any AI topic") search_query = gr.Textbox( label="Query", placeholder="What are you looking for?", lines=3, ) with gr.Row(): search_feed_author = gr.Dropdown( choices=[""] + feed_authors, label="Author (optional)", value="" ) search_feed_name = gr.Dropdown( choices=[""] + feed_names, label="Newsletter (optional)", value="" ) with gr.Row(): search_title_keywords = gr.Textbox( label="Title keywords (optional)", placeholder="Filter by words in the title", ) search_limit = gr.Slider( minimum=1, maximum=20, step=1, label="Number of results", value=5 ) with gr.Row(elem_classes="actions"): search_btn = gr.Button("Search", variant="primary", elem_classes="cta") search_output = gr.HTML(label="Results") # Ask AI Tab with gr.Tab("Ask AI"): with gr.Group(elem_classes="section"): gr.Markdown("#### Ask an AI assistant about any AI topic") ai_query = gr.Textbox( label="Your question", placeholder="Ask a question. The AI will use the articles for context.", lines=4, ) with gr.Row(): ai_feed_author = gr.Dropdown( choices=[""] + feed_authors, label="Author (optional)", value="" ) ai_feed_name = gr.Dropdown( choices=[""] + feed_names, label="Newsletter (optional)", value="" ) ai_limit = gr.Slider( minimum=1, maximum=20, step=1, label="Max articles", value=5 ) with gr.Row(): provider_dd = gr.Dropdown( choices=["OpenRouter", "HuggingFace", "OpenAI"], label="LLM Provider", value="OpenRouter", ) model_dd = gr.Dropdown( choices=get_models_for_provider("OpenRouter"), label="Model", value="Automatic Model Selection (Model Routing)", ) streaming_mode_dd = gr.Radio( choices=["Streaming", "Non-Streaming"], value="Streaming", label="Answer mode", ) with gr.Row(elem_classes="actions"): ask_btn = gr.Button("Run", variant="primary", elem_classes="cta") ai_answer = gr.HTML(label="Answer") ai_model_info = gr.HTML(label="Model") # Wire events with sticky status banner search_btn.click( fn=start_search_status, inputs=[], outputs=[status_banner], show_progress=False, ).then( fn=handle_search_articles, inputs=[ search_query, search_feed_name, search_feed_author, search_title_keywords, search_limit, ], outputs=[search_output], show_progress=False, ).then( fn=clear_status, inputs=[], outputs=[status_banner], show_progress=False, ) provider_dd.change(fn=update_model_choices, inputs=[provider_dd], outputs=[model_dd]) ask_btn.click( fn=start_ai_status, inputs=[streaming_mode_dd], outputs=[status_banner], show_progress=False, ).then( fn=ask_ai_router, inputs=[ streaming_mode_dd, ai_query, ai_feed_name, ai_feed_author, ai_limit, provider_dd, model_dd, ], outputs=[ai_answer, ai_model_info], show_progress=False, ).then( fn=clear_status, inputs=[], outputs=[status_banner], show_progress=False, ) # For local testing if __name__ == "__main__": demo.launch() # # For Google Cloud Run deployment # if __name__ == "__main__": # demo.launch( # server_name="0.0.0.0", # server_port=int(os.environ.get("PORT", 8080)) # )