Spaces:
Runtime error
Runtime error
| # Import modules | |
| from typing import TypedDict, Dict | |
| from langgraph.graph import StateGraph, END | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.runnables.graph import MermaidDrawMethod | |
| # from IPython.display import Image, display | |
| import gradio as gr | |
| import os | |
| from langchain_groq import ChatGroq | |
| # Define the State data structure | |
| class State(TypedDict): | |
| query: str | |
| category: str | |
| sentiment: str | |
| response: str | |
| # Function to get the language model | |
| def get_llm(api_key=None): | |
| if api_key is None: | |
| api_key = os.getenv('GROQ_API_KEY') | |
| llm = ChatGroq( | |
| temperature=0, | |
| groq_api_key=api_key, | |
| model_name="llama-3.3-70b-versatile" | |
| ) | |
| return llm | |
| # Define the processing functions | |
| def categorize(state: State, llm) -> State: | |
| prompt = ChatPromptTemplate.from_template( | |
| "Categorize the following customer query into one of these categories: " | |
| "Technical, Billing, General. Query: {query}" | |
| ) | |
| chain = prompt | llm | |
| category = chain.invoke({"query": state["query"]}).content.strip() | |
| state["category"] = category | |
| return state | |
| def analyze_sentiment(state: State, llm) -> State: | |
| prompt = ChatPromptTemplate.from_template( | |
| "Analyze the sentiment of the following customer query. " | |
| "Respond with either 'Positive', 'Neutral', or 'Negative'. Query: {query}" | |
| ) | |
| chain = prompt | llm | |
| sentiment = chain.invoke({"query": state["query"]}).content.strip() | |
| state["sentiment"] = sentiment | |
| return state | |
| def handle_technical(state: State, llm) -> State: | |
| prompt = ChatPromptTemplate.from_template( | |
| "Provide a technical support response to the following query: {query}" | |
| ) | |
| chain = prompt | llm | |
| response = chain.invoke({"query": state["query"]}).content.strip() | |
| state["response"] = response | |
| return state | |
| def handle_billing(state: State, llm) -> State: | |
| prompt = ChatPromptTemplate.from_template( | |
| "Provide a billing-related support response to the following query: {query}" | |
| ) | |
| chain = prompt | llm | |
| response = chain.invoke({"query": state["query"]}).content.strip() | |
| state["response"] = response | |
| return state | |
| def handle_general(state: State, llm) -> State: | |
| prompt = ChatPromptTemplate.from_template( | |
| "Provide a general support response to the following query: {query}" | |
| ) | |
| chain = prompt | llm | |
| response = chain.invoke({"query": state["query"]}).content.strip() | |
| state["response"] = response | |
| return state | |
| def escalate(state: State) -> State: | |
| state["response"] = "This query has been escalated to a human agent due to its negative sentiment." | |
| return state | |
| def route_query(state: State) -> str: | |
| if state["sentiment"].lower() == "negative": | |
| return "escalate" | |
| elif state["category"].lower() == "technical": | |
| return "handle_technical" | |
| elif state["category"].lower() == "billing": | |
| return "handle_billing" | |
| else: | |
| return "handle_general" | |
| # Function to compile the workflow | |
| def get_workflow(llm): | |
| workflow = StateGraph(State) | |
| workflow.add_node("categorize", lambda state: categorize(state, llm)) | |
| workflow.add_node("analyze_sentiment", lambda state: analyze_sentiment(state, llm)) | |
| workflow.add_node("handle_technical", lambda state: handle_technical(state, llm)) | |
| workflow.add_node("handle_billing", lambda state: handle_billing(state, llm)) | |
| workflow.add_node("handle_general", lambda state: handle_general(state, llm)) | |
| workflow.add_node("escalate", escalate) | |
| workflow.add_edge("categorize", "analyze_sentiment") | |
| workflow.add_conditional_edges("analyze_sentiment", | |
| route_query, { | |
| "handle_technical": "handle_technical", | |
| "handle_billing": "handle_billing", | |
| "handle_general": "handle_general", | |
| "escalate": "escalate", | |
| }) | |
| workflow.add_edge("handle_technical", END) | |
| workflow.add_edge("handle_billing", END) | |
| workflow.add_edge("handle_general", END) | |
| workflow.add_edge("escalate", END) | |
| workflow.set_entry_point("categorize") | |
| return workflow.compile() | |
| # Gradio interface function | |
| def run_customer_support(query: str, api_key: str) -> Dict[str, str]: | |
| llm = get_llm(api_key) | |
| app = get_workflow(llm) | |
| result = app.invoke({"query": query}) | |
| return { | |
| # "Query": query, | |
| # "Category": result.get("category", "").strip(), | |
| # "Sentiment": result.get("sentiment", "").strip(), | |
| "Response": result.get("response", "").strip() | |
| } | |
| # Create the Gradio interface | |
| gr_interface = gr.Interface( | |
| fn=run_customer_support, | |
| inputs=[ | |
| gr.Textbox(lines=2, label="Customer Query", placeholder="Enter your customer support query here..."), | |
| gr.Textbox(label="GROQ API Key", placeholder="Enter your GROQ API key"), | |
| ], | |
| outputs=gr.JSON(label="Response"), | |
| title="Customer Support Chatbot", | |
| description="Enter your query to receive assistance.", | |
| ) | |
| # Launch the Gradio interface | |
| gr_interface.launch() |