|
|
"""Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import TYPE_CHECKING, Any, List, Optional, Sequence |
|
|
|
|
|
from langchain_core._api import deprecated |
|
|
from langchain_core.documents import Document |
|
|
from langchain_core.language_models import BaseLanguageModel |
|
|
from langchain_core.prompts import BasePromptTemplate |
|
|
from langchain_core.tools import BaseTool, Tool |
|
|
from pydantic import Field |
|
|
|
|
|
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING |
|
|
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser |
|
|
from langchain.agents.agent_types import AgentType |
|
|
from langchain.agents.react.output_parser import ReActOutputParser |
|
|
from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT |
|
|
from langchain.agents.react.wiki_prompt import WIKI_PROMPT |
|
|
from langchain.agents.utils import validate_tools_single_input |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from langchain_community.docstore.base import Docstore |
|
|
|
|
|
|
|
|
@deprecated( |
|
|
"0.1.0", |
|
|
message=AGENT_DEPRECATION_WARNING, |
|
|
removal="1.0", |
|
|
) |
|
|
class ReActDocstoreAgent(Agent): |
|
|
"""Agent for the ReAct chain.""" |
|
|
|
|
|
output_parser: AgentOutputParser = Field(default_factory=ReActOutputParser) |
|
|
|
|
|
@classmethod |
|
|
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser: |
|
|
return ReActOutputParser() |
|
|
|
|
|
@property |
|
|
def _agent_type(self) -> str: |
|
|
"""Return Identifier of an agent type.""" |
|
|
return AgentType.REACT_DOCSTORE |
|
|
|
|
|
@classmethod |
|
|
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate: |
|
|
"""Return default prompt.""" |
|
|
return WIKI_PROMPT |
|
|
|
|
|
@classmethod |
|
|
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: |
|
|
validate_tools_single_input(cls.__name__, tools) |
|
|
super()._validate_tools(tools) |
|
|
if len(tools) != 2: |
|
|
raise ValueError(f"Exactly two tools must be specified, but got {tools}") |
|
|
tool_names = {tool.name for tool in tools} |
|
|
if tool_names != {"Lookup", "Search"}: |
|
|
raise ValueError( |
|
|
f"Tool names should be Lookup and Search, got {tool_names}" |
|
|
) |
|
|
|
|
|
@property |
|
|
def observation_prefix(self) -> str: |
|
|
"""Prefix to append the observation with.""" |
|
|
return "Observation: " |
|
|
|
|
|
@property |
|
|
def _stop(self) -> List[str]: |
|
|
return ["\nObservation:"] |
|
|
|
|
|
@property |
|
|
def llm_prefix(self) -> str: |
|
|
"""Prefix to append the LLM call with.""" |
|
|
return "Thought:" |
|
|
|
|
|
|
|
|
@deprecated( |
|
|
"0.1.0", |
|
|
message=AGENT_DEPRECATION_WARNING, |
|
|
removal="1.0", |
|
|
) |
|
|
class DocstoreExplorer: |
|
|
"""Class to assist with exploration of a document store.""" |
|
|
|
|
|
def __init__(self, docstore: Docstore): |
|
|
"""Initialize with a docstore, and set initial document to None.""" |
|
|
self.docstore = docstore |
|
|
self.document: Optional[Document] = None |
|
|
self.lookup_str = "" |
|
|
self.lookup_index = 0 |
|
|
|
|
|
def search(self, term: str) -> str: |
|
|
"""Search for a term in the docstore, and if found save.""" |
|
|
result = self.docstore.search(term) |
|
|
if isinstance(result, Document): |
|
|
self.document = result |
|
|
return self._summary |
|
|
else: |
|
|
self.document = None |
|
|
return result |
|
|
|
|
|
def lookup(self, term: str) -> str: |
|
|
"""Lookup a term in document (if saved).""" |
|
|
if self.document is None: |
|
|
raise ValueError("Cannot lookup without a successful search first") |
|
|
if term.lower() != self.lookup_str: |
|
|
self.lookup_str = term.lower() |
|
|
self.lookup_index = 0 |
|
|
else: |
|
|
self.lookup_index += 1 |
|
|
lookups = [p for p in self._paragraphs if self.lookup_str in p.lower()] |
|
|
if len(lookups) == 0: |
|
|
return "No Results" |
|
|
elif self.lookup_index >= len(lookups): |
|
|
return "No More Results" |
|
|
else: |
|
|
result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})" |
|
|
return f"{result_prefix} {lookups[self.lookup_index]}" |
|
|
|
|
|
@property |
|
|
def _summary(self) -> str: |
|
|
return self._paragraphs[0] |
|
|
|
|
|
@property |
|
|
def _paragraphs(self) -> List[str]: |
|
|
if self.document is None: |
|
|
raise ValueError("Cannot get paragraphs without a document") |
|
|
return self.document.page_content.split("\n\n") |
|
|
|
|
|
|
|
|
@deprecated( |
|
|
"0.1.0", |
|
|
message=AGENT_DEPRECATION_WARNING, |
|
|
removal="1.0", |
|
|
) |
|
|
class ReActTextWorldAgent(ReActDocstoreAgent): |
|
|
"""Agent for the ReAct TextWorld chain.""" |
|
|
|
|
|
@classmethod |
|
|
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate: |
|
|
"""Return default prompt.""" |
|
|
return TEXTWORLD_PROMPT |
|
|
|
|
|
@classmethod |
|
|
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: |
|
|
validate_tools_single_input(cls.__name__, tools) |
|
|
super()._validate_tools(tools) |
|
|
if len(tools) != 1: |
|
|
raise ValueError(f"Exactly one tool must be specified, but got {tools}") |
|
|
tool_names = {tool.name for tool in tools} |
|
|
if tool_names != {"Play"}: |
|
|
raise ValueError(f"Tool name should be Play, got {tool_names}") |
|
|
|
|
|
|
|
|
@deprecated( |
|
|
"0.1.0", |
|
|
message=AGENT_DEPRECATION_WARNING, |
|
|
removal="1.0", |
|
|
) |
|
|
class ReActChain(AgentExecutor): |
|
|
"""[Deprecated] Chain that implements the ReAct paper.""" |
|
|
|
|
|
def __init__(self, llm: BaseLanguageModel, docstore: Docstore, **kwargs: Any): |
|
|
"""Initialize with the LLM and a docstore.""" |
|
|
docstore_explorer = DocstoreExplorer(docstore) |
|
|
tools = [ |
|
|
Tool( |
|
|
name="Search", |
|
|
func=docstore_explorer.search, |
|
|
description="Search for a term in the docstore.", |
|
|
), |
|
|
Tool( |
|
|
name="Lookup", |
|
|
func=docstore_explorer.lookup, |
|
|
description="Lookup a term in the docstore.", |
|
|
), |
|
|
] |
|
|
agent = ReActDocstoreAgent.from_llm_and_tools(llm, tools) |
|
|
super().__init__(agent=agent, tools=tools, **kwargs) |
|
|
|