Spaces:
Running
Running
| import chromadb | |
| from chromadb.config import Settings | |
| from sentence_transformers import SentenceTransformer | |
| from typing import List, Dict, Any | |
| import uuid | |
| from pathlib import Path | |
| class RAGEngine: | |
| """RAG engine using ChromaDB for vector storage and retrieval""" | |
| def __init__(self, persist_directory: str = "data/chroma_db"): | |
| """Initialize RAG engine with ChromaDB""" | |
| Path(persist_directory).mkdir(parents=True, exist_ok=True) | |
| # Initialize ChromaDB client | |
| self.client = chromadb.PersistentClient(path=persist_directory) | |
| # Get or create collection | |
| self.collection = self.client.get_or_create_collection( | |
| name="documents", | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |
| # Initialize embedding model | |
| self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| async def add_document( | |
| self, | |
| text: str, | |
| metadata: Dict[str, Any] = None | |
| ) -> str: | |
| """ | |
| Add document to RAG index | |
| Args: | |
| text: Document text | |
| metadata: Document metadata | |
| Returns: | |
| Document ID | |
| """ | |
| doc_id = str(uuid.uuid4()) | |
| # Generate embedding | |
| embedding = self.embedding_model.encode(text).tolist() | |
| # Add to collection | |
| self.collection.add( | |
| ids=[doc_id], | |
| embeddings=[embedding], | |
| documents=[text], | |
| metadatas=[metadata or {}] | |
| ) | |
| return doc_id | |
| async def add_documents( | |
| self, | |
| texts: List[str], | |
| metadatas: List[Dict[str, Any]] = None | |
| ) -> List[str]: | |
| """Add multiple documents at once""" | |
| doc_ids = [str(uuid.uuid4()) for _ in texts] | |
| # Generate embeddings | |
| embeddings = self.embedding_model.encode(texts).tolist() | |
| # Add to collection | |
| self.collection.add( | |
| ids=doc_ids, | |
| embeddings=embeddings, | |
| documents=texts, | |
| metadatas=metadatas or [{} for _ in texts] | |
| ) | |
| return doc_ids | |
| async def search( | |
| self, | |
| query: str, | |
| k: int = 5, | |
| filter_metadata: Dict[str, Any] = None | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Search for relevant documents | |
| Args: | |
| query: Search query | |
| k: Number of results | |
| filter_metadata: Metadata filters | |
| Returns: | |
| List of matching documents with scores | |
| """ | |
| # Generate query embedding | |
| query_embedding = self.embedding_model.encode(query).tolist() | |
| # Search | |
| results = self.collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=k, | |
| where=filter_metadata | |
| ) | |
| # Format results | |
| documents = [] | |
| if results['documents'] and results['documents'][0]: | |
| for i, doc in enumerate(results['documents'][0]): | |
| documents.append({ | |
| 'id': results['ids'][0][i], | |
| 'text': doc, | |
| 'metadata': results['metadatas'][0][i] if results['metadatas'] else {}, | |
| 'distance': results['distances'][0][i] if results['distances'] else 0 | |
| }) | |
| return documents | |
| async def get_document(self, doc_id: str) -> Dict[str, Any]: | |
| """Get document by ID""" | |
| result = self.collection.get(ids=[doc_id]) | |
| if result['documents']: | |
| return { | |
| 'id': doc_id, | |
| 'text': result['documents'][0], | |
| 'metadata': result['metadatas'][0] if result['metadatas'] else {} | |
| } | |
| return None | |
| async def delete_document(self, doc_id: str): | |
| """Delete document by ID""" | |
| self.collection.delete(ids=[doc_id]) | |
| async def count_documents(self) -> int: | |
| """Get total number of documents""" | |
| return self.collection.count() | |
| async def clear_all(self): | |
| """Clear all documents""" | |
| self.client.delete_collection(name="documents") | |
| self.collection = self.client.get_or_create_collection( | |
| name="documents", | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |