Ahambrahmasmi commited on
Commit
deb30d8
·
verified ·
1 Parent(s): 3e5fcec

Update scripts/custom_retriever.py

Browse files
Files changed (1) hide show
  1. scripts/custom_retriever.py +98 -63
scripts/custom_retriever.py CHANGED
@@ -4,15 +4,21 @@ import time
4
  import traceback
5
  from typing import List, Optional
6
 
 
7
  from cohere import AsyncClient
8
  from dotenv import load_dotenv
9
- from llama_index.core import QueryBundle
 
10
  from llama_index.core.retrievers import (
11
  BaseRetriever,
 
12
  VectorIndexRetriever,
13
  )
14
  from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle
15
  from llama_index.core.vector_stores import (
 
 
 
16
  MetadataFilters,
17
  )
18
  from llama_index.postprocessor.cohere_rerank import CohereRerank
@@ -36,10 +42,13 @@ class AsyncCohereRerank(CohereRerank):
36
  nodes: List[NodeWithScore],
37
  query_bundle: Optional[QueryBundle] = None,
38
  ) -> List[NodeWithScore]:
39
- if query_bundle is None or len(nodes) == 0:
 
 
40
  return []
41
 
42
  async_client = AsyncClient(api_key=self._api_key)
 
43
  texts = [
44
  node.node.get_content(metadata_mode=MetadataMode.EMBED)
45
  for node in nodes
@@ -52,13 +61,19 @@ class AsyncCohereRerank(CohereRerank):
52
  documents=texts,
53
  )
54
 
55
- return [
56
- NodeWithScore(node=nodes[result.index].node, score=result.relevance_score)
57
- for result in results.results
58
- ]
 
 
 
 
59
 
60
 
61
  class CustomRetriever(BaseRetriever):
 
 
62
  def __init__(
63
  self,
64
  vector_retriever: VectorIndexRetriever,
@@ -66,95 +81,115 @@ class CustomRetriever(BaseRetriever):
66
  keyword_retriever=None,
67
  mode: str = "AND",
68
  ) -> None:
69
- super().__init__()
70
  self._vector_retriever = vector_retriever
71
  self._document_dict = document_dict
72
  self._keyword_retriever = keyword_retriever
 
 
73
  self._mode = mode
74
-
75
- def retrieve(self, query: str, filters: Optional[MetadataFilters] = None) -> List[NodeWithScore]:
76
- query_bundle = QueryBundle(query_str=query)
77
- if filters:
78
- self._vector_retriever.filters = filters
79
- return self._retrieve(query_bundle)
80
-
81
- async def aretrieve(self, query: str, filters: Optional[MetadataFilters] = None) -> List[NodeWithScore]:
82
- query_bundle = QueryBundle(query_str=query)
83
- if filters:
84
- self._vector_retriever.filters = filters
85
- return await self._aretrieve(query_bundle)
86
-
87
- def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
88
- return asyncio.run(self._process_retrieval(query_bundle, is_async=False))
89
-
90
- async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
91
- return await self._process_retrieval(query_bundle, is_async=True)
92
 
93
  async def _process_retrieval(
94
  self, query_bundle: QueryBundle, is_async: bool = True
95
  ) -> List[NodeWithScore]:
 
 
 
 
 
 
96
  start = time.time()
 
97
  if is_async:
98
  nodes = await self._vector_retriever.aretrieve(query_bundle)
99
  else:
100
  nodes = self._vector_retriever.retrieve(query_bundle)
101
 
 
102
  if self._keyword_retriever:
103
- keyword_nodes = (
104
- await self._keyword_retriever.aretrieve(query_bundle)
105
- if is_async else self._keyword_retriever.retrieve(query_bundle)
106
- )
107
- else:
108
- keyword_nodes = []
109
 
 
 
110
  combined_dict = {n.node.node_id: n for n in nodes}
111
  combined_dict.update({n.node.node_id: n for n in keyword_nodes})
112
 
113
- if self._keyword_retriever:
114
- if self._mode == "AND":
115
- ids = set(combined_dict) & {n.node.node_id for n in keyword_nodes}
116
- else:
117
- ids = set(combined_dict)
118
  else:
