File size: 7,355 Bytes
5ddcfe5
 
 
 
 
06e3736
5ddcfe5
deb30d8
5ddcfe5
 
deb30d8
 
5ddcfe5
 
deb30d8
5ddcfe5
 
06e3736
4410de7
06e3736
 
 
 
 
 
 
 
5ddcfe5
 
 
06e3736
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a044b52
5ddcfe5
 
 
 
 
 
 
 
 
 
 
 
4410de7
5ddcfe5
 
 
 
deb30d8
 
 
5ddcfe5
 
 
06e3736
57148f8
 
 
 
 
 
 
 
06e3736
 
 
 
deb30d8
06e3736
 
5ddcfe5
 
 
06e3736
deb30d8
5ddcfe5
 
 
 
 
 
 
 
 
 
deb30d8
4410de7
5ddcfe5
deb30d8
5ddcfe5
06e3736
 
 
 
 
 
 
 
 
 
 
 
 
 
deb30d8
5ddcfe5
deb30d8
5ddcfe5
 
 
 
 
deb30d8
5ddcfe5
deb30d8
 
 
 
5ddcfe5
deb30d8
 
5ddcfe5
 
 
deb30d8
 
5ddcfe5
deb30d8
 
 
 
 
5ddcfe5
deb30d8
 
5ddcfe5
deb30d8
06e3736
4410de7
 
 
 
06e3736
5ddcfe5
 
 
deb30d8
 
 
5ddcfe5
deb30d8
4410de7
deb30d8
 
5ddcfe5
 
06e3736
57148f8
5ddcfe5
deb30d8
 
06e3736
 
 
deb30d8
5ddcfe5
06e3736
 
 
deb30d8
5ddcfe5
06e3736
 
deb30d8
 
 
06e3736
 
 
deb30d8
 
06e3736
deb30d8
 
 
 
 
 
 
5ddcfe5
deb30d8
 
 
 
 
 
06e3736
 
deb30d8
06e3736
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import asyncio
import os
import time
import traceback
from typing import List, Optional
import re

import tiktoken
from cohere import AsyncClient
from dotenv import load_dotenv
from llama_index.core import Document, QueryBundle
from llama_index.core.async_utils import run_async_tasks
from llama_index.core.retrievers import (
    BaseRetriever,
    KeywordTableSimpleRetriever,
    VectorIndexRetriever,
)
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle, TextNode
from llama_index.postprocessor.cohere_rerank import CohereRerank
from llama_index.core.vector_stores import (
    FilterCondition,
    FilterOperator,
    MetadataFilter,
    MetadataFilters,
)

from rapidfuzz import fuzz, process  # ✨ NEW for typo correction

load_dotenv()

# ✨ New Function: Fuzzy correction for queries
def normalize_and_correct_query(query: str, known_terms: List[str]) -> str:
    cleaned = re.sub(r"[^\w\s]", "", query.lower())
    words = cleaned.split()
    corrected_words = []

    for word in words:
        match, score, _ = process.extractOne(word, known_terms, scorer=fuzz.ratio)
        if score > 80:
            corrected_words.append(match)
        else:
            corrected_words.append(word)

    return " ".join(corrected_words)


class AsyncCohereRerank(CohereRerank):
    def __init__(
        self,
        top_n: int = 5,
        model: str = "rerank-english-v3.0",
        api_key: Optional[str] = None,
    ) -> None:
        super().__init__(top_n=top_n, model=model, api_key=api_key)
        self._api_key = api_key
        self._model = model
        self._top_n = top_n

    async def postprocess_nodes(
        self,
        nodes: List[NodeWithScore],
        query_bundle: Optional[QueryBundle] = None,
    ) -> List[NodeWithScore]:
        if query_bundle is None:
            raise ValueError("Query bundle must be provided.")
        if len(nodes) == 0:
            return []

        async_client = AsyncClient(api_key=self._api_key)
        texts = [node.node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes]

        results = await async_client.rerank(
            model=self._model,
            top_n=self._top_n,
            query=query_bundle.query_str,
            documents=texts,
        )

        return [
            NodeWithScore(
                node=nodes[result.index].node,
                score=result.relevance_score
            )
            for result in results.results
        ]


