Spaces:
Sleeping
Sleeping
| from typing import Dict, Any | |
| from langchain_openai import ChatOpenAI | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain.schema import StrOutputParser | |
| from scripts.rag_chat import build_general_qa_chain | |
| def build_router_chain(model_name=None): | |
| general_qa = build_general_qa_chain(model_name=model_name) | |
| llm = ChatOpenAI(model_name=model_name or "gpt-4o-mini", temperature=0.0) | |
| # This prompt asks the LLM to choose which "mode" to use | |
| router_prompt = ChatPromptTemplate.from_template(""" | |
| You are a routing assistant for a chatbot. | |
| Classify the following user request into one of these categories: | |
| - "code" for programming or debugging | |
| - "summarize" for summary requests | |
| - "calculate" for math or numeric calculations | |
| - "general" for general Q&A using course files | |
| Return ONLY the category word. | |
| User request: {input} | |
| """) | |
| router_chain = router_prompt | llm | StrOutputParser() | |
| class Router: | |
| def invoke(self, input_dict: Dict[str, Any]): | |
| category = router_chain.invoke({"input": input_dict["input"]}).strip().lower() | |
| print(f"[ROUTER] User query routed to category: {category}") | |
| if category == "code": | |
| prompt = ChatPromptTemplate.from_template( | |
| "As a coding assistant, help with this Python question.\nQuestion: {input}\nAnswer:" | |
| ) | |
| chain = prompt | llm | StrOutputParser() | |
| return {"result": chain.invoke({"input": input_dict["input"]})} | |
| # elif category == "summarize": | |
| # prompt = ChatPromptTemplate.from_template( | |
| # "Provide a concise summary about: {input}\nSummary:" | |
| # ) | |
| # chain = prompt | llm | StrOutputParser() | |
| # return {"result": chain.invoke({"input": input_dict["input"]})} | |
| #elif category == "summarize": | |
| # # 1. Use RAG to retrieve relevant docs | |
| # rag_result = general_qa({"query": input_dict["input"]}) | |
| # # 2. Extract docs and prepare text | |
| # source_docs = rag_result.get("source_documents", []) | |
| # combined_text = "\n\n".join([doc.page_content for doc in source_docs]) | |
| # # 3. Run the summarizer chain on the retrieved text | |
| # from scripts.summarizer import get_summarizer | |
| # summarizer_chain = get_summarizer() | |
| # summary = summarizer_chain.run(combined_text) | |
| # # 4. Add sources if any | |
| # sources = list({str(doc.metadata.get("source", "unknown")) for doc in source_docs}) | |
| # if sources: | |
| # summary += f"\n\n📚 Sources: {', '.join(sources)}" | |
| # return {"result": summary} | |
| elif category == "summarize": | |
| # 1) Retrieve relevant documents via your existing RAG chain | |
| rag_result = general_qa({"query": input_dict["input"]}) | |
| # 2) Get the retrieved docs (already LangChain Document objects) | |
| source_docs = rag_result.get("source_documents", []) or [] | |
| # 3) Build the summarizer and prepare the docs list | |
| from langchain.docstore.document import Document | |
| from scripts.summarizer import get_summarizer | |
| summarizer_chain = get_summarizer() | |
| # If retrieval returned nothing, fall back to summarizing the user’s text | |
| docs = source_docs if source_docs else [Document(page_content=input_dict["input"])] | |
| # 4) Summarize — load_summarize_chain returns {"output_text": "..."} | |
| out = summarizer_chain.invoke(docs) | |
| summary = out["output_text"] if isinstance(out, dict) and "output_text" in out else str(out) | |
| # 5) Append sources (only if we actually had retrieved docs) | |
| if source_docs: | |
| sources = sorted({str(d.metadata.get("source", "unknown")) for d in source_docs}) | |
| if sources: | |
| summary += f"\n\n📚 Sources: {', '.join(sources)}" | |
| return {"result": summary} | |
| elif category == "calculate": | |
| prompt = ChatPromptTemplate.from_template( | |
| "Solve the following calculation step-by-step:\n{input}" | |
| ) | |
| chain = prompt | llm | StrOutputParser() | |
| return {"result": chain.invoke({"input": input_dict["input"]})} | |
| else: # "general" | |
| return general_qa({"query": input_dict["input"]}) | |
| return Router() | |