119
- ids = set(combined_dict)
 
 
 
 
120
 
121
- filtered_nodes = [combined_dict[i] for i in ids]
 
122
 
123
- # Restore full text if `retrieve_doc` is True
124
- for node in filtered_nodes:
125
  doc_id = node.node.source_node.node_id
126
- if node.metadata.get("retrieve_doc"):
127
- doc = self._document_dict.get(doc_id)
128
- if doc:
129
- node.node.text = doc.text
130
  node.node.node_id = doc_id
131
 
132
- # Optional: rerank using Cohere
133
  try:
134
  reranker = (
135
- AsyncCohereRerank(top_n=5)
136
- if is_async else CohereRerank(top_n=5)
 
137
  )
138
- filtered_nodes = (
139
- await reranker.apostprocess_nodes(filtered_nodes, query_bundle)
140
- if is_async else reranker.postprocess_nodes(filtered_nodes, query_bundle)
 
141
  )
142
  except Exception as e:
143
- print(f"Reranking failed: {type(e).__name__}: {e}")
144
  traceback.print_exc()
145
 
146
- filtered = self._limit_results_by_length(filtered_nodes)
147
- print(f"✅ Retrieved in {time.time() - start:.2f}s")
148
- return filtered
 
 
 
149
 
150
- def _limit_results_by_length(self, nodes: List[NodeWithScore]) -> List[NodeWithScore]:
151
- total_chars = 0
152
- limited = []
 
153
  for node in nodes:
154
- length = len(node.node.text)
155
- if total_chars + length > 60_000: # rough char limit to stay safe with Gemini context
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  break
157
- total_chars += length
158
- if node.score >= 0.1:
159
- limited.append(node)
160
- return limited
 
 
 
 
 
 
 
 
4
  import traceback
5
  from typing import List, Optional
6
 
7
+ import tiktoken
8
  from cohere import AsyncClient
9
  from dotenv import load_dotenv
10
+ from llama_index.core import Document, QueryBundle
11
+ from llama_index.core.async_utils import run_async_tasks
12
  from llama_index.core.retrievers import (
13
  BaseRetriever,
14
+ KeywordTableSimpleRetriever,
15
  VectorIndexRetriever,
16
  )
17
  from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle
18
  from llama_index.core.vector_stores import (
19
+ FilterCondition,
20
+ FilterOperator,
21
+ MetadataFilter,
22
  MetadataFilters,
23
  )
24
  from llama_index.postprocessor.cohere_rerank import CohereRerank
 
42
  nodes: List[NodeWithScore],
43
  query_bundle: Optional[QueryBundle] = None,
44
  ) -> List[NodeWithScore]:
45
+ if query_bundle is None:
46
+ raise ValueError("Query bundle must be provided.")
47
+ if len(nodes) == 0:
48
  return []
49
 
50
  async_client = AsyncClient(api_key=self._api_key)
51
+
52
  texts = [
53
  node.node.get_content(metadata_mode=MetadataMode.EMBED)
54
  for node in nodes
 
61
  documents=texts,
62
  )
63
 
64
+ new_nodes = []
65
+ for result in results.results:
66
+ new_node_with_score = NodeWithScore(
67
+ node=nodes[result.index].node, score=result.relevance_score
68
+ )
69
+ new_nodes.append(new_node_with_score)
70
+
71
+ return new_nodes
72
 
73
 
74
  class CustomRetriever(BaseRetriever):
75
+ """Custom retriever that performs both semantic search and hybrid search."""
76
+
77
  def __init__(
78
  self,
79
  vector_retriever: VectorIndexRetriever,
 
81
  keyword_retriever=None,
82
  mode: str = "AND",
83
  ) -> None:
 
84
  self._vector_retriever = vector_retriever
85
  self._document_dict = document_dict
86
  self._keyword_retriever = keyword_retriever
87
+ if mode not in ("AND", "OR"):
88
+ raise ValueError("Invalid mode. Use 'AND' or 'OR'")
89
  self._mode = mode
90
+ super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  async def _process_retrieval(
93
  self, query_bundle: QueryBundle, is_async: bool = True
94
  ) -> List[NodeWithScore]:
