|
|
"""Load agent.""" |
|
|
|
|
|
from typing import Any, Optional, Sequence |
|
|
|
|
|
from langchain_core._api import deprecated |
|
|
from langchain_core.callbacks import BaseCallbackManager |
|
|
from langchain_core.language_models import BaseLanguageModel |
|
|
from langchain_core.tools import BaseTool |
|
|
|
|
|
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING |
|
|
from langchain.agents.agent import AgentExecutor |
|
|
from langchain.agents.agent_types import AgentType |
|
|
from langchain.agents.loading import AGENT_TO_CLASS, load_agent |
|
|
|
|
|
|
|
|
@deprecated( |
|
|
"0.1.0", |
|
|
message=AGENT_DEPRECATION_WARNING, |
|
|
removal="1.0", |
|
|
) |
|
|
def initialize_agent( |
|
|
tools: Sequence[BaseTool], |
|
|
llm: BaseLanguageModel, |
|
|
agent: Optional[AgentType] = None, |
|
|
callback_manager: Optional[BaseCallbackManager] = None, |
|
|
agent_path: Optional[str] = None, |
|
|
agent_kwargs: Optional[dict] = None, |
|
|
*, |
|
|
tags: Optional[Sequence[str]] = None, |
|
|
**kwargs: Any, |
|
|
) -> AgentExecutor: |
|
|
"""Load an agent executor given tools and LLM. |
|
|
|
|
|
Args: |
|
|
tools: List of tools this agent has access to. |
|
|
llm: Language model to use as the agent. |
|
|
agent: Agent type to use. If None and agent_path is also None, will default |
|
|
to AgentType.ZERO_SHOT_REACT_DESCRIPTION. Defaults to None. |
|
|
callback_manager: CallbackManager to use. Global callback manager is used if |
|
|
not provided. Defaults to None. |
|
|
agent_path: Path to serialized agent to use. If None and agent is also None, |
|
|
will default to AgentType.ZERO_SHOT_REACT_DESCRIPTION. Defaults to None. |
|
|
agent_kwargs: Additional keyword arguments to pass to the underlying agent. |
|
|
Defaults to None. |
|
|
tags: Tags to apply to the traced runs. Defaults to None. |
|
|
kwargs: Additional keyword arguments passed to the agent executor. |
|
|
|
|
|
Returns: |
|
|
An agent executor. |
|
|
|
|
|
Raises: |
|
|
ValueError: If both `agent` and `agent_path` are specified. |
|
|
ValueError: If `agent` is not a valid agent type. |
|
|
ValueError: If both `agent` and `agent_path` are None. |
|
|
""" |
|
|
tags_ = list(tags) if tags else [] |
|
|
if agent is None and agent_path is None: |
|
|
agent = AgentType.ZERO_SHOT_REACT_DESCRIPTION |
|
|
if agent is not None and agent_path is not None: |
|
|
raise ValueError( |
|
|
"Both `agent` and `agent_path` are specified, " |
|
|
"but at most only one should be." |
|
|
) |
|
|
if agent is not None: |
|
|
if agent not in AGENT_TO_CLASS: |
|
|
raise ValueError( |
|
|
f"Got unknown agent type: {agent}. " |
|
|
f"Valid types are: {AGENT_TO_CLASS.keys()}." |
|
|
) |
|
|
tags_.append(agent.value if isinstance(agent, AgentType) else agent) |
|
|
agent_cls = AGENT_TO_CLASS[agent] |
|
|
agent_kwargs = agent_kwargs or {} |
|
|
agent_obj = agent_cls.from_llm_and_tools( |
|
|
llm, tools, callback_manager=callback_manager, **agent_kwargs |
|
|
) |
|
|
elif agent_path is not None: |
|
|
agent_obj = load_agent( |
|
|
agent_path, llm=llm, tools=tools, callback_manager=callback_manager |
|
|
) |
|
|
try: |
|
|
|
|
|
tags_.append(agent_obj._agent_type) |
|
|
except NotImplementedError: |
|
|
pass |
|
|
else: |
|
|
raise ValueError( |
|
|
"Somehow both `agent` and `agent_path` are None, " |
|
|
"this should never happen." |
|
|
) |
|
|
return AgentExecutor.from_agent_and_tools( |
|
|
agent=agent_obj, |
|
|
tools=tools, |
|
|
callback_manager=callback_manager, |
|
|
tags=tags_, |
|
|
**kwargs, |
|
|
) |
|
|
|