Commit
·
283e426
1
Parent(s):
809f87e
add tools to langgraph
Browse files- app.py +6 -3
- langgraph_dir/agent.py +21 -13
- langgraph_dir/custom_tools.py +50 -15
- requirements.txt +4 -1
app.py
CHANGED
|
@@ -10,8 +10,8 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
|
| 10 |
|
| 11 |
|
| 12 |
# --- Choice of framework (either "langgraph" or "llamaindex") ---
|
| 13 |
-
|
| 14 |
-
FRAMEWORK = 'llamaindex'
|
| 15 |
|
| 16 |
|
| 17 |
async def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
@@ -98,7 +98,10 @@ async def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
| 98 |
pass
|
| 99 |
|
| 100 |
# call the agent
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
| 102 |
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
|
| 103 |
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
|
| 104 |
agent.ctx.clear() # clear context for next question
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
# --- Choice of framework (either "langgraph" or "llamaindex") ---
|
| 13 |
+
FRAMEWORK = 'langgraph'
|
| 14 |
+
# FRAMEWORK = 'llamaindex'
|
| 15 |
|
| 16 |
|
| 17 |
async def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
|
|
| 98 |
pass
|
| 99 |
|
| 100 |
# call the agent
|
| 101 |
+
if FRAMEWORK == 'llamaindex':
|
| 102 |
+
submitted_answer = await agent(question_text)
|
| 103 |
+
else:
|
| 104 |
+
submitted_answer = agent(question_text)
|
| 105 |
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
|
| 106 |
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
|
| 107 |
agent.ctx.clear() # clear context for next question
|
langgraph_dir/agent.py
CHANGED
|
@@ -1,12 +1,15 @@
|
|
| 1 |
-
|
| 2 |
|
|
|
|
| 3 |
from langchain_openai import ChatOpenAI
|
| 4 |
from langgraph.graph import MessagesState
|
| 5 |
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
|
| 6 |
from langgraph.graph import StateGraph, START, END
|
|
|
|
|
|
|
| 7 |
|
| 8 |
from .prompt import system_prompt
|
| 9 |
-
from .custom_tools import multiply, add, divide
|
| 10 |
|
| 11 |
|
| 12 |
class LangGraphAgent:
|
|
@@ -16,11 +19,18 @@ class LangGraphAgent:
|
|
| 16 |
show_prompt=True):
|
| 17 |
|
| 18 |
# =========== LLM definition ===========
|
| 19 |
-
llm = ChatOpenAI(model=model_name, temperature=0)
|
| 20 |
print(f"LangGraphAgent initialized with model \"{model_name}\"")
|
| 21 |
|
| 22 |
# =========== Augment the LLM with tools ===========
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
tools_by_name = {tool.name: tool for tool in tools}
|
| 25 |
llm_with_tools = llm.bind_tools(tools)
|
| 26 |
|
|
@@ -95,17 +105,15 @@ class LangGraphAgent:
|
|
| 95 |
# Compile the agent
|
| 96 |
self.agent = agent_builder.compile()
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
# print(tool.metadata.description)
|
| 103 |
|
| 104 |
-
# if show_prompt:
|
| 105 |
-
# prompt_dict = self.agent.get_prompts()
|
| 106 |
-
# for k, v in prompt_dict.items():
|
| 107 |
-
# print("\n" + "="*30 + f" Prompt: {k} " + "="*30)
|
| 108 |
-
# print(v.template)
|
| 109 |
|
| 110 |
def __call__(self, question: str) -> str:
|
| 111 |
print("\n\n"+"*"*50)
|
|
|
|
| 1 |
+
import json
|
| 2 |
|
| 3 |
+
from typing import Literal
|
| 4 |
from langchain_openai import ChatOpenAI
|
| 5 |
from langgraph.graph import MessagesState
|
| 6 |
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
|
| 7 |
from langgraph.graph import StateGraph, START, END
|
| 8 |
+
from langchain.agents import load_tools
|
| 9 |
+
from langchain_community.tools.riza.command import ExecPython
|
| 10 |
|
| 11 |
from .prompt import system_prompt
|
| 12 |
+
from .custom_tools import multiply, add, subtract, divide, modulus, power
|
| 13 |
|
| 14 |
|
| 15 |
class LangGraphAgent:
|
|
|
|
| 19 |
show_prompt=True):
|
| 20 |
|
| 21 |
# =========== LLM definition ===========
|
| 22 |
+
llm = ChatOpenAI(model=model_name, temperature=0) # needs OPENAI_API_KEY
|
| 23 |
print(f"LangGraphAgent initialized with model \"{model_name}\"")
|
| 24 |
|
| 25 |
# =========== Augment the LLM with tools ===========
|
| 26 |
+
community_tool_names = [
|
| 27 |
+
"ddg-search", # DuckDuckGo search
|
| 28 |
+
"wikipedia",
|
| 29 |
+
]
|
| 30 |
+
community_tools = load_tools(community_tool_names)
|
| 31 |
+
community_tools += [ExecPython()] # Riza code interpreter (needs RIZA_API_KEY) (not supported by load_tools)
|
| 32 |
+
custom_tools = [multiply, add, subtract, divide, modulus, power]
|
| 33 |
+
tools = community_tools + custom_tools
|
| 34 |
tools_by_name = {tool.name: tool for tool in tools}
|
| 35 |
llm_with_tools = llm.bind_tools(tools)
|
| 36 |
|
|
|
|
| 105 |
# Compile the agent
|
| 106 |
self.agent = agent_builder.compile()
|
| 107 |
|
| 108 |
+
if show_tools_desc:
|
| 109 |
+
for i, tool in enumerate(llm_with_tools.kwargs['tools']):
|
| 110 |
+
print("\n" + "="*30 + f" Tool {i+1} " + "="*30)
|
| 111 |
+
print(json.dumps(tool[tool['type']], indent=4))
|
| 112 |
|
| 113 |
+
if show_prompt:
|
| 114 |
+
print("\n" + "="*30 + f" System prompt " + "="*30)
|
| 115 |
+
print(system_prompt)
|
|
|
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
def __call__(self, question: str) -> str:
|
| 119 |
print("\n\n"+"*"*50)
|
langgraph_dir/custom_tools.py
CHANGED
|
@@ -1,33 +1,68 @@
|
|
| 1 |
from langchain_core.tools import tool
|
| 2 |
|
| 3 |
@tool
|
| 4 |
-
def multiply(a:
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
Args:
|
| 8 |
-
a: first
|
| 9 |
-
b: second
|
| 10 |
"""
|
| 11 |
return a * b
|
| 12 |
|
| 13 |
|
| 14 |
@tool
|
| 15 |
-
def add(a:
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
Args:
|
| 19 |
-
a: first
|
| 20 |
-
b: second
|
| 21 |
"""
|
| 22 |
return a + b
|
| 23 |
|
| 24 |
|
| 25 |
@tool
|
| 26 |
-
def
|
| 27 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
Args:
|
| 30 |
-
a: first
|
| 31 |
-
b: second
|
| 32 |
"""
|
| 33 |
-
return a
|
|
|
|
| 1 |
from langchain_core.tools import tool
|
| 2 |
|
| 3 |
@tool
|
| 4 |
+
def multiply(a: float, b: float) -> float:
|
| 5 |
+
"""
|
| 6 |
+
Multiplies two numbers.
|
| 7 |
Args:
|
| 8 |
+
a (float): the first number
|
| 9 |
+
b (float): the second number
|
| 10 |
"""
|
| 11 |
return a * b
|
| 12 |
|
| 13 |
|
| 14 |
@tool
|
| 15 |
+
def add(a: float, b: float) -> float:
|
| 16 |
+
"""
|
| 17 |
+
Adds two numbers.
|
| 18 |
Args:
|
| 19 |
+
a (float): the first number
|
| 20 |
+
b (float): the second number
|
| 21 |
"""
|
| 22 |
return a + b
|
| 23 |
|
| 24 |
|
| 25 |
@tool
|
| 26 |
+
def subtract(a: float, b: float) -> int:
|
| 27 |
+
"""
|
| 28 |
+
Subtracts two numbers.
|
| 29 |
+
Args:
|
| 30 |
+
a (float): the first number
|
| 31 |
+
b (float): the second number
|
| 32 |
+
"""
|
| 33 |
+
return a - b
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@tool
|
| 37 |
+
def divide(a: float, b: float) -> float:
|
| 38 |
+
"""
|
| 39 |
+
Divides two numbers.
|
| 40 |
+
Args:
|
| 41 |
+
a (float): the first float number
|
| 42 |
+
b (float): the second float number
|
| 43 |
+
"""
|
| 44 |
+
if b == 0:
|
| 45 |
+
raise ValueError("Cannot divided by zero.")
|
| 46 |
+
return a / b
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@tool
|
| 50 |
+
def modulus(a: int, b: int) -> int:
|
| 51 |
+
"""
|
| 52 |
+
Get the modulus of two numbers.
|
| 53 |
+
Args:
|
| 54 |
+
a (int): the first number
|
| 55 |
+
b (int): the second number
|
| 56 |
+
"""
|
| 57 |
+
return a % b
|
| 58 |
+
|
| 59 |
|
| 60 |
+
@tool
|
| 61 |
+
def power(a: float, b: float) -> float:
|
| 62 |
+
"""
|
| 63 |
+
Get the power of two numbers.
|
| 64 |
Args:
|
| 65 |
+
a (float): the first number
|
| 66 |
+
b (float): the second number
|
| 67 |
"""
|
| 68 |
+
return a**b
|
requirements.txt
CHANGED
|
@@ -7,4 +7,7 @@ llama_index.tools.duckduckgo
|
|
| 7 |
llama_index.tools.code_interpreter
|
| 8 |
langchain
|
| 9 |
langgraph
|
| 10 |
-
langchain-openai
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
llama_index.tools.code_interpreter
|
| 8 |
langchain
|
| 9 |
langgraph
|
| 10 |
+
langchain-openai
|
| 11 |
+
langchain-community
|
| 12 |
+
duckduckgo-search
|
| 13 |
+
rizaio
|