|
|
from enum import Enum |
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
from langchain_core.callbacks import ( |
|
|
AsyncCallbackManagerForRetrieverRun, |
|
|
CallbackManagerForRetrieverRun, |
|
|
) |
|
|
from langchain_core.documents import Document |
|
|
from langchain_core.retrievers import BaseRetriever |
|
|
from langchain_core.stores import BaseStore, ByteStore |
|
|
from langchain_core.vectorstores import VectorStore |
|
|
from pydantic import Field, model_validator |
|
|
|
|
|
from langchain.storage._lc_store import create_kv_docstore |
|
|
|
|
|
|
|
|
class SearchType(str, Enum): |
|
|
"""Enumerator of the types of search to perform.""" |
|
|
|
|
|
similarity = "similarity" |
|
|
"""Similarity search.""" |
|
|
similarity_score_threshold = "similarity_score_threshold" |
|
|
"""Similarity search with a score threshold.""" |
|
|
mmr = "mmr" |
|
|
"""Maximal Marginal Relevance reranking of similarity search.""" |
|
|
|
|
|
|
|
|
class MultiVectorRetriever(BaseRetriever): |
|
|
"""Retrieve from a set of multiple embeddings for the same document.""" |
|
|
|
|
|
vectorstore: VectorStore |
|
|
"""The underlying vectorstore to use to store small chunks |
|
|
and their embedding vectors""" |
|
|
byte_store: Optional[ByteStore] = None |
|
|
"""The lower-level backing storage layer for the parent documents""" |
|
|
docstore: BaseStore[str, Document] |
|
|
"""The storage interface for the parent documents""" |
|
|
id_key: str = "doc_id" |
|
|
search_kwargs: dict = Field(default_factory=dict) |
|
|
"""Keyword arguments to pass to the search function.""" |
|
|
search_type: SearchType = SearchType.similarity |
|
|
"""Type of search to perform (similarity / mmr)""" |
|
|
|
|
|
@model_validator(mode="before") |
|
|
@classmethod |
|
|
def shim_docstore(cls, values: Dict) -> Any: |
|
|
byte_store = values.get("byte_store") |
|
|
docstore = values.get("docstore") |
|
|
if byte_store is not None: |
|
|
docstore = create_kv_docstore(byte_store) |
|
|
elif docstore is None: |
|
|
raise Exception("You must pass a `byte_store` parameter.") |
|
|
values["docstore"] = docstore |
|
|
return values |
|
|
|
|
|
def _get_relevant_documents( |
|
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun |
|
|
) -> List[Document]: |
|
|
"""Get documents relevant to a query. |
|
|
Args: |
|
|
query: String to find relevant documents for |
|
|
run_manager: The callbacks handler to use |
|
|
Returns: |
|
|
List of relevant documents |
|
|
""" |
|
|
if self.search_type == SearchType.mmr: |
|
|
sub_docs = self.vectorstore.max_marginal_relevance_search( |
|
|
query, **self.search_kwargs |
|
|
) |
|
|
elif self.search_type == SearchType.similarity_score_threshold: |
|
|
sub_docs_and_similarities = ( |
|
|
self.vectorstore.similarity_search_with_relevance_scores( |
|
|
query, **self.search_kwargs |
|
|
) |
|
|
) |
|
|
sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities] |
|
|
else: |
|
|
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs) |
|
|
|
|
|
|
|
|
ids = [] |
|
|
for d in sub_docs: |
|
|
if self.id_key in d.metadata and d.metadata[self.id_key] not in ids: |
|
|
ids.append(d.metadata[self.id_key]) |
|
|
docs = self.docstore.mget(ids) |
|
|
return [d for d in docs if d is not None] |
|
|
|
|
|
async def _aget_relevant_documents( |
|
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun |
|
|
) -> List[Document]: |
|
|
"""Asynchronously get documents relevant to a query. |
|
|
Args: |
|
|
query: String to find relevant documents for |
|
|
run_manager: The callbacks handler to use |
|
|
Returns: |
|
|
List of relevant documents |
|
|
""" |
|
|
if self.search_type == SearchType.mmr: |
|
|
sub_docs = await self.vectorstore.amax_marginal_relevance_search( |
|
|
query, **self.search_kwargs |
|
|
) |
|
|
elif self.search_type == SearchType.similarity_score_threshold: |
|
|
sub_docs_and_similarities = ( |
|
|
await self.vectorstore.asimilarity_search_with_relevance_scores( |
|
|
query, **self.search_kwargs |
|
|
) |
|
|
) |
|
|
sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities] |
|
|
else: |
|
|
sub_docs = await self.vectorstore.asimilarity_search( |
|
|
query, **self.search_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
ids = [] |
|
|
for d in sub_docs: |
|
|
if self.id_key in d.metadata and d.metadata[self.id_key] not in ids: |
|
|
ids.append(d.metadata[self.id_key]) |
|
|
docs = await self.docstore.amget(ids) |
|
|
return [d for d in docs if d is not None] |
|
|
|