|
|
"""Load summarizing chains.""" |
|
|
|
|
|
from typing import Any, Mapping, Optional, Protocol |
|
|
|
|
|
from langchain_core.callbacks import Callbacks |
|
|
from langchain_core.language_models import BaseLanguageModel |
|
|
from langchain_core.prompts import BasePromptTemplate |
|
|
|
|
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain |
|
|
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain |
|
|
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain |
|
|
from langchain.chains.combine_documents.refine import RefineDocumentsChain |
|
|
from langchain.chains.combine_documents.stuff import StuffDocumentsChain |
|
|
from langchain.chains.llm import LLMChain |
|
|
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt |
|
|
|
|
|
|
|
|
class LoadingCallable(Protocol): |
|
|
"""Interface for loading the combine documents chain.""" |
|
|
|
|
|
def __call__( |
|
|
self, llm: BaseLanguageModel, **kwargs: Any |
|
|
) -> BaseCombineDocumentsChain: |
|
|
"""Callable to load the combine documents chain.""" |
|
|
|
|
|
|
|
|
def _load_stuff_chain( |
|
|
llm: BaseLanguageModel, |
|
|
prompt: BasePromptTemplate = stuff_prompt.PROMPT, |
|
|
document_variable_name: str = "text", |
|
|
verbose: Optional[bool] = None, |
|
|
**kwargs: Any, |
|
|
) -> StuffDocumentsChain: |
|
|
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) |
|
|
|
|
|
return StuffDocumentsChain( |
|
|
llm_chain=llm_chain, |
|
|
document_variable_name=document_variable_name, |
|
|
verbose=verbose, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
def _load_map_reduce_chain( |
|
|
llm: BaseLanguageModel, |
|
|
map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, |
|
|
combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, |
|
|
combine_document_variable_name: str = "text", |
|
|
map_reduce_document_variable_name: str = "text", |
|
|
collapse_prompt: Optional[BasePromptTemplate] = None, |
|
|
reduce_llm: Optional[BaseLanguageModel] = None, |
|
|
collapse_llm: Optional[BaseLanguageModel] = None, |
|
|
verbose: Optional[bool] = None, |
|
|
token_max: int = 3000, |
|
|
callbacks: Callbacks = None, |
|
|
*, |
|
|
collapse_max_retries: Optional[int] = None, |
|
|
**kwargs: Any, |
|
|
) -> MapReduceDocumentsChain: |
|
|
map_chain = LLMChain( |
|
|
llm=llm, |
|
|
prompt=map_prompt, |
|
|
verbose=verbose, |
|
|
callbacks=callbacks, |
|
|
) |
|
|
_reduce_llm = reduce_llm or llm |
|
|
reduce_chain = LLMChain( |
|
|
llm=_reduce_llm, |
|
|
prompt=combine_prompt, |
|
|
verbose=verbose, |
|
|
callbacks=callbacks, |
|
|
) |
|
|
|
|
|
combine_documents_chain = StuffDocumentsChain( |
|
|
llm_chain=reduce_chain, |
|
|
document_variable_name=combine_document_variable_name, |
|
|
verbose=verbose, |
|
|
callbacks=callbacks, |
|
|
) |
|
|
if collapse_prompt is None: |
|
|
collapse_chain = None |
|
|
if collapse_llm is not None: |
|
|
raise ValueError( |
|
|
"collapse_llm provided, but collapse_prompt was not: please " |
|
|
"provide one or stop providing collapse_llm." |
|
|
) |
|
|
else: |
|
|
_collapse_llm = collapse_llm or llm |
|
|
collapse_chain = StuffDocumentsChain( |
|
|
llm_chain=LLMChain( |
|
|
llm=_collapse_llm, |
|
|
prompt=collapse_prompt, |
|
|
verbose=verbose, |
|
|
callbacks=callbacks, |
|
|
), |
|
|
document_variable_name=combine_document_variable_name, |
|
|
) |
|
|
reduce_documents_chain = ReduceDocumentsChain( |
|
|
combine_documents_chain=combine_documents_chain, |
|
|
collapse_documents_chain=collapse_chain, |
|
|
token_max=token_max, |
|
|
verbose=verbose, |
|
|
callbacks=callbacks, |
|
|
collapse_max_retries=collapse_max_retries, |
|
|
) |
|
|
return MapReduceDocumentsChain( |
|
|
llm_chain=map_chain, |
|
|
reduce_documents_chain=reduce_documents_chain, |
|
|
document_variable_name=map_reduce_document_variable_name, |
|
|
verbose=verbose, |
|
|
callbacks=callbacks, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
def _load_refine_chain( |
|
|
llm: BaseLanguageModel, |
|
|
question_prompt: BasePromptTemplate = refine_prompts.PROMPT, |
|
|
refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT, |
|
|
document_variable_name: str = "text", |
|
|
initial_response_name: str = "existing_answer", |
|
|
refine_llm: Optional[BaseLanguageModel] = None, |
|
|
verbose: Optional[bool] = None, |
|
|
**kwargs: Any, |
|
|
) -> RefineDocumentsChain: |
|
|
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) |
|
|
_refine_llm = refine_llm or llm |
|
|
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) |
|
|
return RefineDocumentsChain( |
|
|
initial_llm_chain=initial_chain, |
|
|
refine_llm_chain=refine_chain, |
|
|
document_variable_name=document_variable_name, |
|
|
initial_response_name=initial_response_name, |
|
|
verbose=verbose, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
def load_summarize_chain( |
|
|
llm: BaseLanguageModel, |
|
|
chain_type: str = "stuff", |
|
|
verbose: Optional[bool] = None, |
|
|
**kwargs: Any, |
|
|
) -> BaseCombineDocumentsChain: |
|
|
"""Load summarizing chain. |
|
|
|
|
|
Args: |
|
|
llm: Language Model to use in the chain. |
|
|
chain_type: Type of document combining chain to use. Should be one of "stuff", |
|
|
"map_reduce", and "refine". |
|
|
verbose: Whether chains should be run in verbose mode or not. Note that this |
|
|
applies to all chains that make up the final chain. |
|
|
|
|
|
Returns: |
|
|
A chain to use for summarizing. |
|
|
""" |
|
|
loader_mapping: Mapping[str, LoadingCallable] = { |
|
|
"stuff": _load_stuff_chain, |
|
|
"map_reduce": _load_map_reduce_chain, |
|
|
"refine": _load_refine_chain, |
|
|
} |
|
|
if chain_type not in loader_mapping: |
|
|
raise ValueError( |
|
|
f"Got unsupported chain type: {chain_type}. " |
|
|
f"Should be one of {loader_mapping.keys()}" |
|
|
) |
|
|
return loader_mapping[chain_type](llm, verbose=verbose, **kwargs) |
|
|
|