import asyncio import os import time import traceback from typing import List, Optional import re import tiktoken from cohere import AsyncClient from dotenv import load_dotenv from llama_index.core import Document, QueryBundle from llama_index.core.async_utils import run_async_tasks from llama_index.core.retrievers import ( BaseRetriever, KeywordTableSimpleRetriever, VectorIndexRetriever, ) from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle, TextNode from llama_index.postprocessor.cohere_rerank import CohereRerank from llama_index.core.vector_stores import ( FilterCondition, FilterOperator, MetadataFilter, MetadataFilters, ) from rapidfuzz import fuzz, process # ✨ NEW for typo correction load_dotenv() # ✨ New Function: Fuzzy correction for queries def normalize_and_correct_query(query: str, known_terms: List[str]) -> str: cleaned = re.sub(r"[^\w\s]", "", query.lower()) words = cleaned.split() corrected_words = [] for word in words: match, score, _ = process.extractOne(word, known_terms, scorer=fuzz.ratio) if score > 80: corrected_words.append(match) else: corrected_words.append(word) return " ".join(corrected_words) class AsyncCohereRerank(CohereRerank): def __init__( self, top_n: int = 5, model: str = "rerank-english-v3.0", api_key: Optional[str] = None, ) -> None: super().__init__(top_n=top_n, model=model, api_key=api_key) self._api_key = api_key self._model = model self._top_n = top_n async def postprocess_nodes( self, nodes: List[NodeWithScore], query_bundle: Optional[QueryBundle] = None, ) -> List[NodeWithScore]: if query_bundle is None: raise ValueError("Query bundle must be provided.") if len(nodes) == 0: return [] async_client = AsyncClient(api_key=self._api_key) texts = [node.node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes] results = await async_client.rerank( model=self._model, top_n=self._top_n, query=query_bundle.query_str, documents=texts, ) return [ NodeWithScore( node=nodes[result.index].node, score=result.relevance_score ) for result in results.results ] class CustomRetriever(BaseRetriever): """Custom retriever that performs both semantic search and hybrid search.""" def __init__( self, vector_retriever: VectorIndexRetriever, document_dict: dict, keyword_retriever=None, mode: str = "AND", ) -> None: self._vector_retriever = vector_retriever self._document_dict = document_dict self._keyword_retriever = keyword_retriever if mode not in ("AND", "OR"): raise ValueError("Invalid mode.") self._mode = mode super().__init__() async def _process_retrieval( self, query_bundle: QueryBundle, is_async: bool = True ) -> List[NodeWithScore]: query_bundle.query_str = query_bundle.query_str.replace("\ninput is ", "").rstrip() # ✅ Typo correction using fuzzy logic known_keywords = [ "accounting", "audit", "assurance", "consulting", "tax", "advisory", "technology", "outsourcing", "virtual cfo", "services", "team", "leadership", "india", "usa", "projects", "cloud", "data", "ai", "ml", "education", "training", "academy", "sox", "compliance", "clients", "mission", "vision", "culture", "offices", "partners", "strategy" ] corrected_query = normalize_and_correct_query(query_bundle.query_str, known_keywords) query_bundle.query_str = corrected_query start = time.time() if is_async: nodes = await self._vector_retriever.aretrieve(query_bundle) else: nodes = self._vector_retriever.retrieve(query_bundle) keyword_nodes = [] if self._keyword_retriever: if is_async: keyword_nodes = await self._keyword_retriever.aretrieve(query_bundle) else: keyword_nodes = self._keyword_retriever.retrieve(query_bundle) vector_ids = {n.node.node_id for n in nodes} keyword_ids = {n.node.node_id for n in keyword_nodes} combined_dict = {n.node.node_id: n for n in nodes} combined_dict.update({n.node.node_id: n for n in keyword_nodes}) if not self._keyword_retriever or not keyword_nodes: retrieve_ids = vector_ids else: retrieve_ids = ( vector_ids.intersection(keyword_ids) if self._mode == "AND" else vector_ids.union(keyword_ids) ) nodes = [combined_dict[rid] for rid in retrieve_ids] nodes = self._filter_nodes_by_unique_doc_id(nodes) for node in nodes: doc_id = node.node.source_node.node_id if node.metadata.get("retrieve_doc", False): doc = self._document_dict.get(doc_id) if doc: node.node.text = doc.text node.node.node_id = doc_id try: reranker = ( AsyncCohereRerank(top_n=5, model="rerank-english-v3.0") if is_async else CohereRerank(top_n=5, model="rerank-english-v3.0") ) nodes = ( await reranker.postprocess_nodes(nodes, query_bundle) if is_async else reranker.postprocess_nodes(nodes, query_bundle) ) except Exception as e: print(f"Error during reranking: {type(e).__name__}: {str(e)}") traceback.print_exc() nodes_filtered = self._filter_by_score_and_tokens(nodes) duration = time.time() - start print(f"Retrieving nodes took {duration:.2f}s") return nodes_filtered[:5] def _filter_nodes_by_unique_doc_id( self, nodes: List[NodeWithScore] ) -> List[NodeWithScore]: unique_nodes = {} for node in nodes: doc_id = node.node.source_node.node_id if doc_id is not None and doc_id not in unique_nodes: unique_nodes[doc_id] = node return list(unique_nodes.values()) def _filter_by_score_and_tokens( self, nodes: List[NodeWithScore] ) -> List[NodeWithScore]: nodes_filtered = [] total_tokens = 0 enc = tiktoken.encoding_for_model("gpt-4") for node in nodes: if node.score < 0.10: continue node_tokens = len(enc.encode(node.node.text)) if total_tokens + node_tokens > 100_000: break total_tokens += node_tokens nodes_filtered.append(node) return nodes_filtered async def _aretrieve(self, query_bundle: QueryBundle, **kwargs) -> List[NodeWithScore]: return await self._process_retrieval(query_bundle, is_async=True) def _retrieve(self, query_bundle: QueryBundle, **kwargs) -> List[NodeWithScore]: return asyncio.run(self._process_retrieval(query_bundle, is_async=False))