import asyncio import os import json import logging from typing import Dict, List, Optional, Any from pathlib import Path import hashlib from datetime import datetime from lightrag import LightRAG, QueryParam from lightrag.utils import EmbeddingFunc from lightrag.kg.shared_storage import initialize_pipeline_status # Your existing cloudflare worker (keep this as is) class CloudflareWorker: def __init__( self, cloudflare_api_key: str, api_base_url: str, llm_model_name: str, embedding_model_name: str, max_tokens: int = 4080, max_response_tokens: int = 4080, ): self.cloudflare_api_key = cloudflare_api_key self.api_base_url = api_base_url self.llm_model_name = llm_model_name self.embedding_model_name = embedding_model_name self.max_tokens = max_tokens self.max_response_tokens = max_response_tokens async def _send_request(self, model_name: str, input_: dict, debug_log: str = ""): import requests import numpy as np headers = {"Authorization": f"Bearer {self.cloudflare_api_key}"} try: response_raw = requests.post( f"{self.api_base_url}{model_name}", headers=headers, json=input_ ).json() result = response_raw.get("result", {}) if "data" in result: # Embedding case return np.array(result["data"]) if "response" in result: # LLM response return result["response"] raise ValueError("Unexpected Cloudflare response format") except Exception as e: logging.error(f"Cloudflare API error: {e}") return None async def query(self, prompt: str, system_prompt: str = "", **kwargs) -> str: # Remove LightRAG-specific kwargs kwargs.pop("hashing_kv", None) message = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ] input_ = { "messages": message, "max_tokens": self.max_tokens, "response_token_limit": self.max_response_tokens, } return await self._send_request(self.llm_model_name, input_) async def embedding_chunk(self, texts: List[str]) -> np.ndarray: input_ = { "text": texts, "max_tokens": self.max_tokens, "response_token_limit": self.max_response_tokens, } return await self._send_request(self.embedding_model_name, input_) class LightRAGManager: """Enhanced LightRAG Manager with conversation memory and multi-AI support""" def __init__(self, cloudflare_worker: CloudflareWorker, base_working_dir: str = "./lightrag_storage"): self.cloudflare_worker = cloudflare_worker self.base_working_dir = Path(base_working_dir) self.base_working_dir.mkdir(exist_ok=True) # Store multiple RAG instances self.rag_instances: Dict[str, LightRAG] = {} # Conversation memory store self.conversation_memory: Dict[str, List[Dict]] = {} # Initialize default fire safety RAG self.fire_safety_rag = None # Setup logging self.logger = logging.getLogger(__name__) async def initialize_fire_safety_rag(self) -> LightRAG: """Initialize the fire safety RAG instance""" if self.fire_safety_rag is not None: return self.fire_safety_rag working_dir = self.base_working_dir / "fire_safety" working_dir.mkdir(exist_ok=True) # Check if knowledge base exists knowledge_files = list(working_dir.glob("*.json")) if not knowledge_files: self.logger.warning("No existing fire safety knowledge base found") self.fire_safety_rag = await self._create_rag_instance(str(working_dir)) return self.fire_safety_rag async def create_custom_rag(self, user_id: str, ai_id: str, knowledge_texts: List[str]) -> LightRAG: """Create a custom RAG instance for a user's AI""" instance_key = f"{user_id}_{ai_id}" if instance_key in self.rag_instances: return self.rag_instances[instance_key] # Create working directory for this custom AI working_dir = self.base_working_dir / "custom" / user_id / ai_id working_dir.mkdir(parents=True, exist_ok=True) # Create RAG instance rag = await self._create_rag_instance(str(working_dir)) # Insert knowledge if provided if knowledge_texts: await self._insert_knowledge_batch(rag, knowledge_texts) self.rag_instances[instance_key] = rag return rag async def _create_rag_instance(self, working_dir: str) -> LightRAG: """Create a new LightRAG instance""" rag = LightRAG( working_dir=working_dir, max_parallel_insert=2, llm_model_func=self.cloudflare_worker.query, llm_model_name=self.cloudflare_worker.llm_model_name, llm_model_max_token_size=4080, embedding_func=EmbeddingFunc( embedding_dim=int(os.getenv("EMBEDDING_DIM", "1024")), max_token_size=int(os.getenv("MAX_EMBED_TOKENS", "2048")), func=self.cloudflare_worker.embedding_chunk, ), # Use NetworkX for simplicity (no external dependencies) graph_storage="NetworkXStorage", # Use NanoVectorDB for lightweight vector storage vector_storage="NanoVectorDBStorage", ) # Initialize storage await rag.initialize_storages() await initialize_pipeline_status() return rag async def _insert_knowledge_batch(self, rag: LightRAG, texts: List[str]): """Insert knowledge texts into RAG instance""" for text in texts: if text.strip(): await rag.ainsert(text) async def query_with_memory( self, rag: LightRAG, question: str, conversation_id: str, mode: str = "hybrid", max_memory_turns: int = 10 ) -> str: """Query RAG with conversation memory""" # Get conversation memory memory = self.conversation_memory.get(conversation_id, []) # Build context with recent conversation history context_prompt = self._build_context_prompt(question, memory, max_memory_turns) # Query LightRAG response = await rag.aquery( context_prompt, param=QueryParam(mode=mode) ) # Update conversation memory self._update_conversation_memory(conversation_id, question, response) return response def _build_context_prompt(self, question: str, memory: List[Dict], max_turns: int) -> str: """Build context prompt with conversation history""" if not memory: return question # Get recent conversation turns recent_memory = memory[-max_turns*2:] if len(memory) > max_turns*2 else memory # Build conversation context context = "Previous conversation:\n" for msg in recent_memory: role = msg['role'] content = msg['content'][:200] + "..." if len(msg['content']) > 200 else msg['content'] context += f"{role.title()}: {content}\n" context += f"\nCurrent question: {question}" return context def _update_conversation_memory(self, conversation_id: str, question: str, response: str): """Update conversation memory""" if conversation_id not in self.conversation_memory: self.conversation_memory[conversation_id] = [] memory = self.conversation_memory[conversation_id] # Add user question and assistant response memory.append({ 'role': 'user', 'content': question, 'timestamp': datetime.now().isoformat() }) memory.append({ 'role': 'assistant', 'content': response, 'timestamp': datetime.now().isoformat() }) # Keep only last 50 messages to manage memory if len(memory) > 50: self.conversation_memory[conversation_id] = memory[-50:] def clear_conversation_memory(self, conversation_id: str): """Clear conversation memory for a specific conversation""" if conversation_id in self.conversation_memory: del self.conversation_memory[conversation_id] async def get_rag_instance(self, ai_type: str, user_id: str = None, ai_id: str = None) -> LightRAG: """Get appropriate RAG instance based on AI type""" if ai_type == "fire-safety": return await self.initialize_fire_safety_rag() elif ai_type == "custom" and user_id and ai_id: # For custom AI, we need to load existing knowledge instance_key = f"{user_id}_{ai_id}" if instance_key not in self.rag_instances: # Create instance but don't add knowledge (it should already exist) working_dir = self.base_working_dir / "custom" / user_id / ai_id if working_dir.exists(): rag = await self._create_rag_instance(str(working_dir)) self.rag_instances[instance_key] = rag else: raise ValueError(f"Custom AI {ai_id} knowledge base not found") return self.rag_instances[instance_key] else: raise ValueError(f"Unknown AI type: {ai_type}") async def insert_knowledge_for_custom_ai(self, user_id: str, ai_id: str, knowledge_texts: List[str]): """Insert knowledge for custom AI""" rag = await self.create_custom_rag(user_id, ai_id, knowledge_texts) await self._insert_knowledge_batch(rag, knowledge_texts) async def cleanup(self): """Cleanup resources""" for rag in self.rag_instances.values(): if hasattr(rag, 'finalize_storages'): await rag.finalize_storages() if self.fire_safety_rag and hasattr(self.fire_safety_rag, 'finalize_storages'): await self.fire_safety_rag.finalize_storages() # Global instance lightrag_manager: Optional[LightRAGManager] = None async def initialize_lightrag_manager(cloudflare_worker: CloudflareWorker) -> LightRAGManager: """Initialize the global LightRAG manager""" global lightrag_manager if lightrag_manager is None: lightrag_manager = LightRAGManager(cloudflare_worker) return lightrag_manager def get_lightrag_manager() -> LightRAGManager: """Get the global LightRAG manager""" if lightrag_manager is None: raise RuntimeError("LightRAG manager not initialized") return lightrag_manager