Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from gradio_client import Client | |
| from langgraph.graph import StateGraph, START, END | |
| from typing import TypedDict, Optional | |
| import io | |
| from PIL import Image | |
| import os | |
| #OPEN QUESTION: SHOULD WE PASS ALL PARAMS FROM THE ORCHESTRATOR TO THE NODES INSTEAD OF SETTING IN EACH MODULE? | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| import configparser | |
| import logging | |
| import os | |
| import ast | |
| import re | |
| from dotenv import load_dotenv | |
| # Local .env file | |
| load_dotenv() | |
| def getconfig(configfile_path: str): | |
| """ | |
| Read the config file | |
| Params | |
| ---------------- | |
| configfile_path: file path of .cfg file | |
| """ | |
| config = configparser.ConfigParser() | |
| try: | |
| config.read_file(open(configfile_path)) | |
| return config | |
| except: | |
| logging.warning("config file not found") | |
| def get_auth(provider: str) -> dict: | |
| """Get authentication configuration for different providers""" | |
| auth_configs = { | |
| "huggingface": {"api_key": os.getenv("HF_TOKEN")}, | |
| "qdrant": {"api_key": os.getenv("QDRANT_API_KEY")}, | |
| } | |
| provider = provider.lower() # Normalize to lowercase | |
| if provider not in auth_configs: | |
| raise ValueError(f"Unsupported provider: {provider}") | |
| auth_config = auth_configs[provider] | |
| api_key = auth_config.get("api_key") | |
| if not api_key: | |
| logging.warning(f"No API key found for provider '{provider}'. Please set the appropriate environment variable.") | |
| auth_config["api_key"] = None | |
| return auth_config | |
| # Define the state schema | |
| class GraphState(TypedDict): | |
| query: str | |
| context: str | |
| result: str | |
| # Add orchestrator-level parameters (addressing your open question) | |
| reports_filter: str | |
| sources_filter: str | |
| subtype_filter: str | |
| year_filter: str | |
| # node 2: retriever | |
| def retrieve_node(state: GraphState) -> GraphState: | |
| client = Client("giz/chatfed_retriever", hf_token=HF_TOKEN) # HF repo name | |
| context = client.predict( | |
| query=state["query"], | |
| reports_filter=state.get("reports_filter", ""), | |
| sources_filter=state.get("sources_filter", ""), | |
| subtype_filter=state.get("subtype_filter", ""), | |
| year_filter=state.get("year_filter", ""), | |
| api_name="/retrieve" | |
| ) | |
| return {"context": context} | |
| # node 3: generator | |
| def generate_node(state: GraphState) -> GraphState: | |
| client = Client("giz/chatfed_generator", hf_token=HF_TOKEN) | |
| result = client.predict( | |
| query=state["query"], | |
| context=state["context"], | |
| api_name="/generate" | |
| ) | |
| return {"result": result} | |
| # build the graph | |
| workflow = StateGraph(GraphState) | |
| # Add nodes | |
| workflow.add_node("retrieve", retrieve_node) | |
| workflow.add_node("generate", generate_node) | |
| # Add edges | |
| workflow.add_edge(START, "retrieve") | |
| workflow.add_edge("retrieve", "generate") | |
| workflow.add_edge("generate", END) | |
| # Compile the graph | |
| graph = workflow.compile() | |
| # Single tool for processing queries | |
| def process_query( | |
| query: str, | |
| reports_filter: str = "", | |
| sources_filter: str = "", | |
| subtype_filter: str = "", | |
| year_filter: str = "" | |
| ) -> str: | |
| """ | |
| Execute the ChatFed orchestration pipeline to process a user query. | |
| This function orchestrates a two-step workflow: | |
| 1. Retrieve relevant context using the ChatFed retriever service with optional filters | |
| 2. Generate a response using the ChatFed generator service with the retrieved context | |
| Args: | |
| query (str): The user's input query/question to be processed | |
| reports_filter (str, optional): Filter for specific report types. Defaults to "". | |
| sources_filter (str, optional): Filter for specific data sources. Defaults to "". | |
| subtype_filter (str, optional): Filter for document subtypes. Defaults to "". | |
| year_filter (str, optional): Filter for specific years. Defaults to "". | |
| Returns: | |
| str: The generated response from the ChatFed generator service | |
| """ | |
| initial_state = { | |
| "query": query, | |
| "context": "", | |
| "result": "", | |
| "reports_filter": reports_filter or "", | |
| "sources_filter": sources_filter or "", | |
| "subtype_filter": subtype_filter or "", | |
| "year_filter": year_filter or "" | |
| } | |
| final_state = graph.invoke(initial_state) | |
| return final_state["result"] | |
| # Simple testing interface | |
| # Guidance for ChatUI - can be removed later. Questionable whether front end even necessary. Maybe nice to show the graph. | |
| with gr.Blocks(title="ChatFed Orchestrator") as demo: | |
| with gr.Row(): | |
| # Left column - Graph visualization | |
| with gr.Column(): | |
| query_input = gr.Textbox( | |
| label="query", | |
| lines=2, | |
| placeholder="Enter your search query here", | |
| info="The query to search for in the vector database" | |
| ) | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| # Right column - Interface and documentation | |
| with gr.Column(): | |
| output = gr.Textbox( | |
| label="answer", | |
| lines=10, | |
| show_copy_button=True | |
| ) | |
| # UI event handler | |
| submit_btn.click( | |
| fn=process_query, | |
| inputs=query_input, | |
| outputs=output, | |
| api_name="process_query" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| mcp_server=True, | |
| show_error=True | |
| ) |