|
|
import gradio as gr |
|
|
import asyncio |
|
|
import os |
|
|
import zipfile |
|
|
import requests |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
from typing import List |
|
|
|
|
|
|
|
|
try: |
|
|
from lightrag import LightRAG, QueryParam |
|
|
from lightrag.utils import EmbeddingFunc |
|
|
LIGHTRAG_AVAILABLE = True |
|
|
except ImportError: |
|
|
try: |
|
|
from lightrag.lightrag import LightRAG |
|
|
from lightrag.query import QueryParam |
|
|
from lightrag.utils import EmbeddingFunc |
|
|
LIGHTRAG_AVAILABLE = True |
|
|
except ImportError: |
|
|
try: |
|
|
from lightrag.core import LightRAG |
|
|
from lightrag.core import QueryParam |
|
|
from lightrag.utils import EmbeddingFunc |
|
|
LIGHTRAG_AVAILABLE = True |
|
|
except ImportError: |
|
|
print("β LightRAG import failed - using fallback mode") |
|
|
LIGHTRAG_AVAILABLE = False |
|
|
|
|
|
|
|
|
class CloudflareWorker: |
|
|
def __init__(self, cloudflare_api_key: str, api_base_url: str, llm_model_name: str, embedding_model_name: str): |
|
|
self.cloudflare_api_key = cloudflare_api_key |
|
|
self.api_base_url = api_base_url |
|
|
self.llm_model_name = llm_model_name |
|
|
self.embedding_model_name = embedding_model_name |
|
|
self.max_tokens = 4080 |
|
|
self.max_response_tokens = 4080 |
|
|
|
|
|
async def _send_request(self, model_name: str, input_: dict, debug_log: str = ""): |
|
|
headers = {"Authorization": f"Bearer {self.cloudflare_api_key}"} |
|
|
|
|
|
try: |
|
|
response_raw = requests.post( |
|
|
f"{self.api_base_url}{model_name}", |
|
|
headers=headers, |
|
|
json=input_, |
|
|
timeout=30 |
|
|
).json() |
|
|
|
|
|
result = response_raw.get("result", {}) |
|
|
|
|
|
if "data" in result: |
|
|
return np.array(result["data"]) if LIGHTRAG_AVAILABLE else result["data"] |
|
|
if "response" in result: |
|
|
return result["response"] |
|
|
|
|
|
raise ValueError(f"Unexpected response format: {response_raw}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Cloudflare API Error: {e}") |
|
|
return None |
|
|
|
|
|
async def query(self, prompt: str, system_prompt: str = '', **kwargs) -> str: |
|
|
kwargs.pop("hashing_kv", None) |
|
|
|
|
|
message = [ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": prompt} |
|
|
] |
|
|
|
|
|
input_ = { |
|
|
"messages": message, |
|
|
"max_tokens": self.max_tokens, |
|
|
"response_token_limit": self.max_response_tokens, |
|
|
} |
|
|
|
|
|
result = await self._send_request(self.llm_model_name, input_) |
|
|
return result if result is not None else "Error: Failed to get response" |
|
|
|
|
|
async def embedding_chunk(self, texts: List[str]): |
|
|
input_ = { |
|
|
"text": texts, |
|
|
"max_tokens": self.max_tokens, |
|
|
"response_token_limit": self.max_response_tokens, |
|
|
} |
|
|
|
|
|
result = await self._send_request(self.embedding_model_name, input_) |
|
|
|
|
|
if result is None: |
|
|
if LIGHTRAG_AVAILABLE: |
|
|
return np.random.rand(len(texts), 1024).astype(np.float32) |
|
|
else: |
|
|
return [[0.0] * 1024 for _ in texts] |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
class SimpleKnowledgeStore: |
|
|
def __init__(self, data_dir: str): |
|
|
self.data_dir = data_dir |
|
|
self.chunks = [] |
|
|
self.entities = [] |
|
|
self.load_data() |
|
|
|
|
|
def load_data(self): |
|
|
try: |
|
|
import json |
|
|
chunks_file = Path(self.data_dir) / "kv_store_text_chunks.json" |
|
|
if chunks_file.exists(): |
|
|
with open(chunks_file, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
self.chunks = list(data.values()) if data else [] |
|
|
|
|
|
entities_file = Path(self.data_dir) / "vdb_entities.json" |
|
|
if entities_file.exists(): |
|
|
with open(entities_file, 'r', encoding='utf-8') as f: |
|
|
entities_data = json.load(f) |
|
|
if isinstance(entities_data, dict) and 'data' in entities_data: |
|
|
self.entities = entities_data['data'] |
|
|
elif isinstance(entities_data, list): |
|
|
self.entities = entities_data |
|
|
else: |
|
|
self.entities = [] |
|
|
|
|
|
print(f"β
Loaded {len(self.chunks)} chunks and {len(self.entities)} entities") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β οΈ Error loading data: {e}") |
|
|
self.chunks = [] |
|
|
self.entities = [] |
|
|
|
|
|
def search(self, query: str, limit: int = 5) -> List[str]: |
|
|
query_lower = query.lower() |
|
|
results = [] |
|
|
|
|
|
for chunk in self.chunks: |
|
|
if isinstance(chunk, dict) and 'content' in chunk: |
|
|
content = chunk['content'] |
|
|
if any(word in content.lower() for word in query_lower.split()): |
|
|
results.append(content) |
|
|
|
|
|
for entity in self.entities: |
|
|
if isinstance(entity, dict): |
|
|
entity_text = str(entity) |
|
|
if any(word in entity_text.lower() for word in query_lower.split()): |
|
|
results.append(entity_text) |
|
|
|
|
|
return results[:limit] |
|
|
|
|
|
|
|
|
CLOUDFLARE_API_KEY = os.getenv('CLOUDFLARE_API_KEY', 'lMbDDfHi887AK243ZUenm4dHV2nwEx2NSmX6xuq5') |
|
|
API_BASE_URL = "https://api.cloudflare.com/client/v4/accounts/07c4bcfbc1891c3e528e1c439fee68bd/ai/run/" |
|
|
EMBEDDING_MODEL = '@cf/baai/bge-m3' |
|
|
LLM_MODEL = "@cf/meta/llama-3.2-3b-instruct" |
|
|
WORKING_DIR = "./dickens" |
|
|
|
|
|
|
|
|
rag_instance = None |
|
|
knowledge_store = None |
|
|
cloudflare_worker = None |
|
|
|
|
|
async def initialize_system(): |
|
|
global rag_instance, knowledge_store, cloudflare_worker |
|
|
|
|
|
print("π Initializing system...") |
|
|
|
|
|
|
|
|
dickens_path = Path(WORKING_DIR) |
|
|
has_data = dickens_path.exists() and len(list(dickens_path.glob("*.json"))) > 0 |
|
|
|
|
|
if not has_data: |
|
|
print("π₯ Downloading RAG database...") |
|
|
try: |
|
|
|
|
|
data_url = "https://github.com/YOUR_USERNAME/fire-safety-ai/releases/download/v1.0-data/dickens.zip" |
|
|
|
|
|
response = requests.get(data_url, timeout=60) |
|
|
response.raise_for_status() |
|
|
|
|
|
with open("dickens.zip", "wb") as f: |
|
|
f.write(response.content) |
|
|
|
|
|
with zipfile.ZipFile("dickens.zip", 'r') as zip_ref: |
|
|
zip_ref.extractall(".") |
|
|
|
|
|
os.remove("dickens.zip") |
|
|
print("β
Data downloaded!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β οΈ Download failed: {e}") |
|
|
os.makedirs(WORKING_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
cloudflare_worker = CloudflareWorker( |
|
|
cloudflare_api_key=CLOUDFLARE_API_KEY, |
|
|
api_base_url=API_BASE_URL, |
|
|
embedding_model_name=EMBEDDING_MODEL, |
|
|
llm_model_name=LLM_MODEL, |
|
|
) |
|
|
|
|
|
|
|
|
if LIGHTRAG_AVAILABLE: |
|
|
try: |
|
|
rag_instance = LightRAG( |
|
|
working_dir=WORKING_DIR, |
|
|
max_parallel_insert=2, |
|
|
llm_model_func=cloudflare_worker.query, |
|
|
llm_model_name=LLM_MODEL, |
|
|
llm_model_max_token_size=4080, |
|
|
embedding_func=EmbeddingFunc( |
|
|
embedding_dim=1024, |
|
|
max_token_size=2048, |
|
|
func=lambda texts: cloudflare_worker.embedding_chunk(texts), |
|
|
), |
|
|
) |
|
|
|
|
|
await rag_instance.initialize_storages() |
|
|
print("β
LightRAG system initialized!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β οΈ LightRAG failed, using fallback: {e}") |
|
|
knowledge_store = SimpleKnowledgeStore(WORKING_DIR) |
|
|
else: |
|
|
print("π Using simple knowledge store...") |
|
|
knowledge_store = SimpleKnowledgeStore(WORKING_DIR) |
|
|
|
|
|
print("β
System ready!") |
|
|
|
|
|
|
|
|
asyncio.run(initialize_system()) |
|
|
|
|
|
async def ask_question(question, mode="hybrid"): |
|
|
if not question.strip(): |
|
|
return "β Please enter a question." |
|
|
|
|
|
try: |
|
|
print(f"π Processing question: {question}") |
|
|
|
|
|
|
|
|
if rag_instance and LIGHTRAG_AVAILABLE: |
|
|
response = await rag_instance.aquery( |
|
|
question, |
|
|
param=QueryParam(mode=mode) |
|
|
) |
|
|
return response |
|
|
|
|
|
elif knowledge_store and cloudflare_worker: |
|
|
|
|
|
relevant_chunks = knowledge_store.search(question, limit=3) |
|
|
context = "\n".join(relevant_chunks) if relevant_chunks else "No specific context found." |
|
|
|
|
|
system_prompt = """You are a Fire Safety AI Assistant specializing in Vietnamese fire safety regulations. |
|
|
Use the provided context to answer questions about building codes, emergency exits, and fire safety requirements.""" |
|
|
|
|
|
user_prompt = f"""Context: {context} |
|
|
|
|
|
Question: {question} |
|
|
|
|
|
Please provide a helpful answer based on the context about Vietnamese fire safety regulations.""" |
|
|
|
|
|
response = await cloudflare_worker.query(user_prompt, system_prompt) |
|
|
return response |
|
|
|
|
|
else: |
|
|
return "β System not initialized yet. Please wait..." |
|
|
|
|
|
except Exception as e: |
|
|
return f"β Error: {str(e)}" |
|
|
|
|
|
def sync_ask_question(question, mode): |
|
|
return asyncio.run(ask_question(question, mode)) |
|
|
|
|
|
|
|
|
with gr.Blocks(title="π₯ Fire Safety AI Assistant", theme=gr.themes.Soft()) as demo: |
|
|
gr.HTML("<h1 style='text-align: center;'>π₯ Fire Safety AI Assistant</h1>") |
|
|
gr.HTML("<p style='text-align: center;'>Ask questions about Vietnamese fire safety regulations</p>") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
question_input = gr.Textbox( |
|
|
label="Your Question", |
|
|
placeholder="What are the requirements for emergency exits?", |
|
|
lines=3 |
|
|
) |
|
|
mode_dropdown = gr.Dropdown( |
|
|
choices=["hybrid", "local", "global", "naive"], |
|
|
value="hybrid", |
|
|
label="Search Mode", |
|
|
info="Hybrid is recommended for best results" |
|
|
) |
|
|
submit_btn = gr.Button("π Ask Question", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
answer_output = gr.Textbox( |
|
|
label="Answer", |
|
|
lines=15, |
|
|
show_copy_button=True |
|
|
) |
|
|
|
|
|
|
|
|
status_text = "β
LightRAG System" if LIGHTRAG_AVAILABLE else "β οΈ Fallback Mode" |
|
|
gr.HTML(f"<p style='text-align: center; color: gray;'>Status: {status_text}</p>") |
|
|
|
|
|
|
|
|
gr.HTML("<h3 style='text-align: center;'>π‘ Example Questions:</h3>") |
|
|
|
|
|
with gr.Row(): |
|
|
example1 = gr.Button("What are the requirements for emergency exits?", size="sm") |
|
|
example2 = gr.Button("How many exits does a building need?", size="sm") |
|
|
|
|
|
with gr.Row(): |
|
|
example3 = gr.Button("What are fire safety rules for stairwells?", size="sm") |
|
|
example4 = gr.Button("What are building safety requirements?", size="sm") |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
sync_ask_question, |
|
|
inputs=[question_input, mode_dropdown], |
|
|
outputs=answer_output |
|
|
) |
|
|
|
|
|
question_input.submit( |
|
|
sync_ask_question, |
|
|
inputs=[question_input, mode_dropdown], |
|
|
outputs=answer_output |
|
|
) |
|
|
|
|
|
example1.click(lambda: "What are the requirements for emergency exits?", outputs=question_input) |
|
|
example2.click(lambda: "How many exits does a building need?", outputs=question_input) |
|
|
example3.click(lambda: "What are fire safety rules for stairwells?", outputs=question_input) |
|
|
example4.click(lambda: "What are building safety requirements?", outputs=question_input) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |