Spaces:
Sleeping
Sleeping
| """ | |
| This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| from typing import Any, List, Sequence, Tuple, Optional, Union | |
| from pydantic.schema import model_schema | |
| from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser | |
| from langchain.memory import ConversationBufferWindowMemory | |
| from langchain.agents.agent import Agent | |
| from langchain.chains.llm import LLMChain | |
| from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate | |
| from langchain.agents.agent import AgentOutputParser | |
| from langchain.output_parsers import OutputFixingParser | |
| from langchain.pydantic_v1 import Field | |
| from langchain.schema import AgentAction, AgentFinish, OutputParserException, BasePromptTemplate | |
| from langchain.agents.agent import AgentExecutor | |
| from langchain.callbacks.base import BaseCallbackManager | |
| from langchain.schema.language_model import BaseLanguageModel | |
| from langchain.tools.base import BaseTool | |
| HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}" | |
| logger = logging.getLogger(__name__) | |
| class StructuredChatOutputParserWithRetries(AgentOutputParser): | |
| """Output parser with retries for the structured chat agent.""" | |
| base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser) | |
| """The base parser to use.""" | |
| output_fixing_parser: Optional[OutputFixingParser] = None | |
| """The output fixing parser to use.""" | |
| def parse(self, text: str) -> Union[AgentAction, AgentFinish]: | |
| special_tokens = ["Action:", "<|observation|>"] | |
| first_index = min([text.find(token) if token in text else len(text) for token in special_tokens]) | |
| text = text[:first_index] | |
| if "tool_call" in text: | |
| action_end = text.find("```") | |
| action = text[:action_end].strip() | |
| params_str_start = text.find("(") + 1 | |
| params_str_end = text.rfind(")") | |
| params_str = text[params_str_start:params_str_end] | |
| params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param] | |
| params = {pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs} | |
| action_json = { | |
| "action": action, | |
| "action_input": params | |
| } | |
| else: | |
| action_json = { | |
| "action": "Final Answer", | |
| "action_input": text | |
| } | |
| action_str = f""" | |
| Action: | |
| ``` | |
| {json.dumps(action_json, ensure_ascii=False)} | |
| ```""" | |
| try: | |
| if self.output_fixing_parser is not None: | |
| parsed_obj: Union[ | |
| AgentAction, AgentFinish | |
| ] = self.output_fixing_parser.parse(action_str) | |
| else: | |
| parsed_obj = self.base_parser.parse(action_str) | |
| return parsed_obj | |
| except Exception as e: | |
| raise OutputParserException(f"Could not parse LLM output: {text}") from e | |
| def _type(self) -> str: | |
| return "structured_chat_ChatGLM3_6b_with_retries" | |
| class StructuredGLM3ChatAgent(Agent): | |
| """Structured Chat Agent.""" | |
| output_parser: AgentOutputParser = Field( | |
| default_factory=StructuredChatOutputParserWithRetries | |
| ) | |
| """Output parser for the agent.""" | |
| def observation_prefix(self) -> str: | |
| """Prefix to append the ChatGLM3-6B observation with.""" | |
| return "Observation:" | |
| def llm_prefix(self) -> str: | |
| """Prefix to append the llm call with.""" | |
| return "Thought:" | |
| def _construct_scratchpad( | |
| self, intermediate_steps: List[Tuple[AgentAction, str]] | |
| ) -> str: | |
| agent_scratchpad = super()._construct_scratchpad(intermediate_steps) | |
| if not isinstance(agent_scratchpad, str): | |
| raise ValueError("agent_scratchpad should be of type string.") | |
| if agent_scratchpad: | |
| return ( | |
| f"This was your previous work " | |
| f"(but I haven't seen any of it! I only see what " | |
| f"you return as final answer):\n{agent_scratchpad}" | |
| ) | |
| else: | |
| return agent_scratchpad | |
| def _get_default_output_parser( | |
| cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any | |
| ) -> AgentOutputParser: | |
| return StructuredChatOutputParserWithRetries(llm=llm) | |
| def _stop(self) -> List[str]: | |
| return ["<|observation|>"] | |
| def create_prompt( | |
| cls, | |
| tools: Sequence[BaseTool], | |
| prompt: str = None, | |
| input_variables: Optional[List[str]] = None, | |
| memory_prompts: Optional[List[BasePromptTemplate]] = None, | |
| ) -> BasePromptTemplate: | |
| tools_json = [] | |
| tool_names = [] | |
| for tool in tools: | |
| tool_schema = model_schema(tool.args_schema) if tool.args_schema else {} | |
| simplified_config_langchain = { | |
| "name": tool.name, | |
| "description": tool.description, | |
| "parameters": tool_schema.get("properties", {}) | |
| } | |
| tools_json.append(simplified_config_langchain) | |
| tool_names.append(tool.name) | |
| formatted_tools = "\n".join([ | |
| f"{tool['name']}: {tool['description']}, args: {tool['parameters']}" | |
| for tool in tools_json | |
| ]) | |
| formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}") | |
| template = prompt.format(tool_names=tool_names, | |
| tools=formatted_tools, | |
| history="None", | |
| input="{input}", | |
| agent_scratchpad="{agent_scratchpad}") | |
| if input_variables is None: | |
| input_variables = ["input", "agent_scratchpad"] | |
| _memory_prompts = memory_prompts or [] | |
| messages = [ | |
| SystemMessagePromptTemplate.from_template(template), | |
| *_memory_prompts, | |
| ] | |
| return ChatPromptTemplate(input_variables=input_variables, messages=messages) | |
| def from_llm_and_tools( | |
| cls, | |
| llm: BaseLanguageModel, | |
| tools: Sequence[BaseTool], | |
| prompt: str = None, | |
| callback_manager: Optional[BaseCallbackManager] = None, | |
| output_parser: Optional[AgentOutputParser] = None, | |
| human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |
| input_variables: Optional[List[str]] = None, | |
| memory_prompts: Optional[List[BasePromptTemplate]] = None, | |
| **kwargs: Any, | |
| ) -> Agent: | |
| """Construct an agent from an LLM and tools.""" | |
| cls._validate_tools(tools) | |
| prompt = cls.create_prompt( | |
| tools, | |
| prompt=prompt, | |
| input_variables=input_variables, | |
| memory_prompts=memory_prompts, | |
| ) | |
| llm_chain = LLMChain( | |
| llm=llm, | |
| prompt=prompt, | |
| callback_manager=callback_manager, | |
| ) | |
| tool_names = [tool.name for tool in tools] | |
| _output_parser = output_parser or cls._get_default_output_parser(llm=llm) | |
| return cls( | |
| llm_chain=llm_chain, | |
| allowed_tools=tool_names, | |
| output_parser=_output_parser, | |
| **kwargs, | |
| ) | |
| def _agent_type(self) -> str: | |
| raise ValueError | |
| def initialize_glm3_agent( | |
| tools: Sequence[BaseTool], | |
| llm: BaseLanguageModel, | |
| prompt: str = None, | |
| memory: Optional[ConversationBufferWindowMemory] = None, | |
| agent_kwargs: Optional[dict] = None, | |
| *, | |
| tags: Optional[Sequence[str]] = None, | |
| **kwargs: Any, | |
| ) -> AgentExecutor: | |
| tags_ = list(tags) if tags else [] | |
| agent_kwargs = agent_kwargs or {} | |
| agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools( | |
| llm=llm, | |
| tools=tools, | |
| prompt=prompt, | |
| **agent_kwargs | |
| ) | |
| return AgentExecutor.from_agent_and_tools( | |
| agent=agent_obj, | |
| tools=tools, | |
| memory=memory, | |
| tags=tags_, | |
| **kwargs, | |
| ) |