Michele De Stefano
Now using Tavily for web searches. It's a lot more powerful than DuckDuckGo
6770007
| import datetime as dt | |
| import dotenv | |
| import re | |
| from typing import Any, Literal | |
| from langchain_community.tools import DuckDuckGoSearchResults | |
| from langchain_core.messages import SystemMessage, AnyMessage | |
| from langchain_core.runnables import Runnable | |
| from langchain_core.tools import BaseTool | |
| from langchain_ollama import ChatOllama | |
| from langchain_tavily import TavilySearch, TavilyExtract | |
| from langgraph.constants import START, END | |
| from langgraph.graph import MessagesState, StateGraph | |
| from langgraph.graph.graph import CompiledGraph | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| from langgraph.prebuilt import ToolNode | |
| from pydantic import BaseModel | |
| from tools import ( | |
| get_excel_table_content, | |
| get_youtube_video_transcript, | |
| reverse_string, | |
| transcribe_audio_file, | |
| web_page_info_retriever, | |
| youtube_video_to_frame_captions, sum_list, execute_python_script, | |
| ) | |
| dotenv.load_dotenv() | |
| class AgentFactory: | |
| """ | |
| A factory for the agent. It is assumed that an Ollama server is running | |
| on the machine where the factory is used. | |
| """ | |
| __system_prompt: str = ( | |
| "You have to answer to some test questions.\n" | |
| "Sometimes auxiliary files may be attached to the question.\n" | |
| "Each question is a JSON string with the following fields:\n" | |
| "1. task_id: unique hash identifier of the question.\n" | |
| "2. question: the text of the question.\n" | |
| "3. Level: ignore this field.\n" | |
| "4. file_name: the name of the file needed to answer the question. " | |
| "This is empty if the question does not refer to any file. " | |
| "IMPORTANT: The text of the question may mention a file name that is " | |
| "different from what is reported into the \"file_name\" JSON field. " | |
| "YOU HAVE TO IGNORE THE FILE NAME MENTIONED INTO \"question\" AND " | |
| "YOU MUST USE THE FILE NAME PROVIDED INTO THE \"file_name\" FIELD.\n" | |
| "\n" | |
| "Achieve the solution by dividing your reasoning in steps, and\n" | |
| "provide an explanation for each step.\n" | |
| "\n" | |
| "The format of your final answer must be\n" | |
| "\n" | |
| "<ANSWER>your_final_answer</Answer>, where your_final_answer is a\n" | |
| "number OR as few words as possible OR a comma separated list of\n" | |
| "numbers and/or strings. If you are asked for\n" | |
| "a number, don't use comma to write your number neither use units\n" | |
| "such as $ or percent sign unless specified otherwise. If you are\n" | |
| "asked for a string, don't use articles, neither abbreviations (e.g.\n" | |
| "for cities), and write the digits in plain text unless specified\n" | |
| "otherwise. If you are asked for a comma separated list, apply the\n" | |
| "above rules depending of whether the element to be put in the list\n" | |
| "is a number or a string.\n" | |
| "ALWAYS PRESENT THE FINAL ANSWER BETWEEN THE <ANSWER> AND </ANSWER>\n" | |
| "TAGS.\n" | |
| "\n" | |
| "When, for achieving the solution, you have to perform a sum, DON'T\n" | |
| "try to do that yourself. Exploit the tool that is able to sum a list\n" | |
| " of numbers. If you have to sum the results of previous sums, use\n" | |
| "again the same tool, by calling it again.\n" | |
| "You are advised to cycle between reasoning and tool calling also\n" | |
| "multiple times. Provide an answer only when you are sure you don't\n" | |
| "have to call any tool again.\n" | |
| "\n" | |
| f"If you need it, the date today is {dt.date.today()}." | |
| ) | |
| __llm: Runnable | |
| __tools: list[BaseTool] | |
| def __init__( | |
| self, | |
| model: str = "qwen2.5-coder:32b", | |
| # model: str = "mistral-small3.1", | |
| # model: str = "phi4-mini", | |
| temperature: float = 0.0, | |
| num_ctx: int = 8192 | |
| ) -> None: | |
| """ | |
| Constructor. | |
| Args: | |
| model: The name of the Ollama model to use. | |
| temperature: Temperature parameter. | |
| num_ctx: Size of the context window used to generate the | |
| next token. | |
| """ | |
| # search_tool = DuckDuckGoSearchResults( | |
| # description=( | |
| # "A wrapper around Duck Duck Go Search. Useful for when you " | |
| # "need to answer questions about information you can find on " | |
| # "the web. Input should be a search query. It is advisable to " | |
| # "use this tool to retrieve web page URLs and use another tool " | |
| # "to analyze the pages. If the web source is suggested by the " | |
| # "user query, prefer retrieving information from that source. " | |
| # "For example, the query may suggest to search on Wikipedia or " | |
| # "Medium. In those cases, prepend the query with " | |
| # "'site: <name of the source>'. For example: " | |
| # "'site: wikipedia.org'" | |
| # ), | |
| # output_format="list" | |
| # ) | |
| search_tool = TavilySearch( | |
| topic="general", | |
| max_results=5, | |
| include_answer="advanced", | |
| ) | |
| # search_tool.with_retry() | |
| extract_tool = TavilyExtract( | |
| extract_depth="advanced", | |
| include_images=False, | |
| ) | |
| self.__tools = [ | |
| execute_python_script, | |
| get_excel_table_content, | |
| get_youtube_video_transcript, | |
| reverse_string, | |
| search_tool, | |
| extract_tool, | |
| sum_list, | |
| transcribe_audio_file, | |
| # web_page_info_retriever, | |
| youtube_video_to_frame_captions | |
| ] | |
| self.__llm = ChatOllama( | |
| model=model, | |
| temperature=temperature, | |
| num_ctx=num_ctx | |
| ).bind_tools(tools=self.__tools) | |
| # llm_endpoint = HuggingFaceEndpoint( | |
| # repo_id="Qwen/Qwen2.5-72B-Instruct", | |
| # task="text-generation", | |
| # max_new_tokens=num_ctx, | |
| # do_sample=False, | |
| # repetition_penalty=1.03, | |
| # temperature=temperature, | |
| # ) | |
| # | |
| # self.__llm = ( | |
| # ChatHuggingFace(llm=llm_endpoint) | |
| # .bind_tools(tools=self.__tools) | |
| # ) | |
| def __run_llm(self, state: MessagesState) -> dict[str, Any]: | |
| answer = self.__llm.invoke(state["messages"]) | |
| # Remove thinking pattern if present | |
| pattern = r'\n*<think>.*?</think>\n*' | |
| answer.content = re.sub( | |
| pattern, "", answer.content, flags=re.DOTALL | |
| ) | |
| return {"messages": [answer]} | |
| def __extract_last_message( | |
| state: list[AnyMessage] | dict[str, Any] | BaseModel, | |
| messages_key: str | |
| ) -> str: | |
| if isinstance(state, list): | |
| last_message = state[-1] | |
| elif isinstance(state, dict) and (messages := state.get(messages_key, [])): | |
| last_message = messages[-1] | |
| elif messages := getattr(state, messages_key, []): | |
| last_message = messages[-1] | |
| else: | |
| raise ValueError(f"No messages found in input state to tool_edge: {state}") | |
| return last_message | |
| def __route_from_llm( | |
| self, | |
| state: list[AnyMessage] | dict[str, Any] | BaseModel, | |
| messages_key: str = "messages", | |
| ) -> Literal["tools", "extract_final_answer"]: | |
| ai_message = self.__extract_last_message(state, messages_key) | |
| if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: | |
| return "tools" | |
| return "extract_final_answer" | |
| def __extract_final_answer(state: MessagesState) -> dict[str, Any]: | |
| last_message = state["messages"][-1].content | |
| pattern = r"<ANSWER>(?P<answer>.*?)</ANSWER>" | |
| m = re.search(pattern, last_message, flags=re.DOTALL) | |
| answer = m.group("answer").strip() if m else "" | |
| return {"messages": [answer]} | |
| def system_prompt(self) -> SystemMessage: | |
| """ | |
| Returns: | |
| The system prompt to use with the agent. | |
| """ | |
| return SystemMessage(content=self.__system_prompt) | |
| def get(self) -> CompiledGraph: | |
| """ | |
| Factory method. | |
| Returns: | |
| The instance of the agent. | |
| """ | |
| graph_builder = StateGraph(MessagesState) | |
| graph_builder.add_node("LLM", self.__run_llm) | |
| graph_builder.add_node("tools", ToolNode(tools=self.__tools)) | |
| graph_builder.add_node( | |
| "extract_final_answer", | |
| self.__extract_final_answer | |
| ) | |
| graph_builder.add_edge(start_key=START, end_key="LLM") | |
| graph_builder.add_conditional_edges( | |
| source="LLM", | |
| path=self.__route_from_llm, | |
| path_map={ | |
| "tools": "tools", | |
| "extract_final_answer": "extract_final_answer" | |
| } | |
| ) | |
| graph_builder.add_edge(start_key="tools", end_key="LLM") | |
| graph_builder.add_edge(start_key="extract_final_answer", end_key=END) | |
| return graph_builder.compile() | |