|
|
"""Hypothetical Document Embeddings. |
|
|
|
|
|
https://arxiv.org/abs/2212.10496 |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
import numpy as np |
|
|
from langchain_core.callbacks import CallbackManagerForChainRun |
|
|
from langchain_core.embeddings import Embeddings |
|
|
from langchain_core.language_models import BaseLanguageModel |
|
|
from langchain_core.output_parsers import StrOutputParser |
|
|
from langchain_core.prompts import BasePromptTemplate |
|
|
from langchain_core.runnables import Runnable |
|
|
from pydantic import ConfigDict |
|
|
|
|
|
from langchain.chains.base import Chain |
|
|
from langchain.chains.hyde.prompts import PROMPT_MAP |
|
|
from langchain.chains.llm import LLMChain |
|
|
|
|
|
|
|
|
class HypotheticalDocumentEmbedder(Chain, Embeddings): |
|
|
"""Generate hypothetical document for query, and then embed that. |
|
|
|
|
|
Based on https://arxiv.org/abs/2212.10496 |
|
|
""" |
|
|
|
|
|
base_embeddings: Embeddings |
|
|
llm_chain: Runnable |
|
|
|
|
|
model_config = ConfigDict( |
|
|
arbitrary_types_allowed=True, |
|
|
extra="forbid", |
|
|
) |
|
|
|
|
|
@property |
|
|
def input_keys(self) -> List[str]: |
|
|
"""Input keys for Hyde's LLM chain.""" |
|
|
return self.llm_chain.input_schema.model_json_schema()["required"] |
|
|
|
|
|
@property |
|
|
def output_keys(self) -> List[str]: |
|
|
"""Output keys for Hyde's LLM chain.""" |
|
|
if isinstance(self.llm_chain, LLMChain): |
|
|
return self.llm_chain.output_keys |
|
|
else: |
|
|
return ["text"] |
|
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]: |
|
|
"""Call the base embeddings.""" |
|
|
return self.base_embeddings.embed_documents(texts) |
|
|
|
|
|
def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]: |
|
|
"""Combine embeddings into final embeddings.""" |
|
|
return list(np.array(embeddings).mean(axis=0)) |
|
|
|
|
|
def embed_query(self, text: str) -> List[float]: |
|
|
"""Generate a hypothetical document and embedded it.""" |
|
|
var_name = self.input_keys[0] |
|
|
result = self.llm_chain.invoke({var_name: text}) |
|
|
if isinstance(self.llm_chain, LLMChain): |
|
|
documents = [result[self.output_keys[0]]] |
|
|
else: |
|
|
documents = [result] |
|
|
embeddings = self.embed_documents(documents) |
|
|
return self.combine_embeddings(embeddings) |
|
|
|
|
|
def _call( |
|
|
self, |
|
|
inputs: Dict[str, Any], |
|
|
run_manager: Optional[CallbackManagerForChainRun] = None, |
|
|
) -> Dict[str, str]: |
|
|
"""Call the internal llm chain.""" |
|
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() |
|
|
return self.llm_chain.invoke( |
|
|
inputs, config={"callbacks": _run_manager.get_child()} |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def from_llm( |
|
|
cls, |
|
|
llm: BaseLanguageModel, |
|
|
base_embeddings: Embeddings, |
|
|
prompt_key: Optional[str] = None, |
|
|
custom_prompt: Optional[BasePromptTemplate] = None, |
|
|
**kwargs: Any, |
|
|
) -> HypotheticalDocumentEmbedder: |
|
|
"""Load and use LLMChain with either a specific prompt key or custom prompt.""" |
|
|
if custom_prompt is not None: |
|
|
prompt = custom_prompt |
|
|
elif prompt_key is not None and prompt_key in PROMPT_MAP: |
|
|
prompt = PROMPT_MAP[prompt_key] |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Must specify prompt_key if custom_prompt not provided. Should be one " |
|
|
f"of {list(PROMPT_MAP.keys())}." |
|
|
) |
|
|
|
|
|
llm_chain = prompt | llm | StrOutputParser() |
|
|
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs) |
|
|
|
|
|
@property |
|
|
def _chain_type(self) -> str: |
|
|
return "hyde_chain" |
|
|
|