KUNAL SHAW commited on
Commit
add1aec
·
1 Parent(s): 75729d2

Replace towhee milvus-client with direct MilvusClient search for Zilliz Serverless

Browse files
Files changed (1) hide show
  1. app.py +36 -21
app.py CHANGED
@@ -317,29 +317,46 @@ utility.loading_progress(COLLECTION_NAME)
317
 
318
  max_input_length = 500 # Maximum length allowed by the model
319
 
320
- # Configure Towhee Milvus Client arguments based on connection type
321
- milvus_args = {
322
- "collection_name": COLLECTION_NAME,
323
- "limit": 1
324
- }
325
  if milvus_uri and milvus_token:
326
- milvus_args["uri"] = milvus_uri
327
- milvus_args["token"] = milvus_token
328
  else:
329
- milvus_args["host"] = host_milvus
330
- milvus_args["port"] = '19530'
331
 
332
- # Create the combined pipe for question encoding and answer retrieval
333
- combined_pipe = (
334
  pipe.input('question')
335
- .map('question', 'vec', lambda x: x[:max_input_length]) # Truncate the question if longer than 512 tokens
336
- .map('vec', 'vec', ops.text_embedding.dpr(model_name='facebook/dpr-ctx_encoder-single-nq-base'))
337
- .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
338
- .map('vec', 'res', ops.ann_search.milvus_client(**milvus_args))
339
- .map('res', 'answer', lambda x: [id_answer[int(i[0])] for i in x])
340
- .output('question', 'answer')
341
  )
342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  # Step 3 - Custom LLM
344
  from openai import OpenAI
345
  def generate_stream(prompt, model="mixtral-8x7b"):
@@ -382,10 +399,8 @@ class CustomRetrieverLang(BaseRetriever):
382
  self, query: str, *, run_manager: CallbackManagerForRetrieverRun
383
  ) -> List[Document]:
384
  # Perform the encoding and retrieval for a specific question
385
- ans = combined_pipe(query)
386
- ans = DataCollection(ans)
387
- answer=ans[0]['answer']
388
- answer_string = ' '.join(answer)
389
  return [Document(page_content=answer_string)]
390
  # Ensure correct VectorStoreRetriever usage
391
  retriever = CustomRetrieverLang()
 
317
 
318
  max_input_length = 500 # Maximum length allowed by the model
319
 
320
+ # Initialize MilvusClient for search (compatible with Zilliz Serverless)
321
+ from pymilvus import MilvusClient as SearchClient
 
 
 
322
  if milvus_uri and milvus_token:
323
+ search_client = SearchClient(uri=milvus_uri, token=milvus_token)
 
324
  else:
325
+ search_client = SearchClient(uri=f"http://{host_milvus}:19530")
 
326
 
327
+ # Initialize embedding pipeline (without Milvus search - we'll do that separately)
328
+ embedding_pipe = (
329
  pipe.input('question')
330
+ .map('question', 'vec', lambda x: x[:max_input_length])
331
+ .map('vec', 'vec', ops.text_embedding.dpr(model_name='facebook/dpr-ctx_encoder-single-nq-base'))
332
+ .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
333
+ .output('vec')
 
 
334
  )
335
 
336
+ def search_similar_questions(question: str) -> list:
337
+ """Search for similar questions using MilvusClient directly (Zilliz Serverless compatible)."""
338
+ # Get embedding for the question
339
+ result = embedding_pipe(question)
340
+ query_vector = result.get()[0].tolist()
341
+
342
+ # Search using MilvusClient
343
+ search_results = search_client.search(
344
+ collection_name=COLLECTION_NAME,
345
+ data=[query_vector],
346
+ limit=1,
347
+ output_fields=["id"]
348
+ )
349
+
350
+ # Extract answers from results
351
+ answers = []
352
+ for hits in search_results:
353
+ for hit in hits:
354
+ doc_id = hit['id']
355
+ if doc_id in id_answer:
356
+ answers.append(id_answer[doc_id])
357
+
358
+ return answers
359
+
360
  # Step 3 - Custom LLM
361
  from openai import OpenAI
362
  def generate_stream(prompt, model="mixtral-8x7b"):
 
399
  self, query: str, *, run_manager: CallbackManagerForRetrieverRun
400
  ) -> List[Document]:
401
  # Perform the encoding and retrieval for a specific question
402
+ answers = search_similar_questions(query)
403
+ answer_string = ' '.join(answers) if answers else "No relevant information found."
 
 
404
  return [Document(page_content=answer_string)]
405
  # Ensure correct VectorStoreRetriever usage
406
  retriever = CustomRetrieverLang()