95
+
96
+ if not isinstance(query_bundle, QueryBundle):
97
+ raise TypeError(f"Expected QueryBundle, got {type(query_bundle)}")
98
+
99
+ query_bundle.query_str = query_bundle.query_str.replace("\ninput is ", "").rstrip()
100
+
101
  start = time.time()
102
+
103
  if is_async:
104
  nodes = await self._vector_retriever.aretrieve(query_bundle)
105
  else:
106
  nodes = self._vector_retriever.retrieve(query_bundle)
107
 
108
+ keyword_nodes = []
109
  if self._keyword_retriever:
110
+ if is_async:
111
+ keyword_nodes = await self._keyword_retriever.aretrieve(query_bundle)
112
+ else:
113
+ keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
 
 
114
 
115
+ vector_ids = {n.node.node_id for n in nodes}
116
+ keyword_ids = {n.node.node_id for n in keyword_nodes}
117
  combined_dict = {n.node.node_id: n for n in nodes}
118
  combined_dict.update({n.node.node_id: n for n in keyword_nodes})
119
 
120
+ if not self._keyword_retriever or not keyword_nodes:
121
+ retrieve_ids = vector_ids
 
 
 
122
  else:
123
+ retrieve_ids = (
124
+ vector_ids.intersection(keyword_ids)
125
+ if self._mode == "AND"
126
+ else vector_ids.union(keyword_ids)
127
+ )
128
 
129
+ nodes = [combined_dict[rid] for rid in retrieve_ids]
130
+ nodes = self._filter_nodes_by_unique_doc_id(nodes)
131
 
132
+ for node in nodes:
 
133
  doc_id = node.node.source_node.node_id
134
+ if node.metadata["retrieve_doc"]:
135
+ doc = self._document_dict[doc_id]
136
+ node.node.text = doc.text
 
137
  node.node.node_id = doc_id
138
 
 
139
  try:
140
  reranker = (
141
+ AsyncCohereRerank(top_n=5, model="rerank-english-v3.0")
142
+ if is_async
143
+ else CohereRerank(top_n=5, model="rerank-english-v3.0")
144
  )
145
+ nodes = (
146
+ await reranker.apostprocess_nodes(nodes, query_bundle)
147
+ if is_async
148
+ else reranker.postprocess_nodes(nodes, query_bundle)
149
  )
150
  except Exception as e:
151
+ print(f"Error during reranking: {type(e).__name__}: {str(e)}")
152
  traceback.print_exc()
153
 
154
+ nodes_filtered = self._filter_by_score_and_tokens(nodes)
155
+
156
+ duration = time.time() - start
157
+ print(f"Retrieving nodes took {duration:.2f}s")
158
+
159
+ return nodes_filtered[:5]
160
 
161
+ def _filter_nodes_by_unique_doc_id(
162
+ self, nodes: List[NodeWithScore]
163
+ ) -> List[NodeWithScore]:
164
+ unique_nodes = {}
165
  for node in nodes:
166
+ doc_id = node.node.source_node.node_id
167
+ if doc_id is not None and doc_id not in unique_nodes:
168
+ unique_nodes[doc_id] = node
169
+ return list(unique_nodes.values())
170
+
171
+ def _filter_by_score_and_tokens(
172
+ self, nodes: List[NodeWithScore]
173
+ ) -> List[NodeWithScore]:
174
+ nodes_filtered = []
175
+ total_tokens = 0
176
+ enc = tiktoken.encoding_for_model("gpt-4") # tokenizer model name is fine for now
177
+
178
+ for node in nodes:
179
+ if node.score < 0.10:
180
+ continue
181
+
182
+ node_tokens = len(enc.encode(node.node.text))
183
+ if total_tokens + node_tokens > 100_000:
184
  break
185
+
186
+ total_tokens += node_tokens
187
+ nodes_filtered.append(node)
188
+
189
+ return nodes_filtered
190
+
191
+ async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
192
+ return await self._process_retrieval(query_bundle, is_async=True)
193
+
194
+ def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
195
+ return asyncio.run(self._process_retrieval(query_bundle, is_async=False))