safetyAI / lightrag_manager.py
al1kss's picture
Create lightrag_manager.py
e9b1b6e verified
raw
history blame
11.2 kB
import asyncio
import os
import json
import logging
from typing import Dict, List, Optional, Any
from pathlib import Path
import hashlib
from datetime import datetime
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc
from lightrag.kg.shared_storage import initialize_pipeline_status
# Your existing cloudflare worker (keep this as is)
class CloudflareWorker:
def __init__(
self,
cloudflare_api_key: str,
api_base_url: str,
llm_model_name: str,
embedding_model_name: str,
max_tokens: int = 4080,
max_response_tokens: int = 4080,
):
self.cloudflare_api_key = cloudflare_api_key
self.api_base_url = api_base_url
self.llm_model_name = llm_model_name
self.embedding_model_name = embedding_model_name
self.max_tokens = max_tokens
self.max_response_tokens = max_response_tokens
async def _send_request(self, model_name: str, input_: dict, debug_log: str = ""):
import requests
import numpy as np
headers = {"Authorization": f"Bearer {self.cloudflare_api_key}"}
try:
response_raw = requests.post(
f"{self.api_base_url}{model_name}", headers=headers, json=input_
).json()
result = response_raw.get("result", {})
if "data" in result: # Embedding case
return np.array(result["data"])
if "response" in result: # LLM response
return result["response"]
raise ValueError("Unexpected Cloudflare response format")
except Exception as e:
logging.error(f"Cloudflare API error: {e}")
return None
async def query(self, prompt: str, system_prompt: str = "", **kwargs) -> str:
# Remove LightRAG-specific kwargs
kwargs.pop("hashing_kv", None)
message = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
input_ = {
"messages": message,
"max_tokens": self.max_tokens,
"response_token_limit": self.max_response_tokens,
}
return await self._send_request(self.llm_model_name, input_)
async def embedding_chunk(self, texts: List[str]) -> np.ndarray:
input_ = {
"text": texts,
"max_tokens": self.max_tokens,
"response_token_limit": self.max_response_tokens,
}
return await self._send_request(self.embedding_model_name, input_)
class LightRAGManager:
"""Enhanced LightRAG Manager with conversation memory and multi-AI support"""
def __init__(self, cloudflare_worker: CloudflareWorker, base_working_dir: str = "./lightrag_storage"):
self.cloudflare_worker = cloudflare_worker
self.base_working_dir = Path(base_working_dir)
self.base_working_dir.mkdir(exist_ok=True)
# Store multiple RAG instances
self.rag_instances: Dict[str, LightRAG] = {}
# Conversation memory store
self.conversation_memory: Dict[str, List[Dict]] = {}
# Initialize default fire safety RAG
self.fire_safety_rag = None
# Setup logging
self.logger = logging.getLogger(__name__)
async def initialize_fire_safety_rag(self) -> LightRAG:
"""Initialize the fire safety RAG instance"""
if self.fire_safety_rag is not None:
return self.fire_safety_rag
working_dir = self.base_working_dir / "fire_safety"
working_dir.mkdir(exist_ok=True)
# Check if knowledge base exists
knowledge_files = list(working_dir.glob("*.json"))
if not knowledge_files:
self.logger.warning("No existing fire safety knowledge base found")
self.fire_safety_rag = await self._create_rag_instance(str(working_dir))
return self.fire_safety_rag
async def create_custom_rag(self, user_id: str, ai_id: str, knowledge_texts: List[str]) -> LightRAG:
"""Create a custom RAG instance for a user's AI"""
instance_key = f"{user_id}_{ai_id}"
if instance_key in self.rag_instances:
return self.rag_instances[instance_key]
# Create working directory for this custom AI
working_dir = self.base_working_dir / "custom" / user_id / ai_id
working_dir.mkdir(parents=True, exist_ok=True)
# Create RAG instance
rag = await self._create_rag_instance(str(working_dir))
# Insert knowledge if provided
if knowledge_texts:
await self._insert_knowledge_batch(rag, knowledge_texts)
self.rag_instances[instance_key] = rag
return rag
async def _create_rag_instance(self, working_dir: str) -> LightRAG:
"""Create a new LightRAG instance"""
rag = LightRAG(
working_dir=working_dir,
max_parallel_insert=2,
llm_model_func=self.cloudflare_worker.query,
llm_model_name=self.cloudflare_worker.llm_model_name,
llm_model_max_token_size=4080,
embedding_func=EmbeddingFunc(
embedding_dim=int(os.getenv("EMBEDDING_DIM", "1024")),
max_token_size=int(os.getenv("MAX_EMBED_TOKENS", "2048")),
func=self.cloudflare_worker.embedding_chunk,
),
# Use NetworkX for simplicity (no external dependencies)
graph_storage="NetworkXStorage",
# Use NanoVectorDB for lightweight vector storage
vector_storage="NanoVectorDBStorage",
)
# Initialize storage
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
async def _insert_knowledge_batch(self, rag: LightRAG, texts: List[str]):
"""Insert knowledge texts into RAG instance"""
for text in texts:
if text.strip():
await rag.ainsert(text)
async def query_with_memory(
self,
rag: LightRAG,
question: str,
conversation_id: str,
mode: str = "hybrid",
max_memory_turns: int = 10
) -> str:
"""Query RAG with conversation memory"""
# Get conversation memory
memory = self.conversation_memory.get(conversation_id, [])
# Build context with recent conversation history
context_prompt = self._build_context_prompt(question, memory, max_memory_turns)
# Query LightRAG
response = await rag.aquery(
context_prompt,
param=QueryParam(mode=mode)
)
# Update conversation memory
self._update_conversation_memory(conversation_id, question, response)
return response
def _build_context_prompt(self, question: str, memory: List[Dict], max_turns: int) -> str:
"""Build context prompt with conversation history"""
if not memory:
return question
# Get recent conversation turns
recent_memory = memory[-max_turns*2:] if len(memory) > max_turns*2 else memory
# Build conversation context
context = "Previous conversation:\n"
for msg in recent_memory:
role = msg['role']
content = msg['content'][:200] + "..." if len(msg['content']) > 200 else msg['content']
context += f"{role.title()}: {content}\n"
context += f"\nCurrent question: {question}"
return context
def _update_conversation_memory(self, conversation_id: str, question: str, response: str):
"""Update conversation memory"""
if conversation_id not in self.conversation_memory:
self.conversation_memory[conversation_id] = []
memory = self.conversation_memory[conversation_id]
# Add user question and assistant response
memory.append({
'role': 'user',
'content': question,
'timestamp': datetime.now().isoformat()
})
memory.append({
'role': 'assistant',
'content': response,
'timestamp': datetime.now().isoformat()
})
# Keep only last 50 messages to manage memory
if len(memory) > 50:
self.conversation_memory[conversation_id] = memory[-50:]
def clear_conversation_memory(self, conversation_id: str):
"""Clear conversation memory for a specific conversation"""
if conversation_id in self.conversation_memory:
del self.conversation_memory[conversation_id]
async def get_rag_instance(self, ai_type: str, user_id: str = None, ai_id: str = None) -> LightRAG:
"""Get appropriate RAG instance based on AI type"""
if ai_type == "fire-safety":
return await self.initialize_fire_safety_rag()
elif ai_type == "custom" and user_id and ai_id:
# For custom AI, we need to load existing knowledge
instance_key = f"{user_id}_{ai_id}"
if instance_key not in self.rag_instances:
# Create instance but don't add knowledge (it should already exist)
working_dir = self.base_working_dir / "custom" / user_id / ai_id
if working_dir.exists():
rag = await self._create_rag_instance(str(working_dir))
self.rag_instances[instance_key] = rag
else:
raise ValueError(f"Custom AI {ai_id} knowledge base not found")
return self.rag_instances[instance_key]
else:
raise ValueError(f"Unknown AI type: {ai_type}")
async def insert_knowledge_for_custom_ai(self, user_id: str, ai_id: str, knowledge_texts: List[str]):
"""Insert knowledge for custom AI"""
rag = await self.create_custom_rag(user_id, ai_id, knowledge_texts)
await self._insert_knowledge_batch(rag, knowledge_texts)
async def cleanup(self):
"""Cleanup resources"""
for rag in self.rag_instances.values():
if hasattr(rag, 'finalize_storages'):
await rag.finalize_storages()
if self.fire_safety_rag and hasattr(self.fire_safety_rag, 'finalize_storages'):
await self.fire_safety_rag.finalize_storages()
# Global instance
lightrag_manager: Optional[LightRAGManager] = None
async def initialize_lightrag_manager(cloudflare_worker: CloudflareWorker) -> LightRAGManager:
"""Initialize the global LightRAG manager"""
global lightrag_manager
if lightrag_manager is None:
lightrag_manager = LightRAGManager(cloudflare_worker)
return lightrag_manager
def get_lightrag_manager() -> LightRAGManager:
"""Get the global LightRAG manager"""
if lightrag_manager is None:
raise RuntimeError("LightRAG manager not initialized")
return lightrag_manager