Spaces:
Configuration error
Configuration error
| import os | |
| from dotenv import load_dotenv | |
| from typing import TypedDict, Optional, Annotated | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_mistralai import ChatMistralAI | |
| from langchain_groq import ChatGroq | |
| from langgraph.graph import StateGraph, START, END | |
| from langchain_core.messages import AnyMessage, HumanMessage | |
| from langgraph.graph.message import add_messages | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from custom_tools import custom_tools | |
| class QuestionState(TypedDict): | |
| input_file: Optional[str] | |
| messages: Annotated[list[AnyMessage], add_messages] | |
| class NodesReActAgent: | |
| def __init__(self, provider: str="Google", model: str="gemini-2.5-pro"): | |
| print('Initializing ReActAgent...') | |
| load_dotenv() | |
| # Set up the LLM based on provider | |
| if provider == "Google": | |
| os.environ["GOOGLE_API_KEY"] = os.getenv("GOOGLE") | |
| llm = ChatGoogleGenerativeAI(model=model, temperature=0, max_retries=5) | |
| elif provider == "Mistral": | |
| os.environ["MISTRAL_API_KEY"] = os.getenv("MISTRAL") | |
| llm = ChatMistralAI(model=model, temperature=0, max_retries=5) | |
| elif provider == "Groq": | |
| os.environ["GROQ_API_KEY"] = os.getenv("GROQ") | |
| llm = ChatGroq(model=model, temperature=0, max_retries=5) | |
| else: | |
| raise ValueError(f"Unknown provider: {provider}") | |
| self.llm_with_tools = llm.bind_tools(custom_tools) | |
| def assistant(state: QuestionState): | |
| input_file = state["input_file"] | |
| sys_prompt = f""" | |
| You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].\n | |
| \n | |
| YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, DON'T use comma to write your number NEITHER use units such as $ or percent sign unless specified otherwise. If you are asked for a string, DON'T use articles, NEITHER abbreviations (e.g. for cities) capitalize the first letter, and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending, unless the first letter capitalization, whether the element to be put in the list is a number or a string.\n | |
| \n | |
| EXAMPLES:\n | |
| - What is US President Obama's first name? FINAL ANSWER: Barack\n | |
| - What are the 3 mandatory ingredients for pancakes? FINAL ANSWER: eggs, flour, milk\n | |
| - What is the final cost of an invoice comprising a $345.00 product and a $355.00 product? Provide the answer with two decimals. FINAL ANSWER: 700.00\n | |
| - How many pairs of chromosomes does a human cell contain? FINAL ANSWER : 23\n | |
| \n | |
| \n | |
| You will be provided with tools to help you answer questions.\n | |
| If you are asked to make a calculation, absolutely use the tools provided to you. You should AVOID calculating by yourself and ABSOLUTELY use appropriate tools.\n | |
| If you are asked to find something in a list of things or people, prefer using the wiki_search tool. Else, prefer to use the web_search tool. After using the web_search tool, look for the first URL provided with the url_search tool and ask yourself if the answer is in the tool response. If it is, answer the question. If not, search on other links.\n | |
| \n | |
| If needed, use one tool first, then use the output of that tool as an input to another thinking then to the use of another tool.\n | |
| \n | |
| \n You have access to some optional files. Currently the loaded file is: {input_file}" | |
| """ | |
| return { | |
| "messages": [self.llm_with_tools.invoke([sys_prompt] + state["messages"])], | |
| "input_file": state["input_file"] | |
| } | |
| # The graph | |
| builder = StateGraph(QuestionState) | |
| # Define nodes: these do the work | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("tools", ToolNode(custom_tools)) | |
| builder.add_edge(START, "assistant") | |
| builder.add_conditional_edges( | |
| "assistant", | |
| # If the latest message requires a tool, route to "tools" | |
| # Otherwise, route to "END" and provide a direct response | |
| tools_condition, | |
| ) | |
| builder.add_edge("tools", "assistant") | |
| self.react_graph = builder.compile() | |
| print(f"ReActAgent initialized with {provider} - {model}.") | |
| def __call__(self, question: str, input_file: str = "") -> str: | |
| input_msg = [HumanMessage(content=question)] | |
| out = self.react_graph.invoke({"messages": input_msg, "input_file": input_file}) | |
| for o in out["messages"]: | |
| o.pretty_print() | |
| # The last message contains the agent's reply | |
| reply = out["messages"][-1].content | |
| # Optionally, strip out “Final Answer:” headers | |
| if "FINAL ANSWER: " in reply: | |
| reply = reply.split("FINAL ANSWER: ")[-1].strip() | |
| return reply | |