al1kss commited on
Commit
5d9aa53
·
verified ·
1 Parent(s): feab486

Update lightrag_manager.py

Browse files
Files changed (1) hide show
  1. lightrag_manager.py +552 -247
lightrag_manager.py CHANGED
@@ -9,15 +9,57 @@ from typing import Dict, List, Optional, Any, Tuple
9
  from datetime import datetime
10
  import uuid
11
  import httpx
 
 
12
 
13
  # LightRAG imports
14
  from lightrag import LightRAG, QueryParam
15
  from lightrag.utils import EmbeddingFunc
16
 
17
- # Import enhanced database manager
18
- from enhanced_database_manager import EnhancedDatabaseManager
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  class CloudflareWorker:
 
 
21
  def __init__(
22
  self,
23
  cloudflare_api_key: str,
@@ -33,100 +75,348 @@ class CloudflareWorker:
33
  self.embedding_model_name = embedding_model_name
34
  self.max_tokens = max_tokens
35
  self.max_response_tokens = max_response_tokens
 
36
 
37
- async def _send_request(self, model_name: str, input_: dict):
 
38
  headers = {"Authorization": f"Bearer {self.cloudflare_api_key}"}
39
 
40
  try:
41
- async with httpx.AsyncClient() as client:
42
  response = await client.post(
43
  f"{self.api_base_url}{model_name}",
44
  headers=headers,
45
- json=input_,
46
- timeout=30.0
47
  )
48
  response.raise_for_status()
49
  response_data = response.json()
50
 
51
  result = response_data.get("result", {})
52
 
 
53
  if "data" in result:
54
  return np.array(result["data"])
55
 
 
56
  if "response" in result:
57
  return result["response"]
58
 
59
  raise ValueError("Unexpected Cloudflare response format")
60
 
61
  except Exception as e:
62
- logging.error(f"Cloudflare API error: {e}")
63
  raise
64
 
65
  async def query(self, prompt: str, system_prompt: str = "", **kwargs) -> str:
66
- # Clean kwargs to avoid LightRAG-specific parameters
67
- clean_kwargs = {k: v for k, v in kwargs.items() if k not in ['hashing_kv', 'history_messages']}
 
 
 
 
 
 
 
68
 
69
- message = [
70
- {"role": "system", "content": system_prompt},
71
  {"role": "user", "content": prompt},
72
  ]
73
 
74
- input_ = {
75
- "messages": message,
76
  "max_tokens": self.max_tokens,
77
- "response_token_limit": self.max_response_tokens,
78
  }
79
 
80
- return await self._send_request(self.llm_model_name, input_)
81
 
82
  async def embedding_chunk(self, texts: List[str]) -> np.ndarray:
83
- input_ = {
 
84
  "text": texts,
85
  "max_tokens": self.max_tokens,
86
- "response_token_limit": self.max_response_tokens,
87
  }
88
 
89
- return await self._send_request(self.embedding_model_name, input_)
90
 
91
  class VercelBlobClient:
92
- """Vercel Blob storage client for backup storage"""
93
 
94
  def __init__(self, token: str):
95
  self.token = token
96
  self.logger = logging.getLogger(__name__)
97
 
98
  async def put(self, filename: str, data: bytes) -> str:
99
- """Upload data to Vercel Blob as backup"""
100
  try:
101
- async with httpx.AsyncClient() as client:
102
  response = await client.put(
103
  f"https://blob.vercel-storage.com/{filename}",
104
  headers={"Authorization": f"Bearer {self.token}"},
105
- content=data,
106
- timeout=120.0
107
  )
108
  response.raise_for_status()
109
- return f"https://blob.vercel-storage.com/{filename}"
 
110
  except Exception as e:
111
  self.logger.error(f"Failed to upload to Vercel Blob: {e}")
112
  raise
113
 
114
- class ProductionLightRAGManager:
115
- """Final Production LightRAG Manager with Complete Database Storage"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  def __init__(
118
  self,
119
  cloudflare_worker: CloudflareWorker,
120
- database_manager: EnhancedDatabaseManager,
121
- vercel_blob_client: Optional[VercelBlobClient] = None
122
  ):
123
  self.cloudflare_worker = cloudflare_worker
124
  self.db = database_manager
125
- self.blob_client = vercel_blob_client
126
  self.rag_instances: Dict[str, LightRAG] = {}
127
- self.processing_lock: Dict[str, asyncio.Lock] = {}
 
128
  self.logger = logging.getLogger(__name__)
129
-
130
  async def get_or_create_rag_instance(
131
  self,
132
  ai_type: str,
@@ -135,20 +425,28 @@ class ProductionLightRAGManager:
135
  name: Optional[str] = None,
136
  description: Optional[str] = None
137
  ) -> LightRAG:
138
- """Get or create a LightRAG instance with complete database storage"""
139
 
140
- cache_key = self._get_cache_key(ai_type, user_id, ai_id)
 
 
 
 
 
 
 
 
141
 
142
- # Check memory cache first
143
  if cache_key in self.rag_instances:
144
  self.logger.info(f"Returning cached RAG instance: {cache_key}")
145
  return self.rag_instances[cache_key]
146
 
147
- # Ensure only one instance is created/loaded at a time
148
- if cache_key not in self.processing_lock:
149
- self.processing_lock[cache_key] = asyncio.Lock()
150
 
151
- async with self.processing_lock[cache_key]:
152
  # Double-check after acquiring lock
153
  if cache_key in self.rag_instances:
154
  return self.rag_instances[cache_key]
@@ -157,7 +455,7 @@ class ProductionLightRAGManager:
157
 
158
  # Try to load from database
159
  try:
160
- rag_instance = await self._load_from_database(ai_type, user_id, ai_id)
161
  if rag_instance:
162
  self.rag_instances[cache_key] = rag_instance
163
  self.logger.info(f"Loaded RAG instance from database: {cache_key}")
@@ -166,16 +464,10 @@ class ProductionLightRAGManager:
166
  self.logger.warning(f"Failed to load RAG from database: {e}")
167
 
168
  # Create new instance
169
- rag_instance = await self._create_new_rag_instance(
170
- ai_type, user_id, ai_id, name or f"{ai_type} AI", description
171
- )
172
 
173
  # Save to database
174
- await self._save_to_database(
175
- ai_type, user_id, ai_id,
176
- name or f"{ai_type} AI", description,
177
- rag_instance
178
- )
179
 
180
  # Cache in memory
181
  self.rag_instances[cache_key] = rag_instance
@@ -183,88 +475,14 @@ class ProductionLightRAGManager:
183
 
184
  return rag_instance
185
 
186
- async def _load_from_database(
187
- self,
188
- ai_type: str,
189
- user_id: Optional[str],
190
- ai_id: Optional[str]
191
- ) -> Optional[LightRAG]:
192
- """Load RAG instance from database"""
193
-
194
- # Get complete RAG instance from database
195
- rag_data = await self.db.load_complete_rag_instance(ai_type, user_id, ai_id)
196
- if not rag_data:
197
- return None
198
 
199
- try:
200
- # Reconstruct RAG instance from database data
201
- rag_instance = await self._deserialize_rag_state(rag_data['rag_state'])
202
-
203
- self.logger.info(f"Successfully loaded RAG from database: {ai_type}")
204
- return rag_instance
205
-
206
- except Exception as e:
207
- self.logger.error(f"Failed to reconstruct RAG from database: {e}")
208
- return None
209
-
210
- async def _save_to_database(
211
- self,
212
- ai_type: str,
213
- user_id: Optional[str],
214
- ai_id: Optional[str],
215
- name: str,
216
- description: Optional[str],
217
- rag_instance: LightRAG
218
- ):
219
- """Save RAG instance completely to database"""
220
-
221
- try:
222
- # Serialize RAG state
223
- rag_state = await self._serialize_rag_state(rag_instance)
224
-
225
- # Optional: Create backup in Vercel Blob
226
- blob_url = None
227
- if self.blob_client:
228
- try:
229
- compressed_data = gzip.compress(pickle.dumps(rag_state))
230
- blob_filename = f"lightrag_backup/{ai_type}_{user_id or 'system'}_{ai_id or 'default'}_{uuid.uuid4()}.pkl.gz"
231
- blob_url = await self.blob_client.put(blob_filename, compressed_data)
232
- self.logger.info(f"Created backup in Vercel Blob: {blob_url}")
233
- except Exception as e:
234
- self.logger.warning(f"Failed to create Vercel Blob backup: {e}")
235
-
236
- # Save everything to database
237
- await self.db.save_complete_rag_instance(
238
- ai_type=ai_type,
239
- user_id=user_id,
240
- ai_id=ai_id,
241
- name=name,
242
- description=description,
243
- rag_state=rag_state,
244
- blob_url=blob_url
245
- )
246
-
247
- self.logger.info(f"Successfully saved RAG to database: {ai_type}")
248
-
249
- except Exception as e:
250
- self.logger.error(f"Failed to save RAG to database: {e}")
251
- raise
252
-
253
- async def _create_new_rag_instance(
254
- self,
255
- ai_type: str,
256
- user_id: Optional[str],
257
- ai_id: Optional[str],
258
- name: str,
259
- description: Optional[str]
260
- ) -> LightRAG:
261
- """Create a new LightRAG instance"""
262
-
263
- # Create temporary working directory (will be serialized to database)
264
- working_dir = f"/tmp/rag_temp_{ai_type}_{user_id or 'system'}_{ai_id or 'default'}_{uuid.uuid4()}"
265
  os.makedirs(working_dir, exist_ok=True)
266
 
267
- # Initialize LightRAG
268
  rag = LightRAG(
269
  working_dir=working_dir,
270
  max_parallel_insert=2,
@@ -280,91 +498,157 @@ class ProductionLightRAGManager:
280
  vector_storage="NanoVectorDBStorage",
281
  )
282
 
283
- # Wait for initialization to complete
284
  await rag.initialize_storages()
285
 
286
  # Load knowledge based on AI type
287
- if ai_type == "fire-safety":
288
  await self._load_fire_safety_knowledge(rag)
289
 
290
  return rag
291
 
292
  async def _load_fire_safety_knowledge(self, rag: LightRAG):
293
- """Load fire safety knowledge from files"""
294
- knowledge_sources = [
295
- "/app/book.txt",
296
- "/app/book.pdf",
297
- "/app/fire_safety.txt",
298
- "./book.txt",
299
- "./book.pdf"
300
- ]
301
 
302
- combined_content = ""
303
- processed_files = []
 
304
 
305
- for source_file in knowledge_sources:
306
- try:
307
- if os.path.exists(source_file):
308
- if source_file.endswith('.txt'):
309
- with open(source_file, 'r', encoding='utf-8') as f:
310
- content = f.read()
311
- elif source_file.endswith('.pdf'):
312
- content = await self._extract_pdf_content(source_file)
313
- else:
314
- continue
315
-
316
- if content.strip():
317
- combined_content += f"\n\n=== Content from {source_file} ===\n\n{content}\n\n"
318
- processed_files.append(source_file)
319
-
320
- except Exception as e:
321
- self.logger.warning(f"Failed to process {source_file}: {e}")
 
 
 
 
 
 
 
322
 
323
- if combined_content.strip():
324
- self.logger.info(f"Inserting fire safety knowledge from {len(processed_files)} files")
325
- await rag.ainsert(combined_content)
326
- else:
327
- self.logger.warning("No fire safety knowledge found")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
- async def _extract_pdf_content(self, pdf_path: str) -> str:
330
- """Extract text from PDF file"""
 
331
  try:
332
- import PyPDF2
333
- content = ""
334
- with open(pdf_path, 'rb') as file:
335
- pdf_reader = PyPDF2.PdfReader(file)
336
- for page in pdf_reader.pages:
337
- text = page.extract_text()
338
- if text:
339
- content += text + "\n"
340
- return content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  except Exception as e:
342
- self.logger.warning(f"Failed to extract PDF content: {e}")
343
- return ""
344
 
345
- async def _serialize_rag_state(self, rag: LightRAG) -> Dict[str, Any]:
346
- """Serialize LightRAG state for database storage"""
 
347
  try:
348
- graph_storage = rag.graph_storage
349
- vector_storage = rag.vector_storage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
 
 
 
 
 
 
 
351
  # Extract graph data
352
  graph_data = {"nodes": [], "edges": [], "graph_attrs": {}}
353
- if hasattr(graph_storage, 'graph') and graph_storage.graph:
354
  graph_data = {
355
- "nodes": list(graph_storage.graph.nodes(data=True)),
356
- "edges": list(graph_storage.graph.edges(data=True)),
357
- "graph_attrs": dict(graph_storage.graph.graph)
358
  }
359
 
360
  # Extract vector data
361
  vector_data = {"embeddings": [], "metadata": [], "dimension": 1024}
362
- if hasattr(vector_storage, 'embeddings') and vector_storage.embeddings is not None:
363
- vector_data["embeddings"] = vector_storage.embeddings.tolist()
364
- if hasattr(vector_storage, 'metadata'):
365
- vector_data["metadata"] = getattr(vector_storage, 'metadata', [])
366
- if hasattr(vector_storage, 'dimension'):
367
- vector_data["dimension"] = getattr(vector_storage, 'dimension', 1024)
368
 
369
  # Configuration
370
  config_data = {
@@ -387,7 +671,7 @@ class ProductionLightRAGManager:
387
  raise
388
 
389
  async def _deserialize_rag_state(self, rag_state: Dict[str, Any]) -> LightRAG:
390
- """Deserialize RAG state and reconstruct LightRAG instance"""
391
  try:
392
  config = rag_state["config"]
393
 
@@ -395,7 +679,7 @@ class ProductionLightRAGManager:
395
  working_dir = f"/tmp/rag_restored_{uuid.uuid4()}"
396
  os.makedirs(working_dir, exist_ok=True)
397
 
398
- # Create new RAG instance
399
  rag = LightRAG(
400
  working_dir=working_dir,
401
  max_parallel_insert=config.get("max_parallel_insert", 2),
@@ -419,8 +703,6 @@ class ProductionLightRAGManager:
419
  rag.graph_storage.graph.add_nodes_from(graph_data["nodes"])
420
  if graph_data["edges"] and hasattr(rag.graph_storage, 'graph'):
421
  rag.graph_storage.graph.add_edges_from(graph_data["edges"])
422
- if graph_data["graph_attrs"] and hasattr(rag.graph_storage, 'graph'):
423
- rag.graph_storage.graph.graph.update(graph_data["graph_attrs"])
424
 
425
  # Restore vectors
426
  vector_data = rag_state["vectors"]
@@ -428,8 +710,6 @@ class ProductionLightRAGManager:
428
  rag.vector_storage.embeddings = np.array(vector_data["embeddings"])
429
  if hasattr(rag.vector_storage, 'metadata'):
430
  rag.vector_storage.metadata = vector_data["metadata"]
431
- if hasattr(rag.vector_storage, 'dimension'):
432
- rag.vector_storage.dimension = vector_data["dimension"]
433
 
434
  return rag
435
 
@@ -447,43 +727,43 @@ class ProductionLightRAGManager:
447
  mode: str = "hybrid",
448
  max_memory_turns: int = 10
449
  ) -> str:
450
- """Query RAG with database-persisted conversation memory"""
451
 
452
  try:
453
  # Get RAG instance
454
  rag = await self.get_or_create_rag_instance(ai_type, user_id, ai_id)
455
 
456
- # Get conversation messages from database
457
  messages = await self.db.get_conversation_messages(conversation_id)
458
 
459
- # Build context with conversation memory
460
  context_prompt = self._build_context_prompt(question, messages[-max_memory_turns*2:])
461
 
462
  # Query LightRAG
463
  response = await rag.aquery(context_prompt, QueryParam(mode=mode))
464
 
465
- # Save conversation messages to database
466
- await self.db.save_conversation_message(conversation_id, "user", question)
467
- await self.db.save_conversation_message(conversation_id, "assistant", response, {
468
- "mode": mode,
469
- "ai_type": ai_type,
470
- "user_id": user_id,
471
- "ai_id": ai_id
472
- })
 
 
 
 
 
 
 
 
473
 
474
  return response
475
 
476
  except Exception as e:
477
  self.logger.error(f"Query with memory failed: {e}")
478
-
479
- # Fallback to direct query
480
- try:
481
- rag = await self.get_or_create_rag_instance(ai_type, user_id, ai_id)
482
- response = await rag.aquery(question, QueryParam(mode=mode))
483
- return response
484
- except Exception as fallback_error:
485
- self.logger.error(f"Fallback query also failed: {fallback_error}")
486
- return "I apologize, but I'm experiencing technical difficulties. Please try again later."
487
 
488
  def _build_context_prompt(self, question: str, messages: List[Dict[str, Any]]) -> str:
489
  """Build context prompt with conversation memory"""
@@ -506,7 +786,7 @@ class ProductionLightRAGManager:
506
  description: str,
507
  uploaded_files: List[Dict[str, Any]]
508
  ) -> str:
509
- """Create custom AI with database storage"""
510
 
511
  ai_id = str(uuid.uuid4())
512
 
@@ -526,30 +806,16 @@ class ProductionLightRAGManager:
526
  if combined_content.strip():
527
  await rag.ainsert(combined_content)
528
 
529
- # Save to database with updated knowledge
530
- await self._save_to_database(
531
  ai_type="custom",
532
  user_id=user_id,
533
  ai_id=ai_id,
534
  name=ai_name,
535
- description=description,
536
- rag_instance=rag
537
  )
538
 
539
- # Save file metadata to database
540
- rag_data = await self.db.load_complete_rag_instance("custom", user_id, ai_id)
541
- if rag_data:
542
- rag_instance_id = rag_data['metadata']['id']
543
-
544
- for file_data in uploaded_files:
545
- await self.db.save_knowledge_file(
546
- rag_instance_id=rag_instance_id,
547
- filename=file_data['filename'],
548
- original_filename=file_data['filename'],
549
- file_type=file_data.get('type', 'unknown'),
550
- file_size=file_data.get('size', 0),
551
- content_text=file_data.get('content', '')
552
- )
553
 
554
  return ai_id
555
 
@@ -557,39 +823,78 @@ class ProductionLightRAGManager:
557
  self.logger.error(f"Failed to create custom AI: {e}")
558
  raise
559
 
560
- async def get_user_ais(self, user_id: str) -> List[Dict[str, Any]]:
561
- return await self.db.list_user_rag_instances(user_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
 
563
  async def cleanup(self):
564
  """Clean up resources"""
565
  self.rag_instances.clear()
566
  self.conversation_memory.clear()
567
- self.processing_lock.clear()
568
  await self.db.close()
569
  self.logger.info("LightRAG manager cleaned up")
570
 
571
  # Global instance
572
- lightrag_manager: Optional[ProductionLightRAGManager] = None
573
-
574
- async def initialize_lightrag_manager(
575
- cloudflare_worker: CloudflareWorker,
576
- database_url: str,
577
- vercel_blob_token: str
578
- ) -> ProductionLightRAGManager:
579
- """Initialize the production LightRAG manager"""
580
  global lightrag_manager
581
 
582
  if lightrag_manager is None:
583
- # Initialize database
584
- db_manager = DatabaseManager(database_url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
  await db_manager.connect()
586
 
587
  # Initialize blob client
588
- blob_client = VercelBlobClient(vercel_blob_token)
589
 
590
  # Create manager
591
- lightrag_manager = ProductionLightRAGManager(
592
  cloudflare_worker, db_manager, blob_client
593
  )
594
 
 
 
 
 
595
  return lightrag_manager
 
9
  from datetime import datetime
10
  import uuid
11
  import httpx
12
+ import base64
13
+ from dataclasses import dataclass
14
 
15
  # LightRAG imports
16
  from lightrag import LightRAG, QueryParam
17
  from lightrag.utils import EmbeddingFunc
18
 
19
+ # Database imports
20
+ import asyncpg
21
+ from redis import Redis
22
+
23
+ # Environment validation
24
+ REQUIRED_ENV_VARS = [
25
+ 'CLOUDFLARE_API_KEY',
26
+ 'CLOUDFLARE_ACCOUNT_ID',
27
+ 'DATABASE_URL',
28
+ 'BLOB_READ_WRITE_TOKEN',
29
+ 'REDIS_URL',
30
+ 'JWT_SECRET'
31
+ ]
32
+
33
+ class EnvironmentError(Exception):
34
+ """Raised when required environment variables are missing"""
35
+ pass
36
+
37
+ def validate_environment():
38
+ """Validate all required environment variables are present"""
39
+ missing_vars = []
40
+ for var in REQUIRED_ENV_VARS:
41
+ if not os.getenv(var):
42
+ missing_vars.append(var)
43
+
44
+ if missing_vars:
45
+ raise EnvironmentError(f"Missing required environment variables: {', '.join(missing_vars)}")
46
+
47
+ @dataclass
48
+ class RAGConfig:
49
+ """Configuration for RAG instances"""
50
+ ai_type: str
51
+ user_id: Optional[str] = None
52
+ ai_id: Optional[str] = None
53
+ name: Optional[str] = None
54
+ description: Optional[str] = None
55
+
56
+ def get_cache_key(self) -> str:
57
+ """Generate cache key for this RAG configuration"""
58
+ return f"rag_{self.ai_type}_{self.user_id or 'system'}_{self.ai_id or 'default'}"
59
 
60
  class CloudflareWorker:
61
+ """Cloudflare Workers AI integration with proper LightRAG compatibility"""
62
+
63
  def __init__(
64
  self,
65
  cloudflare_api_key: str,
 
75
  self.embedding_model_name = embedding_model_name
76
  self.max_tokens = max_tokens
77
  self.max_response_tokens = max_response_tokens
78
+ self.logger = logging.getLogger(__name__)
79
 
80
+ async def _send_request(self, model_name: str, input_: dict) -> Any:
81
+ """Send request to Cloudflare Workers AI"""
82
  headers = {"Authorization": f"Bearer {self.cloudflare_api_key}"}
83
 
84
  try:
85
+ async with httpx.AsyncClient(timeout=30.0) as client:
86
  response = await client.post(
87
  f"{self.api_base_url}{model_name}",
88
  headers=headers,
89
+ json=input_
 
90
  )
91
  response.raise_for_status()
92
  response_data = response.json()
93
 
94
  result = response_data.get("result", {})
95
 
96
+ # Handle embedding response
97
  if "data" in result:
98
  return np.array(result["data"])
99
 
100
+ # Handle LLM response
101
  if "response" in result:
102
  return result["response"]
103
 
104
  raise ValueError("Unexpected Cloudflare response format")
105
 
106
  except Exception as e:
107
+ self.logger.error(f"Cloudflare API error: {e}")
108
  raise
109
 
110
  async def query(self, prompt: str, system_prompt: str = "", **kwargs) -> str:
111
+ """
112
+ LightRAG-compatible query method
113
+ Fixed to handle LightRAG's parameter expectations
114
+ """
115
+ # Filter out LightRAG-specific parameters that shouldn't go to Cloudflare
116
+ filtered_kwargs = {
117
+ k: v for k, v in kwargs.items()
118
+ if k not in ['hashing_kv', 'history_messages', 'global_kv', 'text_chunks']
119
+ }
120
 
121
+ messages = [
122
+ {"role": "system", "content": system_prompt or "You are a helpful AI assistant."},
123
  {"role": "user", "content": prompt},
124
  ]
125
 
126
+ input_data = {
127
+ "messages": messages,
128
  "max_tokens": self.max_tokens,
129
+ **filtered_kwargs
130
  }
131
 
132
+ return await self._send_request(self.llm_model_name, input_data)
133
 
134
  async def embedding_chunk(self, texts: List[str]) -> np.ndarray:
135
+ """Generate embeddings for text chunks"""
136
+ input_data = {
137
  "text": texts,
138
  "max_tokens": self.max_tokens,
 
139
  }
140
 
141
+ return await self._send_request(self.embedding_model_name, input_data)
142
 
143
  class VercelBlobClient:
144
+ """Vercel Blob storage client for RAG state persistence"""
145
 
146
  def __init__(self, token: str):
147
  self.token = token
148
  self.logger = logging.getLogger(__name__)
149
 
150
  async def put(self, filename: str, data: bytes) -> str:
151
+ """Upload data to Vercel Blob"""
152
  try:
153
+ async with httpx.AsyncClient(timeout=120.0) as client:
154
  response = await client.put(
155
  f"https://blob.vercel-storage.com/{filename}",
156
  headers={"Authorization": f"Bearer {self.token}"},
157
+ content=data
 
158
  )
159
  response.raise_for_status()
160
+ result = response.json()
161
+ return result.get('url', f"https://blob.vercel-storage.com/{filename}")
162
  except Exception as e:
163
  self.logger.error(f"Failed to upload to Vercel Blob: {e}")
164
  raise
165
 
166
+ async def get(self, url: str) -> bytes:
167
+ """Download data from Vercel Blob"""
168
+ try:
169
+ async with httpx.AsyncClient(timeout=120.0) as client:
170
+ response = await client.get(url)
171
+ response.raise_for_status()
172
+ return response.content
173
+ except Exception as e:
174
+ self.logger.error(f"Failed to download from Vercel Blob: {e}")
175
+ raise
176
+
177
+ class DatabaseManager:
178
+ """Database manager with complete RAG persistence"""
179
+
180
+ def __init__(self, database_url: str, redis_url: str):
181
+ self.database_url = database_url
182
+ self.redis_url = redis_url
183
+ self.pool = None
184
+ self.redis = None
185
+ self.logger = logging.getLogger(__name__)
186
+
187
+ async def connect(self):
188
+ """Initialize database connections"""
189
+ try:
190
+ # PostgreSQL connection pool
191
+ self.pool = await asyncpg.create_pool(
192
+ self.database_url,
193
+ min_size=2,
194
+ max_size=20,
195
+ command_timeout=60
196
+ )
197
+
198
+ # Redis connection
199
+ self.redis = Redis.from_url(self.redis_url, decode_responses=True)
200
+
201
+ self.logger.info("Database connections established successfully")
202
+
203
+ # Create tables if they don't exist
204
+ await self._create_tables()
205
+
206
+ except Exception as e:
207
+ self.logger.error(f"Database connection failed: {e}")
208
+ raise
209
+
210
+ async def _create_tables(self):
211
+ """Create necessary tables for RAG persistence"""
212
+ async with self.pool.acquire() as conn:
213
+ await conn.execute("""
214
+ CREATE TABLE IF NOT EXISTS rag_instances (
215
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
216
+ ai_type VARCHAR(50) NOT NULL,
217
+ user_id VARCHAR(100),
218
+ ai_id VARCHAR(100),
219
+ name VARCHAR(255) NOT NULL,
220
+ description TEXT,
221
+
222
+ -- Blob storage URLs
223
+ graph_blob_url TEXT,
224
+ vector_blob_url TEXT,
225
+ config_blob_url TEXT,
226
+
227
+ -- Metadata
228
+ total_chunks INTEGER DEFAULT 0,
229
+ total_tokens INTEGER DEFAULT 0,
230
+ file_count INTEGER DEFAULT 0,
231
+
232
+ -- Timestamps
233
+ created_at TIMESTAMP DEFAULT NOW(),
234
+ updated_at TIMESTAMP DEFAULT NOW(),
235
+ last_accessed_at TIMESTAMP DEFAULT NOW(),
236
+
237
+ -- Status
238
+ status VARCHAR(20) DEFAULT 'active',
239
+
240
+ UNIQUE(ai_type, user_id, ai_id)
241
+ );
242
+
243
+ CREATE TABLE IF NOT EXISTS knowledge_files (
244
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
245
+ rag_instance_id UUID REFERENCES rag_instances(id) ON DELETE CASCADE,
246
+ filename VARCHAR(255) NOT NULL,
247
+ original_filename VARCHAR(255),
248
+ file_type VARCHAR(50),
249
+ file_size INTEGER,
250
+ blob_url TEXT,
251
+ content_text TEXT,
252
+ processed_at TIMESTAMP DEFAULT NOW(),
253
+ processing_status VARCHAR(20) DEFAULT 'processed',
254
+ token_count INTEGER DEFAULT 0,
255
+ created_at TIMESTAMP DEFAULT NOW()
256
+ );
257
+
258
+ CREATE TABLE IF NOT EXISTS conversations (
259
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
260
+ user_id VARCHAR(100) NOT NULL,
261
+ ai_type VARCHAR(50) NOT NULL,
262
+ ai_id VARCHAR(100),
263
+ title VARCHAR(255),
264
+ created_at TIMESTAMP DEFAULT NOW(),
265
+ updated_at TIMESTAMP DEFAULT NOW(),
266
+ is_active BOOLEAN DEFAULT TRUE
267
+ );
268
+
269
+ CREATE TABLE IF NOT EXISTS conversation_messages (
270
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
271
+ conversation_id UUID REFERENCES conversations(id) ON DELETE CASCADE,
272
+ role VARCHAR(20) NOT NULL,
273
+ content TEXT NOT NULL,
274
+ metadata JSONB DEFAULT '{}',
275
+ created_at TIMESTAMP DEFAULT NOW()
276
+ );
277
+
278
+ -- Indexes for performance
279
+ CREATE INDEX IF NOT EXISTS idx_rag_instances_lookup ON rag_instances(ai_type, user_id, ai_id);
280
+ CREATE INDEX IF NOT EXISTS idx_conversations_user ON conversations(user_id);
281
+ CREATE INDEX IF NOT EXISTS idx_conversation_messages_conv ON conversation_messages(conversation_id);
282
+ """)
283
+
284
+ self.logger.info("Database tables created/verified successfully")
285
+
286
+ async def save_rag_instance(
287
+ self,
288
+ config: RAGConfig,
289
+ graph_blob_url: str,
290
+ vector_blob_url: str,
291
+ config_blob_url: str,
292
+ metadata: Dict[str, Any]
293
+ ) -> str:
294
+ """Save RAG instance metadata to database"""
295
+ async with self.pool.acquire() as conn:
296
+ rag_instance_id = await conn.fetchval("""
297
+ INSERT INTO rag_instances (
298
+ ai_type, user_id, ai_id, name, description,
299
+ graph_blob_url, vector_blob_url, config_blob_url,
300
+ total_chunks, total_tokens, file_count
301
+ ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
302
+ ON CONFLICT (ai_type, user_id, ai_id) DO UPDATE SET
303
+ name = EXCLUDED.name,
304
+ description = EXCLUDED.description,
305
+ graph_blob_url = EXCLUDED.graph_blob_url,
306
+ vector_blob_url = EXCLUDED.vector_blob_url,
307
+ config_blob_url = EXCLUDED.config_blob_url,
308
+ total_chunks = EXCLUDED.total_chunks,
309
+ total_tokens = EXCLUDED.total_tokens,
310
+ file_count = EXCLUDED.file_count,
311
+ updated_at = NOW()
312
+ RETURNING id;
313
+ """,
314
+ config.ai_type, config.user_id, config.ai_id,
315
+ config.name, config.description,
316
+ graph_blob_url, vector_blob_url, config_blob_url,
317
+ metadata.get('total_chunks', 0),
318
+ metadata.get('total_tokens', 0),
319
+ metadata.get('file_count', 0)
320
+ )
321
+
322
+ return str(rag_instance_id)
323
+
324
+ async def get_rag_instance(self, config: RAGConfig) -> Optional[Dict[str, Any]]:
325
+ """Get RAG instance from database"""
326
+ async with self.pool.acquire() as conn:
327
+ result = await conn.fetchrow("""
328
+ SELECT id, ai_type, user_id, ai_id, name, description,
329
+ graph_blob_url, vector_blob_url, config_blob_url,
330
+ total_chunks, total_tokens, file_count,
331
+ created_at, updated_at, last_accessed_at, status
332
+ FROM rag_instances
333
+ WHERE ai_type = $1 AND user_id = $2 AND ai_id = $3 AND status = 'active'
334
+ """, config.ai_type, config.user_id, config.ai_id)
335
+
336
+ if result:
337
+ # Update last accessed time
338
+ await conn.execute("""
339
+ UPDATE rag_instances SET last_accessed_at = NOW() WHERE id = $1
340
+ """, result['id'])
341
+
342
+ return dict(result)
343
+
344
+ return None
345
+
346
+ async def save_conversation_message(
347
+ self,
348
+ conversation_id: str,
349
+ role: str,
350
+ content: str,
351
+ metadata: Optional[Dict[str, Any]] = None
352
+ ) -> str:
353
+ """Save conversation message to database"""
354
+ async with self.pool.acquire() as conn:
355
+ # Create conversation if it doesn't exist
356
+ await conn.execute("""
357
+ INSERT INTO conversations (id, user_id, ai_type, ai_id, title)
358
+ VALUES ($1, $2, $3, $4, $5)
359
+ ON CONFLICT (id) DO NOTHING
360
+ """, conversation_id,
361
+ metadata.get('user_id', 'anonymous'),
362
+ metadata.get('ai_type', 'unknown'),
363
+ metadata.get('ai_id'),
364
+ metadata.get('title', 'New Conversation')
365
+ )
366
+
367
+ # Save message
368
+ message_id = await conn.fetchval("""
369
+ INSERT INTO conversation_messages (conversation_id, role, content, metadata)
370
+ VALUES ($1, $2, $3, $4)
371
+ RETURNING id
372
+ """, conversation_id, role, content, json.dumps(metadata or {}))
373
+
374
+ return str(message_id)
375
+
376
+ async def get_conversation_messages(
377
+ self,
378
+ conversation_id: str,
379
+ limit: int = 50
380
+ ) -> List[Dict[str, Any]]:
381
+ """Get conversation messages from database"""
382
+ async with self.pool.acquire() as conn:
383
+ messages = await conn.fetch("""
384
+ SELECT id, role, content, metadata, created_at
385
+ FROM conversation_messages
386
+ WHERE conversation_id = $1
387
+ ORDER BY created_at DESC
388
+ LIMIT $2
389
+ """, conversation_id, limit)
390
+
391
+ return [dict(msg) for msg in reversed(messages)]
392
+
393
+ async def close(self):
394
+ """Close database connections"""
395
+ if self.pool:
396
+ await self.pool.close()
397
+ if self.redis:
398
+ self.redis.close()
399
+
400
+ class PersistentLightRAGManager:
401
+ """
402
+ Complete LightRAG manager with Vercel-only persistence
403
+ Zero dependency on HuggingFace ephemeral storage
404
+ """
405
 
406
  def __init__(
407
  self,
408
  cloudflare_worker: CloudflareWorker,
409
+ database_manager: DatabaseManager,
410
+ blob_client: VercelBlobClient
411
  ):
412
  self.cloudflare_worker = cloudflare_worker
413
  self.db = database_manager
414
+ self.blob_client = blob_client
415
  self.rag_instances: Dict[str, LightRAG] = {}
416
+ self.processing_locks: Dict[str, asyncio.Lock] = {}
417
+ self.conversation_memory: Dict[str, List[Dict[str, Any]]] = {}
418
  self.logger = logging.getLogger(__name__)
419
+
420
  async def get_or_create_rag_instance(
421
  self,
422
  ai_type: str,
 
425
  name: Optional[str] = None,
426
  description: Optional[str] = None
427
  ) -> LightRAG:
428
+ """Get or create RAG instance with complete Vercel persistence"""
429
 
430
+ config = RAGConfig(
431
+ ai_type=ai_type,
432
+ user_id=user_id,
433
+ ai_id=ai_id,
434
+ name=name or f"{ai_type} AI",
435
+ description=description
436
+ )
437
+
438
+ cache_key = config.get_cache_key()
439
 
440
+ # Check memory cache
441
  if cache_key in self.rag_instances:
442
  self.logger.info(f"Returning cached RAG instance: {cache_key}")
443
  return self.rag_instances[cache_key]
444
 
445
+ # Ensure thread safety
446
+ if cache_key not in self.processing_locks:
447
+ self.processing_locks[cache_key] = asyncio.Lock()
448
 
449
+ async with self.processing_locks[cache_key]:
450
  # Double-check after acquiring lock
451
  if cache_key in self.rag_instances:
452
  return self.rag_instances[cache_key]
 
455
 
456
  # Try to load from database
457
  try:
458
+ rag_instance = await self._load_from_database(config)
459
  if rag_instance:
460
  self.rag_instances[cache_key] = rag_instance
461
  self.logger.info(f"Loaded RAG instance from database: {cache_key}")
 
464
  self.logger.warning(f"Failed to load RAG from database: {e}")
465
 
466
  # Create new instance
467
+ rag_instance = await self._create_new_rag_instance(config)
 
 
468
 
469
  # Save to database
470
+ await self._save_to_database(config, rag_instance)
 
 
 
 
471
 
472
  # Cache in memory
473
  self.rag_instances[cache_key] = rag_instance
 
475
 
476
  return rag_instance
477
 
478
+ async def _create_new_rag_instance(self, config: RAGConfig) -> LightRAG:
479
+ """Create new RAG instance with in-memory storage"""
 
 
 
 
 
 
 
 
 
 
480
 
481
+ # Create in-memory working directory structure
482
+ working_dir = f"/tmp/rag_memory_{config.get_cache_key()}_{uuid.uuid4()}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  os.makedirs(working_dir, exist_ok=True)
484
 
485
+ # Initialize LightRAG with memory-based storage
486
  rag = LightRAG(
487
  working_dir=working_dir,
488
  max_parallel_insert=2,
 
498
  vector_storage="NanoVectorDBStorage",
499
  )
500
 
501
+ # Initialize storage components
502
  await rag.initialize_storages()
503
 
504
  # Load knowledge based on AI type
505
+ if config.ai_type == "fire-safety":
506
  await self._load_fire_safety_knowledge(rag)
507
 
508
  return rag
509
 
510
  async def _load_fire_safety_knowledge(self, rag: LightRAG):
511
+ """Load fire safety knowledge from available sources"""
 
 
 
 
 
 
 
512
 
513
+ # Fire safety knowledge content
514
+ fire_safety_content = """
515
+ Fire Safety Regulations and Building Codes:
516
 
517
+ 1. Emergency Exits:
518
+ - Buildings must have at least two exits on each floor
519
+ - Maximum travel distance to exit: 75 feet in unsprinklered buildings, 100 feet in sprinklered buildings
520
+ - Exit doors must swing in direction of travel
521
+ - Exits must be clearly marked and illuminated
522
+
523
+ 2. Fire Extinguishers:
524
+ - Type A: Ordinary combustibles (wood, paper, cloth)
525
+ - Type B: Flammable liquids (gasoline, oil, paint)
526
+ - Type C: Electrical equipment
527
+ - Type D: Combustible metals
528
+ - Type K: Cooking oils and fats
529
+
530
+ 3. Fire Detection Systems:
531
+ - Smoke detectors required in all sleeping areas
532
+ - Heat detectors in areas where smoke detectors are not suitable
533
+ - Manual fire alarm pull stations near exits
534
+ - Central monitoring systems in commercial buildings
535
+
536
+ 4. Sprinkler Systems:
537
+ - Required in buildings over certain heights
538
+ - Wet pipe systems most common
539
+ - Dry pipe systems in areas subject to freezing
540
+ - Deluge systems for high-hazard areas
541
 
542
+ 5. Emergency Lighting:
543
+ - Required in all exit routes
544
+ - Must provide minimum 1 foot-candle illumination
545
+ - Battery backup required for minimum 90 minutes
546
+ - Monthly testing required
547
+
548
+ 6. Fire Doors:
549
+ - Must be self-closing and self-latching
550
+ - Fire rating must match wall rating
551
+ - Annual inspection required
552
+ - No propping open unless connected to fire alarm system
553
+
554
+ 7. Occupancy Limits:
555
+ - Based on building type and exit capacity
556
+ - Assembly: 7 sq ft per person (concentrated use)
557
+ - Business: 100 sq ft per person
558
+ - Educational: 20 sq ft per person
559
+ - Industrial: 100 sq ft per person
560
+ """
561
+
562
+ self.logger.info("Loading fire safety knowledge base")
563
+ await rag.ainsert(fire_safety_content)
564
+ self.logger.info("Fire safety knowledge loaded successfully")
565
 
566
+ async def _save_to_database(self, config: RAGConfig, rag: LightRAG):
567
+ """Save RAG instance to Vercel Blob + Database"""
568
+
569
  try:
570
+ # Serialize RAG state
571
+ rag_state = await self._serialize_rag_state(rag)
572
+
573
+ # Create blob files
574
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
575
+ base_filename = f"rag_{config.ai_type}_{config.user_id or 'system'}_{config.ai_id or 'default'}_{timestamp}"
576
+
577
+ # Save to Vercel Blob
578
+ graph_data = gzip.compress(pickle.dumps(rag_state['graph']))
579
+ vector_data = gzip.compress(pickle.dumps(rag_state['vectors']))
580
+ config_data = gzip.compress(pickle.dumps(rag_state['config']))
581
+
582
+ graph_blob_url = await self.blob_client.put(f"{base_filename}_graph.pkl.gz", graph_data)
583
+ vector_blob_url = await self.blob_client.put(f"{base_filename}_vectors.pkl.gz", vector_data)
584
+ config_blob_url = await self.blob_client.put(f"{base_filename}_config.pkl.gz", config_data)
585
+
586
+ # Save metadata to database
587
+ metadata = {
588
+ 'total_chunks': len(rag_state['vectors'].get('embeddings', [])),
589
+ 'total_tokens': self._estimate_tokens(rag_state),
590
+ 'file_count': 1 if config.ai_type == 'fire-safety' else 0
591
+ }
592
+
593
+ await self.db.save_rag_instance(
594
+ config, graph_blob_url, vector_blob_url, config_blob_url, metadata
595
+ )
596
+
597
+ self.logger.info(f"Successfully saved RAG to Vercel storage: {config.ai_type}")
598
+
599
  except Exception as e:
600
+ self.logger.error(f"Failed to save RAG to database: {e}")
601
+ raise
602
 
603
+ async def _load_from_database(self, config: RAGConfig) -> Optional[LightRAG]:
604
+ """Load RAG instance from database + Vercel Blob"""
605
+
606
  try:
607
+ # Get metadata from database
608
+ instance_data = await self.db.get_rag_instance(config)
609
+ if not instance_data:
610
+ return None
611
+
612
+ # Download from Vercel Blob
613
+ graph_data = await self.blob_client.get(instance_data['graph_blob_url'])
614
+ vector_data = await self.blob_client.get(instance_data['vector_blob_url'])
615
+ config_data = await self.blob_client.get(instance_data['config_blob_url'])
616
+
617
+ # Deserialize
618
+ rag_state = {
619
+ 'graph': pickle.loads(gzip.decompress(graph_data)),
620
+ 'vectors': pickle.loads(gzip.decompress(vector_data)),
621
+ 'config': pickle.loads(gzip.decompress(config_data))
622
+ }
623
+
624
+ # Reconstruct RAG instance
625
+ rag = await self._deserialize_rag_state(rag_state)
626
+
627
+ self.logger.info(f"Successfully loaded RAG from Vercel storage: {config.ai_type}")
628
+ return rag
629
 
630
+ except Exception as e:
631
+ self.logger.error(f"Failed to load RAG from database: {e}")
632
+ return None
633
+
634
+ async def _serialize_rag_state(self, rag: LightRAG) -> Dict[str, Any]:
635
+ """Serialize RAG state for storage"""
636
+ try:
637
  # Extract graph data
638
  graph_data = {"nodes": [], "edges": [], "graph_attrs": {}}
639
+ if hasattr(rag.graph_storage, 'graph') and rag.graph_storage.graph:
640
  graph_data = {
641
+ "nodes": list(rag.graph_storage.graph.nodes(data=True)),
642
+ "edges": list(rag.graph_storage.graph.edges(data=True)),
643
+ "graph_attrs": dict(rag.graph_storage.graph.graph)
644
  }
645
 
646
  # Extract vector data
647
  vector_data = {"embeddings": [], "metadata": [], "dimension": 1024}
648
+ if hasattr(rag.vector_storage, 'embeddings') and rag.vector_storage.embeddings is not None:
649
+ vector_data["embeddings"] = rag.vector_storage.embeddings.tolist()
650
+ if hasattr(rag.vector_storage, 'metadata'):
651
+ vector_data["metadata"] = getattr(rag.vector_storage, 'metadata', [])
 
 
652
 
653
  # Configuration
654
  config_data = {
 
671
  raise
672
 
673
  async def _deserialize_rag_state(self, rag_state: Dict[str, Any]) -> LightRAG:
674
+ """Deserialize RAG state and reconstruct LightRAG"""
675
  try:
676
  config = rag_state["config"]
677
 
 
679
  working_dir = f"/tmp/rag_restored_{uuid.uuid4()}"
680
  os.makedirs(working_dir, exist_ok=True)
681
 
682
+ # Create RAG instance
683
  rag = LightRAG(
684
  working_dir=working_dir,
685
  max_parallel_insert=config.get("max_parallel_insert", 2),
 
703
  rag.graph_storage.graph.add_nodes_from(graph_data["nodes"])
704
  if graph_data["edges"] and hasattr(rag.graph_storage, 'graph'):
705
  rag.graph_storage.graph.add_edges_from(graph_data["edges"])
 
 
706
 
707
  # Restore vectors
708
  vector_data = rag_state["vectors"]
 
710
  rag.vector_storage.embeddings = np.array(vector_data["embeddings"])
711
  if hasattr(rag.vector_storage, 'metadata'):
712
  rag.vector_storage.metadata = vector_data["metadata"]
 
 
713
 
714
  return rag
715
 
 
727
  mode: str = "hybrid",
728
  max_memory_turns: int = 10
729
  ) -> str:
730
+ """Query RAG with conversation memory"""
731
 
732
  try:
733
  # Get RAG instance
734
  rag = await self.get_or_create_rag_instance(ai_type, user_id, ai_id)
735
 
736
+ # Get conversation memory
737
  messages = await self.db.get_conversation_messages(conversation_id)
738
 
739
+ # Build context with memory
740
  context_prompt = self._build_context_prompt(question, messages[-max_memory_turns*2:])
741
 
742
  # Query LightRAG
743
  response = await rag.aquery(context_prompt, QueryParam(mode=mode))
744
 
745
+ # Save conversation
746
+ await self.db.save_conversation_message(
747
+ conversation_id, "user", question, {
748
+ "user_id": user_id,
749
+ "ai_type": ai_type,
750
+ "ai_id": ai_id
751
+ }
752
+ )
753
+ await self.db.save_conversation_message(
754
+ conversation_id, "assistant", response, {
755
+ "mode": mode,
756
+ "ai_type": ai_type,
757
+ "user_id": user_id,
758
+ "ai_id": ai_id
759
+ }
760
+ )
761
 
762
  return response
763
 
764
  except Exception as e:
765
  self.logger.error(f"Query with memory failed: {e}")
766
+ return "I apologize, but I'm experiencing technical difficulties. Please try again later."
 
 
 
 
 
 
 
 
767
 
768
  def _build_context_prompt(self, question: str, messages: List[Dict[str, Any]]) -> str:
769
  """Build context prompt with conversation memory"""
 
786
  description: str,
787
  uploaded_files: List[Dict[str, Any]]
788
  ) -> str:
789
+ """Create custom AI with uploaded files"""
790
 
791
  ai_id = str(uuid.uuid4())
792
 
 
806
  if combined_content.strip():
807
  await rag.ainsert(combined_content)
808
 
809
+ # Save to database
810
+ config = RAGConfig(
811
  ai_type="custom",
812
  user_id=user_id,
813
  ai_id=ai_id,
814
  name=ai_name,
815
+ description=description
 
816
  )
817
 
818
+ await self._save_to_database(config, rag)
 
 
 
 
 
 
 
 
 
 
 
 
 
819
 
820
  return ai_id
821
 
 
823
  self.logger.error(f"Failed to create custom AI: {e}")
824
  raise
825
 
826
+ def _estimate_tokens(self, rag_state: Dict[str, Any]) -> int:
827
+ """Estimate token count from RAG state"""
828
+ try:
829
+ content_size = len(json.dumps(rag_state))
830
+ return content_size // 4 # Rough estimate: 4 chars per token
831
+ except:
832
+ return 0
833
+
834
+ def get_conversation_memory_status(self, conversation_id: str) -> Dict[str, Any]:
835
+ """Get conversation memory status"""
836
+ if conversation_id in self.conversation_memory:
837
+ return {
838
+ "has_memory": True,
839
+ "message_count": len(self.conversation_memory[conversation_id]),
840
+ "last_updated": datetime.now().isoformat()
841
+ }
842
+ return {"has_memory": False, "message_count": 0}
843
+
844
+ def clear_conversation_memory(self, conversation_id: str):
845
+ """Clear conversation memory"""
846
+ if conversation_id in self.conversation_memory:
847
+ del self.conversation_memory[conversation_id]
848
 
849
  async def cleanup(self):
850
  """Clean up resources"""
851
  self.rag_instances.clear()
852
  self.conversation_memory.clear()
853
+ self.processing_locks.clear()
854
  await self.db.close()
855
  self.logger.info("LightRAG manager cleaned up")
856
 
857
  # Global instance
858
+ lightrag_manager: Optional[PersistentLightRAGManager] = None
859
+
860
+ async def initialize_lightrag_manager() -> PersistentLightRAGManager:
861
+ """Initialize the persistent LightRAG manager"""
 
 
 
 
862
  global lightrag_manager
863
 
864
  if lightrag_manager is None:
865
+ # Validate environment
866
+ validate_environment()
867
+
868
+ # Get environment variables
869
+ cloudflare_api_key = os.getenv("CLOUDFLARE_API_KEY")
870
+ cloudflare_account_id = os.getenv("CLOUDFLARE_ACCOUNT_ID")
871
+ database_url = os.getenv("DATABASE_URL")
872
+ redis_url = os.getenv("REDIS_URL")
873
+ blob_token = os.getenv("BLOB_READ_WRITE_TOKEN")
874
+
875
+ # Initialize Cloudflare worker
876
+ api_base_url = f"https://api.cloudflare.com/client/v4/accounts/{cloudflare_account_id}/ai/run/"
877
+ cloudflare_worker = CloudflareWorker(
878
+ cloudflare_api_key=cloudflare_api_key,
879
+ api_base_url=api_base_url,
880
+ llm_model_name="@cf/meta/llama-3.2-3b-instruct",
881
+ embedding_model_name="@cf/baai/bge-m3"
882
+ )
883
+
884
+ # Initialize database manager
885
+ db_manager = DatabaseManager(database_url, redis_url)
886
  await db_manager.connect()
887
 
888
  # Initialize blob client
889
+ blob_client = VercelBlobClient(blob_token)
890
 
891
  # Create manager
892
+ lightrag_manager = PersistentLightRAGManager(
893
  cloudflare_worker, db_manager, blob_client
894
  )
895
 
896
+ return lightrag_manager
897
+
898
+ def get_lightrag_manager() -> Optional[PersistentLightRAGManager]:
899
+ """Get the current LightRAG manager instance"""
900
  return lightrag_manager