File size: 5,586 Bytes
f1e6b80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
"""Chain that carries on a conversation and calls an LLM."""
from typing import List
from langchain_core._api import deprecated
from langchain_core.memory import BaseMemory
from langchain_core.prompts import BasePromptTemplate
from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self
from langchain.chains.conversation.prompt import PROMPT
from langchain.chains.llm import LLMChain
from langchain.memory.buffer import ConversationBufferMemory
@deprecated(
since="0.2.7",
alternative=(
"RunnableWithMessageHistory: "
"https://python.langchain.com/v0.2/api_reference/core/runnables/langchain_core.runnables.history.RunnableWithMessageHistory.html" # noqa: E501
),
removal="1.0",
)
class ConversationChain(LLMChain): # type: ignore[override, override]
"""Chain to have a conversation and load context from memory.
This class is deprecated in favor of ``RunnableWithMessageHistory``. Please refer
to this tutorial for more detail: https://python.langchain.com/docs/tutorials/chatbot/
``RunnableWithMessageHistory`` offers several benefits, including:
- Stream, batch, and async support;
- More flexible memory handling, including the ability to manage memory
outside the chain;
- Support for multiple threads.
Below is a minimal implementation, analogous to using ``ConversationChain`` with
the default ``ConversationBufferMemory``:
.. code-block:: python
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_openai import ChatOpenAI
store = {} # memory is maintained outside the chain
def get_session_history(session_id: str) -> InMemoryChatMessageHistory:
if session_id not in store:
store[session_id] = InMemoryChatMessageHistory()
return store[session_id]
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
chain = RunnableWithMessageHistory(llm, get_session_history)
chain.invoke(
"Hi I'm Bob.",
config={"configurable": {"session_id": "1"}},
) # session_id determines thread
Memory objects can also be incorporated into the ``get_session_history`` callable:
.. code-block:: python
from langchain.memory import ConversationBufferWindowMemory
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_openai import ChatOpenAI
store = {} # memory is maintained outside the chain
def get_session_history(session_id: str) -> InMemoryChatMessageHistory:
if session_id not in store:
store[session_id] = InMemoryChatMessageHistory()
return store[session_id]
memory = ConversationBufferWindowMemory(
chat_memory=store[session_id],
k=3,
return_messages=True,
)
assert len(memory.memory_variables) == 1
key = memory.memory_variables[0]
messages = memory.load_memory_variables({})[key]
store[session_id] = InMemoryChatMessageHistory(messages=messages)
return store[session_id]
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
chain = RunnableWithMessageHistory(llm, get_session_history)
chain.invoke(
"Hi I'm Bob.",
config={"configurable": {"session_id": "1"}},
) # session_id determines thread
Example:
.. code-block:: python
from langchain.chains import ConversationChain
from langchain_community.llms import OpenAI
conversation = ConversationChain(llm=OpenAI())
"""
memory: BaseMemory = Field(default_factory=ConversationBufferMemory)
"""Default memory store."""
prompt: BasePromptTemplate = PROMPT
"""Default conversation prompt to use."""
input_key: str = "input" #: :meta private:
output_key: str = "response" #: :meta private:
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
@classmethod
def is_lc_serializable(cls) -> bool:
return False
@property
def input_keys(self) -> List[str]:
"""Use this since so some prompt vars come from history."""
return [self.input_key]
@model_validator(mode="after")
def validate_prompt_input_variables(self) -> Self:
"""Validate that prompt input variables are consistent."""
memory_keys = self.memory.memory_variables
input_key = self.input_key
if input_key in memory_keys:
raise ValueError(
f"The input key {input_key} was also found in the memory keys "
f"({memory_keys}) - please provide keys that don't overlap."
)
prompt_variables = self.prompt.input_variables
expected_keys = memory_keys + [input_key]
if set(expected_keys) != set(prompt_variables):
raise ValueError(
"Got unexpected prompt input variables. The prompt expects "
f"{prompt_variables}, but got {memory_keys} as inputs from "
f"memory, and {input_key} as the normal input key."
)
return self
|