Spaces:
Sleeping
Sleeping
| import os | |
| from utils.logger import logger | |
| from config.settings import settings | |
| from uuid import uuid4 | |
| from typing import List, Tuple, Dict, Any | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import Distance, VectorParams, PointStruct, UpdateStatus | |
| class VectorDBManager: | |
| def __init__(self, collection_name: str, embedding_dim: int, client: QdrantClient = None): | |
| logger.info(f"Initializing Qdrant VectorDBManager for collection: '{collection_name}'") | |
| if client: | |
| self.client = client | |
| logger.info("Using shared Qdrant client instance.") | |
| else: | |
| logger.warning("No shared Qdrant client provided. Creating a new local instance.") | |
| qdrant_db_path = os.path.join(settings.DATA_DIR, "qdrant_data") | |
| self.client = QdrantClient(path=qdrant_db_path) | |
| self.collection_name = collection_name | |
| self.embedding_dim = embedding_dim | |
| self.create_collection_if_not_exists() | |
| def create_collection_if_not_exists(self): | |
| try: | |
| collections = self.client.get_collections().collections | |
| collection_names = [collection.name for collection in collections] | |
| if self.collection_name not in collection_names: | |
| logger.info(f"Collection '{self.collection_name}' not found. Creating a new one...") | |
| self.client.recreate_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=VectorParams( | |
| size=self.embedding_dim, | |
| distance=Distance.COSINE | |
| ) | |
| ) | |
| logger.success(f"Collection '{self.collection_name}' created successfully.") | |
| else: | |
| logger.info(f"Collection '{self.collection_name}' already exists.") | |
| except Exception as e: | |
| logger.error(f"Error checking or creating collection '{self.collection_name}': {e}") | |
| raise | |
| def add_vectors(self, embeddings: List[List[float]], metadatas: List[Dict[str, Any]]): | |
| if not embeddings: | |
| logger.warning("No embeddings to add. Skipping.") | |
| return | |
| if len(embeddings) != len(metadatas): | |
| logger.error("Number of embeddings and metadatas must match.") | |
| raise ValueError("Embeddings and metadatas count mismatch.") | |
| points_to_add = [] | |
| for i, (embedding, metadata) in enumerate(zip(embeddings, metadatas)): | |
| point_id = str(uuid4()) | |
| points_to_add.append( | |
| PointStruct( | |
| id=point_id, | |
| vector=embedding, | |
| payload=metadata | |
| ) | |
| ) | |
| try: | |
| operation_info = self.client.upsert( | |
| collection_name=self.collection_name, | |
| wait=True, | |
| points=points_to_add | |
| ) | |
| if operation_info.status == UpdateStatus.COMPLETED: | |
| logger.debug(f"Successfully upserted {len(points_to_add)} points to collection '{self.collection_name}'.") | |
| else: | |
| logger.warning(f"Upsert operation finished with status: {operation_info.status}") | |
| except Exception as e: | |
| logger.error(f"Error upserting points to collection '{self.collection_name}': {e}") | |
| def search_vectors(self, query_embedding: List[float], k: int = 5, filter_payload: Dict = None) -> List[Tuple[float, Dict[str, Any]]]: | |
| try: | |
| search_results = self.client.search( | |
| collection_name=self.collection_name, | |
| query_vector=query_embedding, | |
| query_filter=filter_payload, | |
| limit=k, | |
| with_payload=True, # include payload in return | |
| with_vectors=False # exclude vectors in return | |
| ) | |
| formatted_results = [] | |
| for scored_point in search_results: | |
| score = scored_point.score | |
| payload = scored_point.payload | |
| formatted_results.append((score, payload)) | |
| logger.debug(f"Searched for top {k} neighbors. Found {len(formatted_results)} results.") | |
| return formatted_results | |
| except Exception as e: | |
| logger.error(f"Error searching in collection '{self.collection_name}': {e}") | |
| return [] | |
| def get_total_vectors(self) -> int: | |
| try: | |
| count_result = self.client.count( | |
| collection_name=self.collection_name, | |
| exact=True | |
| ) | |
| return count_result.count | |
| except Exception as e: | |
| logger.error(f"Error counting vectors in collection '{self.collection_name}': {e}") | |
| return 0 |