|
|
import asyncio |
|
|
import os |
|
|
import json |
|
|
import logging |
|
|
from typing import Dict, List, Optional, Any |
|
|
from pathlib import Path |
|
|
import hashlib |
|
|
from datetime import datetime |
|
|
|
|
|
from lightrag import LightRAG, QueryParam |
|
|
from lightrag.utils import EmbeddingFunc |
|
|
from lightrag.kg.shared_storage import initialize_pipeline_status |
|
|
|
|
|
|
|
|
class CloudflareWorker: |
|
|
def __init__( |
|
|
self, |
|
|
cloudflare_api_key: str, |
|
|
api_base_url: str, |
|
|
llm_model_name: str, |
|
|
embedding_model_name: str, |
|
|
max_tokens: int = 4080, |
|
|
max_response_tokens: int = 4080, |
|
|
): |
|
|
self.cloudflare_api_key = cloudflare_api_key |
|
|
self.api_base_url = api_base_url |
|
|
self.llm_model_name = llm_model_name |
|
|
self.embedding_model_name = embedding_model_name |
|
|
self.max_tokens = max_tokens |
|
|
self.max_response_tokens = max_response_tokens |
|
|
|
|
|
async def _send_request(self, model_name: str, input_: dict, debug_log: str = ""): |
|
|
import requests |
|
|
import numpy as np |
|
|
|
|
|
headers = {"Authorization": f"Bearer {self.cloudflare_api_key}"} |
|
|
|
|
|
try: |
|
|
response_raw = requests.post( |
|
|
f"{self.api_base_url}{model_name}", headers=headers, json=input_ |
|
|
).json() |
|
|
|
|
|
result = response_raw.get("result", {}) |
|
|
|
|
|
if "data" in result: |
|
|
return np.array(result["data"]) |
|
|
|
|
|
if "response" in result: |
|
|
return result["response"] |
|
|
|
|
|
raise ValueError("Unexpected Cloudflare response format") |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Cloudflare API error: {e}") |
|
|
return None |
|
|
|
|
|
async def query(self, prompt: str, system_prompt: str = "", **kwargs) -> str: |
|
|
|
|
|
kwargs.pop("hashing_kv", None) |
|
|
|
|
|
message = [ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": prompt}, |
|
|
] |
|
|
|
|
|
input_ = { |
|
|
"messages": message, |
|
|
"max_tokens": self.max_tokens, |
|
|
"response_token_limit": self.max_response_tokens, |
|
|
} |
|
|
|
|
|
return await self._send_request(self.llm_model_name, input_) |
|
|
|
|
|
async def embedding_chunk(self, texts: List[str]) -> np.ndarray: |
|
|
input_ = { |
|
|
"text": texts, |
|
|
"max_tokens": self.max_tokens, |
|
|
"response_token_limit": self.max_response_tokens, |
|
|
} |
|
|
|
|
|
return await self._send_request(self.embedding_model_name, input_) |
|
|
|
|
|
|
|
|
class LightRAGManager: |
|
|
"""Enhanced LightRAG Manager with conversation memory and multi-AI support""" |
|
|
|
|
|
def __init__(self, cloudflare_worker: CloudflareWorker, base_working_dir: str = "./lightrag_storage"): |
|
|
self.cloudflare_worker = cloudflare_worker |
|
|
self.base_working_dir = Path(base_working_dir) |
|
|
self.base_working_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
self.rag_instances: Dict[str, LightRAG] = {} |
|
|
|
|
|
|
|
|
self.conversation_memory: Dict[str, List[Dict]] = {} |
|
|
|
|
|
|
|
|
self.fire_safety_rag = None |
|
|
|
|
|
|
|
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
async def initialize_fire_safety_rag(self) -> LightRAG: |
|
|
"""Initialize the fire safety RAG instance""" |
|
|
if self.fire_safety_rag is not None: |
|
|
return self.fire_safety_rag |
|
|
|
|
|
working_dir = self.base_working_dir / "fire_safety" |
|
|
working_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
knowledge_files = list(working_dir.glob("*.json")) |
|
|
if not knowledge_files: |
|
|
self.logger.warning("No existing fire safety knowledge base found") |
|
|
|
|
|
self.fire_safety_rag = await self._create_rag_instance(str(working_dir)) |
|
|
return self.fire_safety_rag |
|
|
|
|
|
async def create_custom_rag(self, user_id: str, ai_id: str, knowledge_texts: List[str]) -> LightRAG: |
|
|
"""Create a custom RAG instance for a user's AI""" |
|
|
instance_key = f"{user_id}_{ai_id}" |
|
|
|
|
|
if instance_key in self.rag_instances: |
|
|
return self.rag_instances[instance_key] |
|
|
|
|
|
|
|
|
working_dir = self.base_working_dir / "custom" / user_id / ai_id |
|
|
working_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
rag = await self._create_rag_instance(str(working_dir)) |
|
|
|
|
|
|
|
|
if knowledge_texts: |
|
|
await self._insert_knowledge_batch(rag, knowledge_texts) |
|
|
|
|
|
self.rag_instances[instance_key] = rag |
|
|
return rag |
|
|
|
|
|
async def _create_rag_instance(self, working_dir: str) -> LightRAG: |
|
|
"""Create a new LightRAG instance""" |
|
|
rag = LightRAG( |
|
|
working_dir=working_dir, |
|
|
max_parallel_insert=2, |
|
|
llm_model_func=self.cloudflare_worker.query, |
|
|
llm_model_name=self.cloudflare_worker.llm_model_name, |
|
|
llm_model_max_token_size=4080, |
|
|
embedding_func=EmbeddingFunc( |
|
|
embedding_dim=int(os.getenv("EMBEDDING_DIM", "1024")), |
|
|
max_token_size=int(os.getenv("MAX_EMBED_TOKENS", "2048")), |
|
|
func=self.cloudflare_worker.embedding_chunk, |
|
|
), |
|
|
|
|
|
graph_storage="NetworkXStorage", |
|
|
|
|
|
vector_storage="NanoVectorDBStorage", |
|
|
) |
|
|
|
|
|
|
|
|
await rag.initialize_storages() |
|
|
await initialize_pipeline_status() |
|
|
|
|
|
return rag |
|
|
|
|
|
async def _insert_knowledge_batch(self, rag: LightRAG, texts: List[str]): |
|
|
"""Insert knowledge texts into RAG instance""" |
|
|
for text in texts: |
|
|
if text.strip(): |
|
|
await rag.ainsert(text) |
|
|
|
|
|
async def query_with_memory( |
|
|
self, |
|
|
rag: LightRAG, |
|
|
question: str, |
|
|
conversation_id: str, |
|
|
mode: str = "hybrid", |
|
|
max_memory_turns: int = 10 |
|
|
) -> str: |
|
|
"""Query RAG with conversation memory""" |
|
|
|
|
|
|
|
|
memory = self.conversation_memory.get(conversation_id, []) |
|
|
|
|
|
|
|
|
context_prompt = self._build_context_prompt(question, memory, max_memory_turns) |
|
|
|
|
|
|
|
|
response = await rag.aquery( |
|
|
context_prompt, |
|
|
param=QueryParam(mode=mode) |
|
|
) |
|
|
|
|
|
|
|
|
self._update_conversation_memory(conversation_id, question, response) |
|
|
|
|
|
return response |
|
|
|
|
|
def _build_context_prompt(self, question: str, memory: List[Dict], max_turns: int) -> str: |
|
|
"""Build context prompt with conversation history""" |
|
|
if not memory: |
|
|
return question |
|
|
|
|
|
|
|
|
recent_memory = memory[-max_turns*2:] if len(memory) > max_turns*2 else memory |
|
|
|
|
|
|
|
|
context = "Previous conversation:\n" |
|
|
for msg in recent_memory: |
|
|
role = msg['role'] |
|
|
content = msg['content'][:200] + "..." if len(msg['content']) > 200 else msg['content'] |
|
|
context += f"{role.title()}: {content}\n" |
|
|
|
|
|
context += f"\nCurrent question: {question}" |
|
|
return context |
|
|
|
|
|
def _update_conversation_memory(self, conversation_id: str, question: str, response: str): |
|
|
"""Update conversation memory""" |
|
|
if conversation_id not in self.conversation_memory: |
|
|
self.conversation_memory[conversation_id] = [] |
|
|
|
|
|
memory = self.conversation_memory[conversation_id] |
|
|
|
|
|
|
|
|
memory.append({ |
|
|
'role': 'user', |
|
|
'content': question, |
|
|
'timestamp': datetime.now().isoformat() |
|
|
}) |
|
|
|
|
|
memory.append({ |
|
|
'role': 'assistant', |
|
|
'content': response, |
|
|
'timestamp': datetime.now().isoformat() |
|
|
}) |
|
|
|
|
|
|
|
|
if len(memory) > 50: |
|
|
self.conversation_memory[conversation_id] = memory[-50:] |
|
|
|
|
|
def clear_conversation_memory(self, conversation_id: str): |
|
|
"""Clear conversation memory for a specific conversation""" |
|
|
if conversation_id in self.conversation_memory: |
|
|
del self.conversation_memory[conversation_id] |
|
|
|
|
|
async def get_rag_instance(self, ai_type: str, user_id: str = None, ai_id: str = None) -> LightRAG: |
|
|
"""Get appropriate RAG instance based on AI type""" |
|
|
if ai_type == "fire-safety": |
|
|
return await self.initialize_fire_safety_rag() |
|
|
elif ai_type == "custom" and user_id and ai_id: |
|
|
|
|
|
instance_key = f"{user_id}_{ai_id}" |
|
|
if instance_key not in self.rag_instances: |
|
|
|
|
|
working_dir = self.base_working_dir / "custom" / user_id / ai_id |
|
|
if working_dir.exists(): |
|
|
rag = await self._create_rag_instance(str(working_dir)) |
|
|
self.rag_instances[instance_key] = rag |
|
|
else: |
|
|
raise ValueError(f"Custom AI {ai_id} knowledge base not found") |
|
|
return self.rag_instances[instance_key] |
|
|
else: |
|
|
raise ValueError(f"Unknown AI type: {ai_type}") |
|
|
|
|
|
async def insert_knowledge_for_custom_ai(self, user_id: str, ai_id: str, knowledge_texts: List[str]): |
|
|
"""Insert knowledge for custom AI""" |
|
|
rag = await self.create_custom_rag(user_id, ai_id, knowledge_texts) |
|
|
await self._insert_knowledge_batch(rag, knowledge_texts) |
|
|
|
|
|
async def cleanup(self): |
|
|
"""Cleanup resources""" |
|
|
for rag in self.rag_instances.values(): |
|
|
if hasattr(rag, 'finalize_storages'): |
|
|
await rag.finalize_storages() |
|
|
|
|
|
if self.fire_safety_rag and hasattr(self.fire_safety_rag, 'finalize_storages'): |
|
|
await self.fire_safety_rag.finalize_storages() |
|
|
|
|
|
|
|
|
|
|
|
lightrag_manager: Optional[LightRAGManager] = None |
|
|
|
|
|
async def initialize_lightrag_manager(cloudflare_worker: CloudflareWorker) -> LightRAGManager: |
|
|
"""Initialize the global LightRAG manager""" |
|
|
global lightrag_manager |
|
|
if lightrag_manager is None: |
|
|
lightrag_manager = LightRAGManager(cloudflare_worker) |
|
|
return lightrag_manager |
|
|
|
|
|
def get_lightrag_manager() -> LightRAGManager: |
|
|
"""Get the global LightRAG manager""" |
|
|
if lightrag_manager is None: |
|
|
raise RuntimeError("LightRAG manager not initialized") |
|
|
return lightrag_manager |