|
|
from __future__ import annotations |
|
|
|
|
|
import asyncio |
|
|
import logging |
|
|
import time |
|
|
from typing import ( |
|
|
TYPE_CHECKING, |
|
|
Any, |
|
|
AsyncIterator, |
|
|
Dict, |
|
|
Iterator, |
|
|
List, |
|
|
Optional, |
|
|
Tuple, |
|
|
Union, |
|
|
) |
|
|
from uuid import UUID |
|
|
|
|
|
from langchain_core.agents import ( |
|
|
AgentAction, |
|
|
AgentFinish, |
|
|
AgentStep, |
|
|
) |
|
|
from langchain_core.callbacks import ( |
|
|
AsyncCallbackManager, |
|
|
AsyncCallbackManagerForChainRun, |
|
|
CallbackManager, |
|
|
CallbackManagerForChainRun, |
|
|
Callbacks, |
|
|
) |
|
|
from langchain_core.load.dump import dumpd |
|
|
from langchain_core.outputs import RunInfo |
|
|
from langchain_core.runnables.utils import AddableDict |
|
|
from langchain_core.tools import BaseTool |
|
|
from langchain_core.utils.input import get_color_mapping |
|
|
|
|
|
from langchain.schema import RUN_KEY |
|
|
from langchain.utilities.asyncio import asyncio_timeout |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from langchain.agents.agent import AgentExecutor, NextStepOutput |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class AgentExecutorIterator: |
|
|
"""Iterator for AgentExecutor.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
agent_executor: AgentExecutor, |
|
|
inputs: Any, |
|
|
callbacks: Callbacks = None, |
|
|
*, |
|
|
tags: Optional[list[str]] = None, |
|
|
metadata: Optional[Dict[str, Any]] = None, |
|
|
run_name: Optional[str] = None, |
|
|
run_id: Optional[UUID] = None, |
|
|
include_run_info: bool = False, |
|
|
yield_actions: bool = False, |
|
|
): |
|
|
""" |
|
|
Initialize the AgentExecutorIterator with the given AgentExecutor, |
|
|
inputs, and optional callbacks. |
|
|
|
|
|
Args: |
|
|
agent_executor (AgentExecutor): The AgentExecutor to iterate over. |
|
|
inputs (Any): The inputs to the AgentExecutor. |
|
|
callbacks (Callbacks, optional): The callbacks to use during iteration. |
|
|
Defaults to None. |
|
|
tags (Optional[list[str]], optional): The tags to use during iteration. |
|
|
Defaults to None. |
|
|
metadata (Optional[Dict[str, Any]], optional): The metadata to use |
|
|
during iteration. Defaults to None. |
|
|
run_name (Optional[str], optional): The name of the run. Defaults to None. |
|
|
run_id (Optional[UUID], optional): The ID of the run. Defaults to None. |
|
|
include_run_info (bool, optional): Whether to include run info |
|
|
in the output. Defaults to False. |
|
|
yield_actions (bool, optional): Whether to yield actions as they |
|
|
are generated. Defaults to False. |
|
|
""" |
|
|
self._agent_executor = agent_executor |
|
|
self.inputs = inputs |
|
|
self.callbacks = callbacks |
|
|
self.tags = tags |
|
|
self.metadata = metadata |
|
|
self.run_name = run_name |
|
|
self.run_id = run_id |
|
|
self.include_run_info = include_run_info |
|
|
self.yield_actions = yield_actions |
|
|
self.reset() |
|
|
|
|
|
_inputs: Dict[str, str] |
|
|
callbacks: Callbacks |
|
|
tags: Optional[list[str]] |
|
|
metadata: Optional[Dict[str, Any]] |
|
|
run_name: Optional[str] |
|
|
run_id: Optional[UUID] |
|
|
include_run_info: bool |
|
|
yield_actions: bool |
|
|
|
|
|
@property |
|
|
def inputs(self) -> Dict[str, str]: |
|
|
"""The inputs to the AgentExecutor.""" |
|
|
return self._inputs |
|
|
|
|
|
@inputs.setter |
|
|
def inputs(self, inputs: Any) -> None: |
|
|
self._inputs = self.agent_executor.prep_inputs(inputs) |
|
|
|
|
|
@property |
|
|
def agent_executor(self) -> AgentExecutor: |
|
|
"""The AgentExecutor to iterate over.""" |
|
|
return self._agent_executor |
|
|
|
|
|
@agent_executor.setter |
|
|
def agent_executor(self, agent_executor: AgentExecutor) -> None: |
|
|
self._agent_executor = agent_executor |
|
|
|
|
|
self.inputs = self.inputs |
|
|
|
|
|
@property |
|
|
def name_to_tool_map(self) -> Dict[str, BaseTool]: |
|
|
"""A mapping of tool names to tools.""" |
|
|
return {tool.name: tool for tool in self.agent_executor.tools} |
|
|
|
|
|
@property |
|
|
def color_mapping(self) -> Dict[str, str]: |
|
|
"""A mapping of tool names to colors.""" |
|
|
return get_color_mapping( |
|
|
[tool.name for tool in self.agent_executor.tools], |
|
|
excluded_colors=["green", "red"], |
|
|
) |
|
|
|
|
|
def reset(self) -> None: |
|
|
""" |
|
|
Reset the iterator to its initial state, clearing intermediate steps, |
|
|
iterations, and time elapsed. |
|
|
""" |
|
|
logger.debug("(Re)setting AgentExecutorIterator to fresh state") |
|
|
self.intermediate_steps: list[tuple[AgentAction, str]] = [] |
|
|
self.iterations = 0 |
|
|
|
|
|
self.time_elapsed = 0.0 |
|
|
self.start_time = time.time() |
|
|
|
|
|
def update_iterations(self) -> None: |
|
|
""" |
|
|
Increment the number of iterations and update the time elapsed. |
|
|
""" |
|
|
self.iterations += 1 |
|
|
self.time_elapsed = time.time() - self.start_time |
|
|
logger.debug( |
|
|
f"Agent Iterations: {self.iterations} ({self.time_elapsed:.2f}s elapsed)" |
|
|
) |
|
|
|
|
|
def make_final_outputs( |
|
|
self, |
|
|
outputs: Dict[str, Any], |
|
|
run_manager: Union[CallbackManagerForChainRun, AsyncCallbackManagerForChainRun], |
|
|
) -> AddableDict: |
|
|
|
|
|
|
|
|
|
|
|
prepared_outputs = AddableDict( |
|
|
self.agent_executor.prep_outputs( |
|
|
self.inputs, outputs, return_only_outputs=True |
|
|
) |
|
|
) |
|
|
if self.include_run_info: |
|
|
prepared_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id) |
|
|
return prepared_outputs |
|
|
|
|
|
def __iter__(self: "AgentExecutorIterator") -> Iterator[AddableDict]: |
|
|
logger.debug("Initialising AgentExecutorIterator") |
|
|
self.reset() |
|
|
callback_manager = CallbackManager.configure( |
|
|
self.callbacks, |
|
|
self.agent_executor.callbacks, |
|
|
self.agent_executor.verbose, |
|
|
self.tags, |
|
|
self.agent_executor.tags, |
|
|
self.metadata, |
|
|
self.agent_executor.metadata, |
|
|
) |
|
|
run_manager = callback_manager.on_chain_start( |
|
|
dumpd(self.agent_executor), |
|
|
self.inputs, |
|
|
self.run_id, |
|
|
name=self.run_name, |
|
|
) |
|
|
try: |
|
|
while self.agent_executor._should_continue( |
|
|
self.iterations, self.time_elapsed |
|
|
): |
|
|
|
|
|
|
|
|
next_step_seq: NextStepOutput = [] |
|
|
for chunk in self.agent_executor._iter_next_step( |
|
|
self.name_to_tool_map, |
|
|
self.color_mapping, |
|
|
self.inputs, |
|
|
self.intermediate_steps, |
|
|
run_manager, |
|
|
): |
|
|
next_step_seq.append(chunk) |
|
|
|
|
|
|
|
|
if self.yield_actions: |
|
|
if isinstance(chunk, AgentAction): |
|
|
yield AddableDict(actions=[chunk], messages=chunk.messages) |
|
|
elif isinstance(chunk, AgentStep): |
|
|
yield AddableDict(steps=[chunk], messages=chunk.messages) |
|
|
|
|
|
|
|
|
next_step = self.agent_executor._consume_next_step(next_step_seq) |
|
|
|
|
|
self.update_iterations() |
|
|
|
|
|
output = self._process_next_step_output(next_step, run_manager) |
|
|
is_final = "intermediate_step" not in output |
|
|
|
|
|
|
|
|
if not self.yield_actions or is_final: |
|
|
yield output |
|
|
|
|
|
if is_final: |
|
|
return |
|
|
except BaseException as e: |
|
|
run_manager.on_chain_error(e) |
|
|
raise |
|
|
|
|
|
|
|
|
yield self._stop(run_manager) |
|
|
|
|
|
async def __aiter__(self) -> AsyncIterator[AddableDict]: |
|
|
""" |
|
|
N.B. __aiter__ must be a normal method, so need to initialize async run manager |
|
|
on first __anext__ call where we can await it |
|
|
""" |
|
|
logger.debug("Initialising AgentExecutorIterator (async)") |
|
|
self.reset() |
|
|
callback_manager = AsyncCallbackManager.configure( |
|
|
self.callbacks, |
|
|
self.agent_executor.callbacks, |
|
|
self.agent_executor.verbose, |
|
|
self.tags, |
|
|
self.agent_executor.tags, |
|
|
self.metadata, |
|
|
self.agent_executor.metadata, |
|
|
) |
|
|
run_manager = await callback_manager.on_chain_start( |
|
|
dumpd(self.agent_executor), |
|
|
self.inputs, |
|
|
self.run_id, |
|
|
name=self.run_name, |
|
|
) |
|
|
try: |
|
|
async with asyncio_timeout(self.agent_executor.max_execution_time): |
|
|
while self.agent_executor._should_continue( |
|
|
self.iterations, self.time_elapsed |
|
|
): |
|
|
|
|
|
|
|
|
next_step_seq: NextStepOutput = [] |
|
|
async for chunk in self.agent_executor._aiter_next_step( |
|
|
self.name_to_tool_map, |
|
|
self.color_mapping, |
|
|
self.inputs, |
|
|
self.intermediate_steps, |
|
|
run_manager, |
|
|
): |
|
|
next_step_seq.append(chunk) |
|
|
|
|
|
|
|
|
if self.yield_actions: |
|
|
if isinstance(chunk, AgentAction): |
|
|
yield AddableDict( |
|
|
actions=[chunk], messages=chunk.messages |
|
|
) |
|
|
elif isinstance(chunk, AgentStep): |
|
|
yield AddableDict( |
|
|
steps=[chunk], messages=chunk.messages |
|
|
) |
|
|
|
|
|
|
|
|
next_step = self.agent_executor._consume_next_step(next_step_seq) |
|
|
|
|
|
self.update_iterations() |
|
|
|
|
|
output = await self._aprocess_next_step_output( |
|
|
next_step, run_manager |
|
|
) |
|
|
is_final = "intermediate_step" not in output |
|
|
|
|
|
|
|
|
if not self.yield_actions or is_final: |
|
|
yield output |
|
|
|
|
|
if is_final: |
|
|
return |
|
|
except (TimeoutError, asyncio.TimeoutError): |
|
|
yield await self._astop(run_manager) |
|
|
return |
|
|
except BaseException as e: |
|
|
await run_manager.on_chain_error(e) |
|
|
raise |
|
|
|
|
|
|
|
|
yield await self._astop(run_manager) |
|
|
|
|
|
def _process_next_step_output( |
|
|
self, |
|
|
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]], |
|
|
run_manager: CallbackManagerForChainRun, |
|
|
) -> AddableDict: |
|
|
""" |
|
|
Process the output of the next step, |
|
|
handling AgentFinish and tool return cases. |
|
|
""" |
|
|
logger.debug("Processing output of Agent loop step") |
|
|
if isinstance(next_step_output, AgentFinish): |
|
|
logger.debug( |
|
|
"Hit AgentFinish: _return -> on_chain_end -> run final output logic" |
|
|
) |
|
|
return self._return(next_step_output, run_manager=run_manager) |
|
|
|
|
|
self.intermediate_steps.extend(next_step_output) |
|
|
logger.debug("Updated intermediate_steps with step output") |
|
|
|
|
|
|
|
|
if len(next_step_output) == 1: |
|
|
next_step_action = next_step_output[0] |
|
|
tool_return = self.agent_executor._get_tool_return(next_step_action) |
|
|
if tool_return is not None: |
|
|
return self._return(tool_return, run_manager=run_manager) |
|
|
|
|
|
return AddableDict(intermediate_step=next_step_output) |
|
|
|
|
|
async def _aprocess_next_step_output( |
|
|
self, |
|
|
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]], |
|
|
run_manager: AsyncCallbackManagerForChainRun, |
|
|
) -> AddableDict: |
|
|
""" |
|
|
Process the output of the next async step, |
|
|
handling AgentFinish and tool return cases. |
|
|
""" |
|
|
logger.debug("Processing output of async Agent loop step") |
|
|
if isinstance(next_step_output, AgentFinish): |
|
|
logger.debug( |
|
|
"Hit AgentFinish: _areturn -> on_chain_end -> run final output logic" |
|
|
) |
|
|
return await self._areturn(next_step_output, run_manager=run_manager) |
|
|
|
|
|
self.intermediate_steps.extend(next_step_output) |
|
|
logger.debug("Updated intermediate_steps with step output") |
|
|
|
|
|
|
|
|
if len(next_step_output) == 1: |
|
|
next_step_action = next_step_output[0] |
|
|
tool_return = self.agent_executor._get_tool_return(next_step_action) |
|
|
if tool_return is not None: |
|
|
return await self._areturn(tool_return, run_manager=run_manager) |
|
|
|
|
|
return AddableDict(intermediate_step=next_step_output) |
|
|
|
|
|
def _stop(self, run_manager: CallbackManagerForChainRun) -> AddableDict: |
|
|
""" |
|
|
Stop the iterator and raise a StopIteration exception with the stopped response. |
|
|
""" |
|
|
logger.warning("Stopping agent prematurely due to triggering stop condition") |
|
|
|
|
|
output = self.agent_executor._action_agent.return_stopped_response( |
|
|
self.agent_executor.early_stopping_method, |
|
|
self.intermediate_steps, |
|
|
**self.inputs, |
|
|
) |
|
|
return self._return(output, run_manager=run_manager) |
|
|
|
|
|
async def _astop(self, run_manager: AsyncCallbackManagerForChainRun) -> AddableDict: |
|
|
""" |
|
|
Stop the async iterator and raise a StopAsyncIteration exception with |
|
|
the stopped response. |
|
|
""" |
|
|
logger.warning("Stopping agent prematurely due to triggering stop condition") |
|
|
output = self.agent_executor._action_agent.return_stopped_response( |
|
|
self.agent_executor.early_stopping_method, |
|
|
self.intermediate_steps, |
|
|
**self.inputs, |
|
|
) |
|
|
return await self._areturn(output, run_manager=run_manager) |
|
|
|
|
|
def _return( |
|
|
self, output: AgentFinish, run_manager: CallbackManagerForChainRun |
|
|
) -> AddableDict: |
|
|
""" |
|
|
Return the final output of the iterator. |
|
|
""" |
|
|
returned_output = self.agent_executor._return( |
|
|
output, self.intermediate_steps, run_manager=run_manager |
|
|
) |
|
|
returned_output["messages"] = output.messages |
|
|
run_manager.on_chain_end(returned_output) |
|
|
return self.make_final_outputs(returned_output, run_manager) |
|
|
|
|
|
async def _areturn( |
|
|
self, output: AgentFinish, run_manager: AsyncCallbackManagerForChainRun |
|
|
) -> AddableDict: |
|
|
""" |
|
|
Return the final output of the async iterator. |
|
|
""" |
|
|
returned_output = await self.agent_executor._areturn( |
|
|
output, self.intermediate_steps, run_manager=run_manager |
|
|
) |
|
|
returned_output["messages"] = output.messages |
|
|
await run_manager.on_chain_end(returned_output) |
|
|
return self.make_final_outputs(returned_output, run_manager) |
|
|
|