|
|
"""Retriever that generates and executes structured queries over its own data source.""" |
|
|
|
|
|
import logging |
|
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union |
|
|
|
|
|
from langchain_core.callbacks.manager import ( |
|
|
AsyncCallbackManagerForRetrieverRun, |
|
|
CallbackManagerForRetrieverRun, |
|
|
) |
|
|
from langchain_core.documents import Document |
|
|
from langchain_core.language_models import BaseLanguageModel |
|
|
from langchain_core.retrievers import BaseRetriever |
|
|
from langchain_core.runnables import Runnable |
|
|
from langchain_core.structured_query import StructuredQuery, Visitor |
|
|
from langchain_core.vectorstores import VectorStore |
|
|
from pydantic import ConfigDict, Field, model_validator |
|
|
|
|
|
from langchain.chains.query_constructor.base import load_query_constructor_runnable |
|
|
from langchain.chains.query_constructor.schema import AttributeInfo |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
QUERY_CONSTRUCTOR_RUN_NAME = "query_constructor" |
|
|
|
|
|
|
|
|
def _get_builtin_translator(vectorstore: VectorStore) -> Visitor: |
|
|
"""Get the translator class corresponding to the vector store class.""" |
|
|
try: |
|
|
import langchain_community |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"The langchain-community package must be installed to use this feature." |
|
|
" Please install it using `pip install langchain-community`." |
|
|
) |
|
|
|
|
|
from langchain_community.query_constructors.astradb import AstraDBTranslator |
|
|
from langchain_community.query_constructors.chroma import ChromaTranslator |
|
|
from langchain_community.query_constructors.dashvector import DashvectorTranslator |
|
|
from langchain_community.query_constructors.databricks_vector_search import ( |
|
|
DatabricksVectorSearchTranslator, |
|
|
) |
|
|
from langchain_community.query_constructors.deeplake import DeepLakeTranslator |
|
|
from langchain_community.query_constructors.dingo import DingoDBTranslator |
|
|
from langchain_community.query_constructors.elasticsearch import ( |
|
|
ElasticsearchTranslator, |
|
|
) |
|
|
from langchain_community.query_constructors.milvus import MilvusTranslator |
|
|
from langchain_community.query_constructors.mongodb_atlas import ( |
|
|
MongoDBAtlasTranslator, |
|
|
) |
|
|
from langchain_community.query_constructors.myscale import MyScaleTranslator |
|
|
from langchain_community.query_constructors.neo4j import Neo4jTranslator |
|
|
from langchain_community.query_constructors.opensearch import OpenSearchTranslator |
|
|
from langchain_community.query_constructors.pgvector import PGVectorTranslator |
|
|
from langchain_community.query_constructors.pinecone import PineconeTranslator |
|
|
from langchain_community.query_constructors.qdrant import QdrantTranslator |
|
|
from langchain_community.query_constructors.redis import RedisTranslator |
|
|
from langchain_community.query_constructors.supabase import SupabaseVectorTranslator |
|
|
from langchain_community.query_constructors.tencentvectordb import ( |
|
|
TencentVectorDBTranslator, |
|
|
) |
|
|
from langchain_community.query_constructors.timescalevector import ( |
|
|
TimescaleVectorTranslator, |
|
|
) |
|
|
from langchain_community.query_constructors.vectara import VectaraTranslator |
|
|
from langchain_community.query_constructors.weaviate import WeaviateTranslator |
|
|
from langchain_community.vectorstores import ( |
|
|
AstraDB, |
|
|
DashVector, |
|
|
DatabricksVectorSearch, |
|
|
DeepLake, |
|
|
Dingo, |
|
|
Milvus, |
|
|
MyScale, |
|
|
Neo4jVector, |
|
|
OpenSearchVectorSearch, |
|
|
PGVector, |
|
|
Qdrant, |
|
|
Redis, |
|
|
SupabaseVectorStore, |
|
|
TencentVectorDB, |
|
|
TimescaleVector, |
|
|
Vectara, |
|
|
Weaviate, |
|
|
) |
|
|
from langchain_community.vectorstores import ( |
|
|
Chroma as CommunityChroma, |
|
|
) |
|
|
from langchain_community.vectorstores import ( |
|
|
ElasticsearchStore as ElasticsearchStoreCommunity, |
|
|
) |
|
|
from langchain_community.vectorstores import ( |
|
|
MongoDBAtlasVectorSearch as CommunityMongoDBAtlasVectorSearch, |
|
|
) |
|
|
from langchain_community.vectorstores import ( |
|
|
Pinecone as CommunityPinecone, |
|
|
) |
|
|
|
|
|
BUILTIN_TRANSLATORS: Dict[Type[VectorStore], Type[Visitor]] = { |
|
|
AstraDB: AstraDBTranslator, |
|
|
PGVector: PGVectorTranslator, |
|
|
CommunityPinecone: PineconeTranslator, |
|
|
CommunityChroma: ChromaTranslator, |
|
|
DashVector: DashvectorTranslator, |
|
|
Dingo: DingoDBTranslator, |
|
|
Weaviate: WeaviateTranslator, |
|
|
Vectara: VectaraTranslator, |
|
|
Qdrant: QdrantTranslator, |
|
|
MyScale: MyScaleTranslator, |
|
|
DeepLake: DeepLakeTranslator, |
|
|
ElasticsearchStoreCommunity: ElasticsearchTranslator, |
|
|
Milvus: MilvusTranslator, |
|
|
SupabaseVectorStore: SupabaseVectorTranslator, |
|
|
TimescaleVector: TimescaleVectorTranslator, |
|
|
OpenSearchVectorSearch: OpenSearchTranslator, |
|
|
CommunityMongoDBAtlasVectorSearch: MongoDBAtlasTranslator, |
|
|
Neo4jVector: Neo4jTranslator, |
|
|
} |
|
|
if isinstance(vectorstore, DatabricksVectorSearch): |
|
|
return DatabricksVectorSearchTranslator() |
|
|
elif isinstance(vectorstore, MyScale): |
|
|
return MyScaleTranslator(metadata_key=vectorstore.metadata_column) |
|
|
elif isinstance(vectorstore, Redis): |
|
|
return RedisTranslator.from_vectorstore(vectorstore) |
|
|
elif isinstance(vectorstore, TencentVectorDB): |
|
|
fields = [ |
|
|
field.name for field in (vectorstore.meta_fields or []) if field.index |
|
|
] |
|
|
return TencentVectorDBTranslator(fields) |
|
|
elif vectorstore.__class__ in BUILTIN_TRANSLATORS: |
|
|
return BUILTIN_TRANSLATORS[vectorstore.__class__]() |
|
|
else: |
|
|
try: |
|
|
from langchain_astradb.vectorstores import AstraDBVectorStore |
|
|
except ImportError: |
|
|
pass |
|
|
else: |
|
|
if isinstance(vectorstore, AstraDBVectorStore): |
|
|
return AstraDBTranslator() |
|
|
|
|
|
try: |
|
|
from langchain_elasticsearch.vectorstores import ElasticsearchStore |
|
|
except ImportError: |
|
|
pass |
|
|
else: |
|
|
if isinstance(vectorstore, ElasticsearchStore): |
|
|
return ElasticsearchTranslator() |
|
|
|
|
|
try: |
|
|
from langchain_pinecone import PineconeVectorStore |
|
|
except ImportError: |
|
|
pass |
|
|
else: |
|
|
if isinstance(vectorstore, PineconeVectorStore): |
|
|
return PineconeTranslator() |
|
|
|
|
|
try: |
|
|
from langchain_mongodb import MongoDBAtlasVectorSearch |
|
|
except ImportError: |
|
|
pass |
|
|
else: |
|
|
if isinstance(vectorstore, MongoDBAtlasVectorSearch): |
|
|
return MongoDBAtlasTranslator() |
|
|
|
|
|
try: |
|
|
from langchain_neo4j import Neo4jVector |
|
|
except ImportError: |
|
|
pass |
|
|
else: |
|
|
if isinstance(vectorstore, Neo4jVector): |
|
|
return Neo4jTranslator() |
|
|
|
|
|
try: |
|
|
|
|
|
from langchain_chroma import Chroma |
|
|
except ImportError: |
|
|
pass |
|
|
else: |
|
|
if isinstance(vectorstore, Chroma): |
|
|
return ChromaTranslator() |
|
|
|
|
|
try: |
|
|
from langchain_postgres import PGVector |
|
|
from langchain_postgres import PGVectorTranslator as NewPGVectorTranslator |
|
|
except ImportError: |
|
|
pass |
|
|
else: |
|
|
if isinstance(vectorstore, PGVector): |
|
|
return NewPGVectorTranslator() |
|
|
|
|
|
try: |
|
|
from langchain_qdrant import QdrantVectorStore |
|
|
except ImportError: |
|
|
pass |
|
|
else: |
|
|
if isinstance(vectorstore, QdrantVectorStore): |
|
|
return QdrantTranslator(metadata_key=vectorstore.metadata_payload_key) |
|
|
|
|
|
try: |
|
|
|
|
|
from langchain_community.query_constructors.hanavector import HanaTranslator |
|
|
from langchain_community.vectorstores import HanaDB |
|
|
except ImportError: |
|
|
pass |
|
|
else: |
|
|
if isinstance(vectorstore, HanaDB): |
|
|
return HanaTranslator() |
|
|
|
|
|
try: |
|
|
|
|
|
from langchain_weaviate.vectorstores import WeaviateVectorStore |
|
|
|
|
|
except ImportError: |
|
|
pass |
|
|
else: |
|
|
if isinstance(vectorstore, WeaviateVectorStore): |
|
|
return WeaviateTranslator() |
|
|
|
|
|
raise ValueError( |
|
|
f"Self query retriever with Vector Store type {vectorstore.__class__}" |
|
|
f" not supported." |
|
|
) |
|
|
|
|
|
|
|
|
class SelfQueryRetriever(BaseRetriever): |
|
|
"""Retriever that uses a vector store and an LLM to generate |
|
|
the vector store queries.""" |
|
|
|
|
|
vectorstore: VectorStore |
|
|
"""The underlying vector store from which documents will be retrieved.""" |
|
|
query_constructor: Runnable[dict, StructuredQuery] = Field(alias="llm_chain") |
|
|
"""The query constructor chain for generating the vector store queries. |
|
|
|
|
|
llm_chain is legacy name kept for backwards compatibility.""" |
|
|
search_type: str = "similarity" |
|
|
"""The search type to perform on the vector store.""" |
|
|
search_kwargs: dict = Field(default_factory=dict) |
|
|
"""Keyword arguments to pass in to the vector store search.""" |
|
|
structured_query_translator: Visitor |
|
|
"""Translator for turning internal query language into vectorstore search params.""" |
|
|
verbose: bool = False |
|
|
|
|
|
use_original_query: bool = False |
|
|
"""Use original query instead of the revised new query from LLM""" |
|
|
|
|
|
model_config = ConfigDict( |
|
|
populate_by_name=True, |
|
|
arbitrary_types_allowed=True, |
|
|
) |
|
|
|
|
|
@model_validator(mode="before") |
|
|
@classmethod |
|
|
def validate_translator(cls, values: Dict) -> Any: |
|
|
"""Validate translator.""" |
|
|
if "structured_query_translator" not in values: |
|
|
values["structured_query_translator"] = _get_builtin_translator( |
|
|
values["vectorstore"] |
|
|
) |
|
|
return values |
|
|
|
|
|
@property |
|
|
def llm_chain(self) -> Runnable: |
|
|
"""llm_chain is legacy name kept for backwards compatibility.""" |
|
|
return self.query_constructor |
|
|
|
|
|
def _prepare_query( |
|
|
self, query: str, structured_query: StructuredQuery |
|
|
) -> Tuple[str, Dict[str, Any]]: |
|
|
new_query, new_kwargs = self.structured_query_translator.visit_structured_query( |
|
|
structured_query |
|
|
) |
|
|
if structured_query.limit is not None: |
|
|
new_kwargs["k"] = structured_query.limit |
|
|
if self.use_original_query: |
|
|
new_query = query |
|
|
search_kwargs = {**self.search_kwargs, **new_kwargs} |
|
|
return new_query, search_kwargs |
|
|
|
|
|
def _get_docs_with_query( |
|
|
self, query: str, search_kwargs: Dict[str, Any] |
|
|
) -> List[Document]: |
|
|
docs = self.vectorstore.search(query, self.search_type, **search_kwargs) |
|
|
return docs |
|
|
|
|
|
async def _aget_docs_with_query( |
|
|
self, query: str, search_kwargs: Dict[str, Any] |
|
|
) -> List[Document]: |
|
|
docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs) |
|
|
return docs |
|
|
|
|
|
def _get_relevant_documents( |
|
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun |
|
|
) -> List[Document]: |
|
|
"""Get documents relevant for a query. |
|
|
|
|
|
Args: |
|
|
query: string to find relevant documents for |
|
|
|
|
|
Returns: |
|
|
List of relevant documents |
|
|
""" |
|
|
structured_query = self.query_constructor.invoke( |
|
|
{"query": query}, config={"callbacks": run_manager.get_child()} |
|
|
) |
|
|
if self.verbose: |
|
|
logger.info(f"Generated Query: {structured_query}") |
|
|
new_query, search_kwargs = self._prepare_query(query, structured_query) |
|
|
docs = self._get_docs_with_query(new_query, search_kwargs) |
|
|
return docs |
|
|
|
|
|
async def _aget_relevant_documents( |
|
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun |
|
|
) -> List[Document]: |
|
|
"""Get documents relevant for a query. |
|
|
|
|
|
Args: |
|
|
query: string to find relevant documents for |
|
|
|
|
|
Returns: |
|
|
List of relevant documents |
|
|
""" |
|
|
structured_query = await self.query_constructor.ainvoke( |
|
|
{"query": query}, config={"callbacks": run_manager.get_child()} |
|
|
) |
|
|
if self.verbose: |
|
|
logger.info(f"Generated Query: {structured_query}") |
|
|
new_query, search_kwargs = self._prepare_query(query, structured_query) |
|
|
docs = await self._aget_docs_with_query(new_query, search_kwargs) |
|
|
return docs |
|
|
|
|
|
@classmethod |
|
|
def from_llm( |
|
|
cls, |
|
|
llm: BaseLanguageModel, |
|
|
vectorstore: VectorStore, |
|
|
document_contents: str, |
|
|
metadata_field_info: Sequence[Union[AttributeInfo, dict]], |
|
|
structured_query_translator: Optional[Visitor] = None, |
|
|
chain_kwargs: Optional[Dict] = None, |
|
|
enable_limit: bool = False, |
|
|
use_original_query: bool = False, |
|
|
**kwargs: Any, |
|
|
) -> "SelfQueryRetriever": |
|
|
if structured_query_translator is None: |
|
|
structured_query_translator = _get_builtin_translator(vectorstore) |
|
|
chain_kwargs = chain_kwargs or {} |
|
|
|
|
|
if ( |
|
|
"allowed_comparators" not in chain_kwargs |
|
|
and structured_query_translator.allowed_comparators is not None |
|
|
): |
|
|
chain_kwargs["allowed_comparators"] = ( |
|
|
structured_query_translator.allowed_comparators |
|
|
) |
|
|
if ( |
|
|
"allowed_operators" not in chain_kwargs |
|
|
and structured_query_translator.allowed_operators is not None |
|
|
): |
|
|
chain_kwargs["allowed_operators"] = ( |
|
|
structured_query_translator.allowed_operators |
|
|
) |
|
|
query_constructor = load_query_constructor_runnable( |
|
|
llm, |
|
|
document_contents, |
|
|
metadata_field_info, |
|
|
enable_limit=enable_limit, |
|
|
**chain_kwargs, |
|
|
) |
|
|
query_constructor = query_constructor.with_config( |
|
|
run_name=QUERY_CONSTRUCTOR_RUN_NAME |
|
|
) |
|
|
return cls( |
|
|
query_constructor=query_constructor, |
|
|
vectorstore=vectorstore, |
|
|
use_original_query=use_original_query, |
|
|
structured_query_translator=structured_query_translator, |
|
|
**kwargs, |
|
|
) |
|
|
|