Maheen001 commited on
Commit
e64b016
·
verified ·
1 Parent(s): 0b27752

Create agent/rag_engine.py

Browse files
Files changed (1) hide show
  1. agent/rag_engine.py +146 -0
agent/rag_engine.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ from chromadb.config import Settings
3
+ from sentence_transformers import SentenceTransformer
4
+ from typing import List, Dict, Any
5
+ import uuid
6
+ from pathlib import Path
7
+
8
+
9
+ class RAGEngine:
10
+ """RAG engine using ChromaDB for vector storage and retrieval"""
11
+
12
+ def __init__(self, persist_directory: str = "data/chroma_db"):
13
+ """Initialize RAG engine with ChromaDB"""
14
+ Path(persist_directory).mkdir(parents=True, exist_ok=True)
15
+
16
+ # Initialize ChromaDB client
17
+ self.client = chromadb.PersistentClient(path=persist_directory)
18
+
19
+ # Get or create collection
20
+ self.collection = self.client.get_or_create_collection(
21
+ name="documents",
22
+ metadata={"hnsw:space": "cosine"}
23
+ )
24
+
25
+ # Initialize embedding model
26
+ self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
27
+
28
+ async def add_document(
29
+ self,
30
+ text: str,
31
+ metadata: Dict[str, Any] = None
32
+ ) -> str:
33
+ """
34
+ Add document to RAG index
35
+
36
+ Args:
37
+ text: Document text
38
+ metadata: Document metadata
39
+
40
+ Returns:
41
+ Document ID
42
+ """
43
+ doc_id = str(uuid.uuid4())
44
+
45
+ # Generate embedding
46
+ embedding = self.embedding_model.encode(text).tolist()
47
+
48
+ # Add to collection
49
+ self.collection.add(
50
+ ids=[doc_id],
51
+ embeddings=[embedding],
52
+ documents=[text],
53
+ metadatas=[metadata or {}]
54
+ )
55
+
56
+ return doc_id
57
+
58
+ async def add_documents(
59
+ self,
60
+ texts: List[str],
61
+ metadatas: List[Dict[str, Any]] = None
62
+ ) -> List[str]:
63
+ """Add multiple documents at once"""
64
+ doc_ids = [str(uuid.uuid4()) for _ in texts]
65
+
66
+ # Generate embeddings
67
+ embeddings = self.embedding_model.encode(texts).tolist()
68
+
69
+ # Add to collection
70
+ self.collection.add(
71
+ ids=doc_ids,
72
+ embeddings=embeddings,
73
+ documents=texts,
74
+ metadatas=metadatas or [{} for _ in texts]
75
+ )
76
+
77
+ return doc_ids
78
+
79
+ async def search(
80
+ self,
81
+ query: str,
82
+ k: int = 5,
83
+ filter_metadata: Dict[str, Any] = None
84
+ ) -> List[Dict[str, Any]]:
85
+ """
86
+ Search for relevant documents
87
+
88
+ Args:
89
+ query: Search query
90
+ k: Number of results
91
+ filter_metadata: Metadata filters
92
+
93
+ Returns:
94
+ List of matching documents with scores
95
+ """
96
+ # Generate query embedding
97
+ query_embedding = self.embedding_model.encode(query).tolist()
98
+
99
+ # Search
100
+ results = self.collection.query(
101
+ query_embeddings=[query_embedding],
102
+ n_results=k,
103
+ where=filter_metadata
104
+ )
105
+
106
+ # Format results
107
+ documents = []
108
+ if results['documents'] and results['documents'][0]:
109
+ for i, doc in enumerate(results['documents'][0]):
110
+ documents.append({
111
+ 'id': results['ids'][0][i],
112
+ 'text': doc,
113
+ 'metadata': results['metadatas'][0][i] if results['metadatas'] else {},
114
+ 'distance': results['distances'][0][i] if results['distances'] else 0
115
+ })
116
+
117
+ return documents
118
+
119
+ async def get_document(self, doc_id: str) -> Dict[str, Any]:
120
+ """Get document by ID"""
121
+ result = self.collection.get(ids=[doc_id])
122
+
123
+ if result['documents']:
124
+ return {
125
+ 'id': doc_id,
126
+ 'text': result['documents'][0],
127
+ 'metadata': result['metadatas'][0] if result['metadatas'] else {}
128
+ }
129
+
130
+ return None
131
+
132
+ async def delete_document(self, doc_id: str):
133
+ """Delete document by ID"""
134
+ self.collection.delete(ids=[doc_id])
135
+
136
+ async def count_documents(self) -> int:
137
+ """Get total number of documents"""
138
+ return self.collection.count()
139
+
140
+ async def clear_all(self):
141
+ """Clear all documents"""
142
+ self.client.delete_collection(name="documents")
143
+ self.collection = self.client.get_or_create_collection(
144
+ name="documents",
145
+ metadata={"hnsw:space": "cosine"}
146
+ )