KUNAL SHAW commited on
Commit
dc3a91a
·
1 Parent(s): e0075a3

Fix Milvus collection initialization and Towhee connection

Browse files
Files changed (1) hide show
  1. app.py +56 -2
app.py CHANGED
@@ -243,20 +243,74 @@ else:
243
  print(f"Connecting to Milvus Host: {host_milvus}")
244
  connections.connect(host=host_milvus, port='19530')
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
  collection = Collection(COLLECTION_NAME)
248
- collection.load(replica_number=1)
249
  utility.load_state(COLLECTION_NAME)
250
  utility.loading_progress(COLLECTION_NAME)
251
 
252
  max_input_length = 500 # Maximum length allowed by the model
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  # Create the combined pipe for question encoding and answer retrieval
254
  combined_pipe = (
255
  pipe.input('question')
256
  .map('question', 'vec', lambda x: x[:max_input_length]) # Truncate the question if longer than 512 tokens
257
  .map('vec', 'vec', ops.text_embedding.dpr(model_name='facebook/dpr-ctx_encoder-single-nq-base'))
258
  .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
259
- .map('vec', 'res', ops.ann_search.milvus_client(host=host_milvus, port='19530', collection_name=COLLECTION_NAME, limit=1))
260
  .map('res', 'answer', lambda x: [id_answer[int(i[0])] for i in x])
261
  .output('question', 'answer')
262
  )
 
243
  print(f"Connecting to Milvus Host: {host_milvus}")
244
  connections.connect(host=host_milvus, port='19530')
245
 
246
+ # Check if collection exists, if not create and populate it
247
+ if not utility.has_collection(COLLECTION_NAME):
248
+ print(f"Collection {COLLECTION_NAME} not found. Creating and populating...")
249
+
250
+ # 1. Define Schema
251
+ fields = [
252
+ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=False),
253
+ FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=768) # DPR uses 768 dims
254
+ ]
255
+ schema = CollectionSchema(fields, "Medical Chatbot QA")
256
+ collection = Collection(COLLECTION_NAME, schema)
257
+
258
+ # 2. Generate Embeddings
259
+ print("Generating embeddings for initial data...")
260
+ embedding_pipe = (
261
+ pipe.input('question')
262
+ .map('question', 'vec', lambda x: x[:500])
263
+ .map('vec', 'vec', ops.text_embedding.dpr(model_name='facebook/dpr-ctx_encoder-single-nq-base'))
264
+ .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
265
+ .output('vec')
266
+ )
267
+
268
+ vectors = []
269
+ # Process in batches to be safe
270
+ for q in df['question']:
271
+ res = embedding_pipe(q)
272
+ vectors.append(res.get()[0])
273
+
274
+ # 3. Insert Data
275
+ print("Inserting data into Zilliz...")
276
+ collection.insert([df['id'].tolist(), vectors])
277
+
278
+ # 4. Create Index
279
+ print("Creating index...")
280
+ index_params = {
281
+ "metric_type": "IP",
282
+ "index_type": "AUTOINDEX",
283
+ "params": {}
284
+ }
285
+ collection.create_index(field_name="embedding", index_params=index_params)
286
+ print("Collection setup complete.")
287
 
288
  collection = Collection(COLLECTION_NAME)
289
+ collection.load()
290
  utility.load_state(COLLECTION_NAME)
291
  utility.loading_progress(COLLECTION_NAME)
292
 
293
  max_input_length = 500 # Maximum length allowed by the model
294
+
295
+ # Configure Towhee Milvus Client arguments based on connection type
296
+ milvus_args = {
297
+ "collection_name": COLLECTION_NAME,
298
+ "limit": 1
299
+ }
300
+ if milvus_uri and milvus_token:
301
+ milvus_args["uri"] = milvus_uri
302
+ milvus_args["token"] = milvus_token
303
+ else:
304
+ milvus_args["host"] = host_milvus
305
+ milvus_args["port"] = '19530'
306
+
307
  # Create the combined pipe for question encoding and answer retrieval
308
  combined_pipe = (
309
  pipe.input('question')
310
  .map('question', 'vec', lambda x: x[:max_input_length]) # Truncate the question if longer than 512 tokens
311
  .map('vec', 'vec', ops.text_embedding.dpr(model_name='facebook/dpr-ctx_encoder-single-nq-base'))
312
  .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
313
+ .map('vec', 'res', ops.ann_search.milvus_client(**milvus_args))
314
  .map('res', 'answer', lambda x: [id_answer[int(i[0])] for i in x])
315
  .output('question', 'answer')
316
  )