class CustomRetriever(BaseRetriever):
    """Custom retriever that performs both semantic search and hybrid search."""

    def __init__(
        self,
        vector_retriever: VectorIndexRetriever,
        document_dict: dict,
        keyword_retriever=None,
        mode: str = "AND",
    ) -> None:
        self._vector_retriever = vector_retriever
        self._document_dict = document_dict
        self._keyword_retriever = keyword_retriever
        if mode not in ("AND", "OR"):
            raise ValueError("Invalid mode.")
        self._mode = mode
        super().__init__()

    async def _process_retrieval(
        self, query_bundle: QueryBundle, is_async: bool = True
    ) -> List[NodeWithScore]:
        query_bundle.query_str = query_bundle.query_str.replace("\ninput is ", "").rstrip()

        # ✅ Typo correction using fuzzy logic
        known_keywords = [
            "accounting", "audit", "assurance", "consulting", "tax", "advisory", "technology",
            "outsourcing", "virtual cfo", "services", "team", "leadership", "india", "usa", "projects",
            "cloud", "data", "ai", "ml", "education", "training", "academy", "sox", "compliance",
            "clients", "mission", "vision", "culture", "offices", "partners", "strategy"
        ]
        corrected_query = normalize_and_correct_query(query_bundle.query_str, known_keywords)
        query_bundle.query_str = corrected_query

        start = time.time()

        if is_async:
            nodes = await self._vector_retriever.aretrieve(query_bundle)
        else:
            nodes = self._vector_retriever.retrieve(query_bundle)

        keyword_nodes = []
        if self._keyword_retriever:
            if is_async:
                keyword_nodes = await self._keyword_retriever.aretrieve(query_bundle)
            else:
                keyword_nodes = self._keyword_retriever.retrieve(query_bundle)

        vector_ids = {n.node.node_id for n in nodes}
        keyword_ids = {n.node.node_id for n in keyword_nodes}
        combined_dict = {n.node.node_id: n for n in nodes}
        combined_dict.update({n.node.node_id: n for n in keyword_nodes})

        if not self._keyword_retriever or not keyword_nodes:
            retrieve_ids = vector_ids
        else:
            retrieve_ids = (
                vector_ids.intersection(keyword_ids)
                if self._mode == "AND"
                else vector_ids.union(keyword_ids)
            )

        nodes = [combined_dict[rid] for rid in retrieve_ids]
        nodes = self._filter_nodes_by_unique_doc_id(nodes)

        for node in nodes:
            doc_id = node.node.source_node.node_id
            if node.metadata.get("retrieve_doc", False):
                doc = self._document_dict.get(doc_id)
                if doc:
                    node.node.text = doc.text
            node.node.node_id = doc_id

        try:
            reranker = (
                AsyncCohereRerank(top_n=5, model="rerank-english-v3.0")
                if is_async
                else CohereRerank(top_n=5, model="rerank-english-v3.0")
            )
            nodes = (
                await reranker.postprocess_nodes(nodes, query_bundle)
                if is_async
                else reranker.postprocess_nodes(nodes, query_bundle)
            )
        except Exception as e:
            print(f"Error during reranking: {type(e).__name__}: {str(e)}")
            traceback.print_exc()

        nodes_filtered = self._filter_by_score_and_tokens(nodes)

        duration = time.time() - start
        print(f"Retrieving nodes took {duration:.2f}s")

        return nodes_filtered[:5]

    def _filter_nodes_by_unique_doc_id(
        self, nodes: List[NodeWithScore]
    ) -> List[NodeWithScore]:
        unique_nodes = {}
        for node in nodes:
            doc_id = node.node.source_node.node_id
            if doc_id is not None and doc_id not in unique_nodes:
                unique_nodes[doc_id] = node
        return list(unique_nodes.values())

    def _filter_by_score_and_tokens(
        self, nodes: List[NodeWithScore]
    ) -> List[NodeWithScore]:
        nodes_filtered = []
        total_tokens = 0
        enc = tiktoken.encoding_for_model("gpt-4")

        for node in nodes:
            if node.score < 0.10:
                continue

            node_tokens = len(enc.encode(node.node.text))
            if total_tokens + node_tokens > 100_000:
                break

            total_tokens += node_tokens
            nodes_filtered.append(node)

        return nodes_filtered

    async def _aretrieve(self, query_bundle: QueryBundle, **kwargs) -> List[NodeWithScore]:
        return await self._process_retrieval(query_bundle, is_async=True)

    def _retrieve(self, query_bundle: QueryBundle, **kwargs) -> List[NodeWithScore]:
        return asyncio.run(self._process_retrieval(query_bundle, is_async=False))