import chromadb from chromadb.config import Settings from sentence_transformers import SentenceTransformer from typing import List, Dict, Any import uuid from pathlib import Path class RAGEngine: """RAG engine using ChromaDB for vector storage and retrieval""" def __init__(self, persist_directory: str = "data/chroma_db"): """Initialize RAG engine with ChromaDB""" Path(persist_directory).mkdir(parents=True, exist_ok=True) # Initialize ChromaDB client self.client = chromadb.PersistentClient(path=persist_directory) # Get or create collection self.collection = self.client.get_or_create_collection( name="documents", metadata={"hnsw:space": "cosine"} ) # Initialize embedding model self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') async def add_document( self, text: str, metadata: Dict[str, Any] = None ) -> str: """ Add document to RAG index Args: text: Document text metadata: Document metadata Returns: Document ID """ doc_id = str(uuid.uuid4()) # Generate embedding embedding = self.embedding_model.encode(text).tolist() # Add to collection self.collection.add( ids=[doc_id], embeddings=[embedding], documents=[text], metadatas=[metadata or {}] ) return doc_id async def add_documents( self, texts: List[str], metadatas: List[Dict[str, Any]] = None ) -> List[str]: """Add multiple documents at once""" doc_ids = [str(uuid.uuid4()) for _ in texts] # Generate embeddings embeddings = self.embedding_model.encode(texts).tolist() # Add to collection self.collection.add( ids=doc_ids, embeddings=embeddings, documents=texts, metadatas=metadatas or [{} for _ in texts] ) return doc_ids async def search( self, query: str, k: int = 5, filter_metadata: Dict[str, Any] = None ) -> List[Dict[str, Any]]: """ Search for relevant documents Args: query: Search query k: Number of results filter_metadata: Metadata filters Returns: List of matching documents with scores """ # Generate query embedding query_embedding = self.embedding_model.encode(query).tolist() # Search results = self.collection.query( query_embeddings=[query_embedding], n_results=k, where=filter_metadata ) # Format results documents = [] if results['documents'] and results['documents'][0]: for i, doc in enumerate(results['documents'][0]): documents.append({ 'id': results['ids'][0][i], 'text': doc, 'metadata': results['metadatas'][0][i] if results['metadatas'] else {}, 'distance': results['distances'][0][i] if results['distances'] else 0 }) return documents async def get_document(self, doc_id: str) -> Dict[str, Any]: """Get document by ID""" result = self.collection.get(ids=[doc_id]) if result['documents']: return { 'id': doc_id, 'text': result['documents'][0], 'metadata': result['metadatas'][0] if result['metadatas'] else {} } return None async def delete_document(self, doc_id: str): """Delete document by ID""" self.collection.delete(ids=[doc_id]) async def count_documents(self) -> int: """Get total number of documents""" return self.collection.count() async def clear_all(self): """Clear all documents""" self.client.delete_collection(name="documents") self.collection = self.client.get_or_create_collection( name="documents", metadata={"hnsw:space": "cosine"} )