4at-consulting-chatbot / scripts /custom_retriever.py
Ahambrahmasmi's picture
Update scripts/custom_retriever.py
06e3736 verified
import asyncio
import os
import time
import traceback
from typing import List, Optional
import re
import tiktoken
from cohere import AsyncClient
from dotenv import load_dotenv
from llama_index.core import Document, QueryBundle
from llama_index.core.async_utils import run_async_tasks
from llama_index.core.retrievers import (
BaseRetriever,
KeywordTableSimpleRetriever,
VectorIndexRetriever,
)
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle, TextNode
from llama_index.postprocessor.cohere_rerank import CohereRerank
from llama_index.core.vector_stores import (
FilterCondition,
FilterOperator,
MetadataFilter,
MetadataFilters,
)
from rapidfuzz import fuzz, process # ✨ NEW for typo correction
load_dotenv()
# ✨ New Function: Fuzzy correction for queries
def normalize_and_correct_query(query: str, known_terms: List[str]) -> str:
cleaned = re.sub(r"[^\w\s]", "", query.lower())
words = cleaned.split()
corrected_words = []
for word in words:
match, score, _ = process.extractOne(word, known_terms, scorer=fuzz.ratio)
if score > 80:
corrected_words.append(match)
else:
corrected_words.append(word)
return " ".join(corrected_words)
class AsyncCohereRerank(CohereRerank):
def __init__(
self,
top_n: int = 5,
model: str = "rerank-english-v3.0",
api_key: Optional[str] = None,
) -> None:
super().__init__(top_n=top_n, model=model, api_key=api_key)
self._api_key = api_key
self._model = model
self._top_n = top_n
async def postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
if query_bundle is None:
raise ValueError("Query bundle must be provided.")
if len(nodes) == 0:
return []
async_client = AsyncClient(api_key=self._api_key)
texts = [node.node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes]
results = await async_client.rerank(
model=self._model,
top_n=self._top_n,
query=query_bundle.query_str,
documents=texts,
)
return [
NodeWithScore(
node=nodes[result.index].node,
score=result.relevance_score
)
for result in results.results
]
class CustomRetriever(BaseRetriever):
"""Custom retriever that performs both semantic search and hybrid search."""
def __init__(
self,
vector_retriever: VectorIndexRetriever,
document_dict: dict,
keyword_retriever=None,
mode: str = "AND",
) -> None:
self._vector_retriever = vector_retriever
self._document_dict = document_dict
self._keyword_retriever = keyword_retriever
if mode not in ("AND", "OR"):
raise ValueError("Invalid mode.")
self._mode = mode
super().__init__()
async def _process_retrieval(
self, query_bundle: QueryBundle, is_async: bool = True
) -> List[NodeWithScore]:
query_bundle.query_str = query_bundle.query_str.replace("\ninput is ", "").rstrip()
# ✅ Typo correction using fuzzy logic
known_keywords = [
"accounting", "audit", "assurance", "consulting", "tax", "advisory", "technology",
"outsourcing", "virtual cfo", "services", "team", "leadership", "india", "usa", "projects",
"cloud", "data", "ai", "ml", "education", "training", "academy", "sox", "compliance",
"clients", "mission", "vision", "culture", "offices", "partners", "strategy"
]
corrected_query = normalize_and_correct_query(query_bundle.query_str, known_keywords)
query_bundle.query_str = corrected_query
start = time.time()
if is_async:
nodes = await self._vector_retriever.aretrieve(query_bundle)
else:
nodes = self._vector_retriever.retrieve(query_bundle)
keyword_nodes = []
if self._keyword_retriever:
if is_async:
keyword_nodes = await self._keyword_retriever.aretrieve(query_bundle)
else:
keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
vector_ids = {n.node.node_id for n in nodes}
keyword_ids = {n.node.node_id for n in keyword_nodes}
combined_dict = {n.node.node_id: n for n in nodes}
combined_dict.update({n.node.node_id: n for n in keyword_nodes})
if not self._keyword_retriever or not keyword_nodes:
retrieve_ids = vector_ids
else:
retrieve_ids = (
vector_ids.intersection(keyword_ids)
if self._mode == "AND"
else vector_ids.union(keyword_ids)
)
nodes = [combined_dict[rid] for rid in retrieve_ids]
nodes = self._filter_nodes_by_unique_doc_id(nodes)
for node in nodes:
doc_id = node.node.source_node.node_id
if node.metadata.get("retrieve_doc", False):
doc = self._document_dict.get(doc_id)
if doc:
node.node.text = doc.text
node.node.node_id = doc_id
try:
reranker = (
AsyncCohereRerank(top_n=5, model="rerank-english-v3.0")
if is_async
else CohereRerank(top_n=5, model="rerank-english-v3.0")
)
nodes = (
await reranker.postprocess_nodes(nodes, query_bundle)
if is_async
else reranker.postprocess_nodes(nodes, query_bundle)
)
except Exception as e:
print(f"Error during reranking: {type(e).__name__}: {str(e)}")
traceback.print_exc()
nodes_filtered = self._filter_by_score_and_tokens(nodes)
duration = time.time() - start
print(f"Retrieving nodes took {duration:.2f}s")
return nodes_filtered[:5]
def _filter_nodes_by_unique_doc_id(
self, nodes: List[NodeWithScore]
) -> List[NodeWithScore]:
unique_nodes = {}
for node in nodes:
doc_id = node.node.source_node.node_id
if doc_id is not None and doc_id not in unique_nodes:
unique_nodes[doc_id] = node
return list(unique_nodes.values())
def _filter_by_score_and_tokens(
self, nodes: List[NodeWithScore]
) -> List[NodeWithScore]:
nodes_filtered = []
total_tokens = 0
enc = tiktoken.encoding_for_model("gpt-4")
for node in nodes:
if node.score < 0.10:
continue
node_tokens = len(enc.encode(node.node.text))
if total_tokens + node_tokens > 100_000:
break
total_tokens += node_tokens
nodes_filtered.append(node)
return nodes_filtered
async def _aretrieve(self, query_bundle: QueryBundle, **kwargs) -> List[NodeWithScore]:
return await self._process_retrieval(query_bundle, is_async=True)
def _retrieve(self, query_bundle: QueryBundle, **kwargs) -> List[NodeWithScore]:
return asyncio.run(self._process_retrieval(query_bundle, is_async=False))