LifeAdmin-AI / agent /rag_engine.py
Maheen001's picture
Create agent/rag_engine.py
e64b016 verified
raw
history blame
4.4 kB
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any
import uuid
from pathlib import Path
class RAGEngine:
"""RAG engine using ChromaDB for vector storage and retrieval"""
def __init__(self, persist_directory: str = "data/chroma_db"):
"""Initialize RAG engine with ChromaDB"""
Path(persist_directory).mkdir(parents=True, exist_ok=True)
# Initialize ChromaDB client
self.client = chromadb.PersistentClient(path=persist_directory)
# Get or create collection
self.collection = self.client.get_or_create_collection(
name="documents",
metadata={"hnsw:space": "cosine"}
)
# Initialize embedding model
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
async def add_document(
self,
text: str,
metadata: Dict[str, Any] = None
) -> str:
"""
Add document to RAG index
Args:
text: Document text
metadata: Document metadata
Returns:
Document ID
"""
doc_id = str(uuid.uuid4())
# Generate embedding
embedding = self.embedding_model.encode(text).tolist()
# Add to collection
self.collection.add(
ids=[doc_id],
embeddings=[embedding],
documents=[text],
metadatas=[metadata or {}]
)
return doc_id
async def add_documents(
self,
texts: List[str],
metadatas: List[Dict[str, Any]] = None
) -> List[str]:
"""Add multiple documents at once"""
doc_ids = [str(uuid.uuid4()) for _ in texts]
# Generate embeddings
embeddings = self.embedding_model.encode(texts).tolist()
# Add to collection
self.collection.add(
ids=doc_ids,
embeddings=embeddings,
documents=texts,
metadatas=metadatas or [{} for _ in texts]
)
return doc_ids
async def search(
self,
query: str,
k: int = 5,
filter_metadata: Dict[str, Any] = None
) -> List[Dict[str, Any]]:
"""
Search for relevant documents
Args:
query: Search query
k: Number of results
filter_metadata: Metadata filters
Returns:
List of matching documents with scores
"""
# Generate query embedding
query_embedding = self.embedding_model.encode(query).tolist()
# Search
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=k,
where=filter_metadata
)
# Format results
documents = []
if results['documents'] and results['documents'][0]:
for i, doc in enumerate(results['documents'][0]):
documents.append({
'id': results['ids'][0][i],
'text': doc,
'metadata': results['metadatas'][0][i] if results['metadatas'] else {},
'distance': results['distances'][0][i] if results['distances'] else 0
})
return documents
async def get_document(self, doc_id: str) -> Dict[str, Any]:
"""Get document by ID"""
result = self.collection.get(ids=[doc_id])
if result['documents']:
return {
'id': doc_id,
'text': result['documents'][0],
'metadata': result['metadatas'][0] if result['metadatas'] else {}
}
return None
async def delete_document(self, doc_id: str):
"""Delete document by ID"""
self.collection.delete(ids=[doc_id])
async def count_documents(self) -> int:
"""Get total number of documents"""
return self.collection.count()
async def clear_all(self):
"""Clear all documents"""
self.client.delete_collection(name="documents")
self.collection = self.client.get_or_create_collection(
name="documents",
metadata={"hnsw:space": "cosine"}
)