al1kss commited on
Commit
663e454
Β·
verified Β·
1 Parent(s): fb0b94c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +299 -51
app.py CHANGED
@@ -1,36 +1,271 @@
1
  import gradio as gr
2
  import asyncio
3
- from main import app, rag_instance, startup_event
4
- import uvicorn
5
- import threading
6
- import time
 
 
7
 
8
- # Initialize the FastAPI app
9
- async def init_app():
10
- await startup_event()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Run FastAPI in background
13
- def run_fastapi():
14
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
15
 
16
- # Start FastAPI server in background thread
17
- threading.Thread(target=run_fastapi, daemon=True).start()
 
 
18
 
19
- # Initialize RAG system
20
- asyncio.run(init_app())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Simple Gradio interface
23
  async def ask_question(question, mode="hybrid"):
24
- if not rag_instance:
25
- return "❌ RAG system not initialized yet. Please wait..."
26
 
27
  try:
28
- from lightrag import QueryParam
29
- response = await rag_instance.aquery(
30
- question,
31
- param=QueryParam(mode=mode)
32
- )
33
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  except Exception as e:
35
  return f"❌ Error: {str(e)}"
36
 
@@ -38,51 +273,64 @@ def sync_ask_question(question, mode):
38
  return asyncio.run(ask_question(question, mode))
39
 
40
  # Create Gradio interface
41
- with gr.Blocks(title="πŸ”₯ Fire Safety AI Assistant") as demo:
42
- gr.HTML("<h1>πŸ”₯ Fire Safety AI Assistant</h1>")
43
- gr.HTML("<p>Ask questions about Vietnamese fire safety regulations</p>")
44
 
45
  with gr.Row():
46
- with gr.Column():
47
  question_input = gr.Textbox(
48
  label="Your Question",
49
  placeholder="What are the requirements for emergency exits?",
50
- lines=2
51
  )
52
  mode_dropdown = gr.Dropdown(
53
  choices=["hybrid", "local", "global", "naive"],
54
  value="hybrid",
55
- label="Search Mode"
 
 
 
 
 
 
 
 
 
56
  )
57
- submit_btn = gr.Button("Ask Question", variant="primary")
58
 
59
- with gr.Column():
60
- answer_output = gr.Textbox(
61
- label="Answer",
62
- lines=10,
63
- show_copy_button=True
64
- )
65
 
66
  # Example questions
67
- gr.HTML("<h3>Example Questions:</h3>")
68
- examples = [
69
- "What are the requirements for emergency exits?",
70
- "How many exits does a building need?",
71
- "What are fire safety rules for stairwells?",
72
- "What are building safety requirements?",
73
- ]
74
-
75
- for example in examples:
76
- gr.Button(example).click(
77
- lambda x=example: x,
78
- outputs=question_input
79
- )
80
 
 
 
 
 
 
81
  submit_btn.click(
82
  sync_ask_question,
83
  inputs=[question_input, mode_dropdown],
84
  outputs=answer_output
85
  )
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- # Also expose FastAPI at /api
88
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
  import asyncio
3
+ import os
4
+ import zipfile
5
+ import requests
6
+ from pathlib import Path
7
+ import numpy as np
8
+ from typing import List
9
 
10
+ # Try different LightRAG imports based on version
11
+ try:
12
+ from lightrag import LightRAG, QueryParam
13
+ from lightrag.utils import EmbeddingFunc
14
+ LIGHTRAG_AVAILABLE = True
15
+ except ImportError:
16
+ try:
17
+ from lightrag.lightrag import LightRAG
18
+ from lightrag.query import QueryParam
19
+ from lightrag.utils import EmbeddingFunc
20
+ LIGHTRAG_AVAILABLE = True
21
+ except ImportError:
22
+ try:
23
+ from lightrag.core import LightRAG
24
+ from lightrag.core import QueryParam
25
+ from lightrag.utils import EmbeddingFunc
26
+ LIGHTRAG_AVAILABLE = True
27
+ except ImportError:
28
+ print("❌ LightRAG import failed - using fallback mode")
29
+ LIGHTRAG_AVAILABLE = False
30
+
31
+ # Fallback CloudflareWorker with simple search
32
+ class CloudflareWorker:
33
+ def __init__(self, cloudflare_api_key: str, api_base_url: str, llm_model_name: str, embedding_model_name: str):
34
+ self.cloudflare_api_key = cloudflare_api_key
35
+ self.api_base_url = api_base_url
36
+ self.llm_model_name = llm_model_name
37
+ self.embedding_model_name = embedding_model_name
38
+ self.max_tokens = 4080
39
+ self.max_response_tokens = 4080
40
+
41
+ async def _send_request(self, model_name: str, input_: dict, debug_log: str = ""):
42
+ headers = {"Authorization": f"Bearer {self.cloudflare_api_key}"}
43
+
44
+ try:
45
+ response_raw = requests.post(
46
+ f"{self.api_base_url}{model_name}",
47
+ headers=headers,
48
+ json=input_,
49
+ timeout=30
50
+ ).json()
51
+
52
+ result = response_raw.get("result", {})
53
+
54
+ if "data" in result:
55
+ return np.array(result["data"]) if LIGHTRAG_AVAILABLE else result["data"]
56
+ if "response" in result:
57
+ return result["response"]
58
+
59
+ raise ValueError(f"Unexpected response format: {response_raw}")
60
+
61
+ except Exception as e:
62
+ print(f"Cloudflare API Error: {e}")
63
+ return None
64
+
65
+ async def query(self, prompt: str, system_prompt: str = '', **kwargs) -> str:
66
+ kwargs.pop("hashing_kv", None)
67
+
68
+ message = [
69
+ {"role": "system", "content": system_prompt},
70
+ {"role": "user", "content": prompt}
71
+ ]
72
+
73
+ input_ = {
74
+ "messages": message,
75
+ "max_tokens": self.max_tokens,
76
+ "response_token_limit": self.max_response_tokens,
77
+ }
78
+
79
+ result = await self._send_request(self.llm_model_name, input_)
80
+ return result if result is not None else "Error: Failed to get response"
81
+
82
+ async def embedding_chunk(self, texts: List[str]):
83
+ input_ = {
84
+ "text": texts,
85
+ "max_tokens": self.max_tokens,
86
+ "response_token_limit": self.max_response_tokens,
87
+ }
88
+
89
+ result = await self._send_request(self.embedding_model_name, input_)
90
+
91
+ if result is None:
92
+ if LIGHTRAG_AVAILABLE:
93
+ return np.random.rand(len(texts), 1024).astype(np.float32)
94
+ else:
95
+ return [[0.0] * 1024 for _ in texts]
96
+
97
+ return result
98
+
99
+ # Simple fallback knowledge store if LightRAG fails
100
+ class SimpleKnowledgeStore:
101
+ def __init__(self, data_dir: str):
102
+ self.data_dir = data_dir
103
+ self.chunks = []
104
+ self.entities = []
105
+ self.load_data()
106
+
107
+ def load_data(self):
108
+ try:
109
+ import json
110
+ chunks_file = Path(self.data_dir) / "kv_store_text_chunks.json"
111
+ if chunks_file.exists():
112
+ with open(chunks_file, 'r', encoding='utf-8') as f:
113
+ data = json.load(f)
114
+ self.chunks = list(data.values()) if data else []
115
+
116
+ entities_file = Path(self.data_dir) / "vdb_entities.json"
117
+ if entities_file.exists():
118
+ with open(entities_file, 'r', encoding='utf-8') as f:
119
+ entities_data = json.load(f)
120
+ if isinstance(entities_data, dict) and 'data' in entities_data:
121
+ self.entities = entities_data['data']
122
+ elif isinstance(entities_data, list):
123
+ self.entities = entities_data
124
+ else:
125
+ self.entities = []
126
+
127
+ print(f"βœ… Loaded {len(self.chunks)} chunks and {len(self.entities)} entities")
128
+
129
+ except Exception as e:
130
+ print(f"⚠️ Error loading data: {e}")
131
+ self.chunks = []
132
+ self.entities = []
133
+
134
+ def search(self, query: str, limit: int = 5) -> List[str]:
135
+ query_lower = query.lower()
136
+ results = []
137
+
138
+ for chunk in self.chunks:
139
+ if isinstance(chunk, dict) and 'content' in chunk:
140
+ content = chunk['content']
141
+ if any(word in content.lower() for word in query_lower.split()):
142
+ results.append(content)
143
+
144
+ for entity in self.entities:
145
+ if isinstance(entity, dict):
146
+ entity_text = str(entity)
147
+ if any(word in entity_text.lower() for word in query_lower.split()):
148
+ results.append(entity_text)
149
+
150
+ return results[:limit]
151
 
152
+ # Configuration
153
+ CLOUDFLARE_API_KEY = os.getenv('CLOUDFLARE_API_KEY', 'lMbDDfHi887AK243ZUenm4dHV2nwEx2NSmX6xuq5')
154
+ API_BASE_URL = "https://api.cloudflare.com/client/v4/accounts/07c4bcfbc1891c3e528e1c439fee68bd/ai/run/"
155
+ EMBEDDING_MODEL = '@cf/baai/bge-m3'
156
+ LLM_MODEL = "@cf/meta/llama-3.2-3b-instruct"
157
+ WORKING_DIR = "./dickens"
158
 
159
+ # Global instances
160
+ rag_instance = None
161
+ knowledge_store = None
162
+ cloudflare_worker = None
163
 
164
+ async def initialize_system():
165
+ global rag_instance, knowledge_store, cloudflare_worker
166
+
167
+ print("πŸ”„ Initializing system...")
168
+
169
+ # Download data if needed
170
+ dickens_path = Path(WORKING_DIR)
171
+ has_data = dickens_path.exists() and len(list(dickens_path.glob("*.json"))) > 0
172
+
173
+ if not has_data:
174
+ print("πŸ“₯ Downloading RAG database...")
175
+ try:
176
+ # REPLACE YOUR_USERNAME with your actual GitHub username
177
+ data_url = "https://github.com/YOUR_USERNAME/fire-safety-ai/releases/download/v1.0-data/dickens.zip"
178
+
179
+ response = requests.get(data_url, timeout=60)
180
+ response.raise_for_status()
181
+
182
+ with open("dickens.zip", "wb") as f:
183
+ f.write(response.content)
184
+
185
+ with zipfile.ZipFile("dickens.zip", 'r') as zip_ref:
186
+ zip_ref.extractall(".")
187
+
188
+ os.remove("dickens.zip")
189
+ print("βœ… Data downloaded!")
190
+
191
+ except Exception as e:
192
+ print(f"⚠️ Download failed: {e}")
193
+ os.makedirs(WORKING_DIR, exist_ok=True)
194
+
195
+ # Initialize Cloudflare worker
196
+ cloudflare_worker = CloudflareWorker(
197
+ cloudflare_api_key=CLOUDFLARE_API_KEY,
198
+ api_base_url=API_BASE_URL,
199
+ embedding_model_name=EMBEDDING_MODEL,
200
+ llm_model_name=LLM_MODEL,
201
+ )
202
+
203
+ # Try to initialize LightRAG, fallback to simple store
204
+ if LIGHTRAG_AVAILABLE:
205
+ try:
206
+ rag_instance = LightRAG(
207
+ working_dir=WORKING_DIR,
208
+ max_parallel_insert=2,
209
+ llm_model_func=cloudflare_worker.query,
210
+ llm_model_name=LLM_MODEL,
211
+ llm_model_max_token_size=4080,
212
+ embedding_func=EmbeddingFunc(
213
+ embedding_dim=1024,
214
+ max_token_size=2048,
215
+ func=lambda texts: cloudflare_worker.embedding_chunk(texts),
216
+ ),
217
+ )
218
+
219
+ await rag_instance.initialize_storages()
220
+ print("βœ… LightRAG system initialized!")
221
+
222
+ except Exception as e:
223
+ print(f"⚠️ LightRAG failed, using fallback: {e}")
224
+ knowledge_store = SimpleKnowledgeStore(WORKING_DIR)
225
+ else:
226
+ print("πŸ”„ Using simple knowledge store...")
227
+ knowledge_store = SimpleKnowledgeStore(WORKING_DIR)
228
+
229
+ print("βœ… System ready!")
230
+
231
+ # Initialize on startup
232
+ asyncio.run(initialize_system())
233
 
 
234
  async def ask_question(question, mode="hybrid"):
235
+ if not question.strip():
236
+ return "❌ Please enter a question."
237
 
238
  try:
239
+ print(f"πŸ” Processing question: {question}")
240
+
241
+ # Use LightRAG if available, otherwise fallback
242
+ if rag_instance and LIGHTRAG_AVAILABLE:
243
+ response = await rag_instance.aquery(
244
+ question,
245
+ param=QueryParam(mode=mode)
246
+ )
247
+ return response
248
+
249
+ elif knowledge_store and cloudflare_worker:
250
+ # Fallback: simple search + Cloudflare AI
251
+ relevant_chunks = knowledge_store.search(question, limit=3)
252
+ context = "\n".join(relevant_chunks) if relevant_chunks else "No specific context found."
253
+
254
+ system_prompt = """You are a Fire Safety AI Assistant specializing in Vietnamese fire safety regulations.
255
+ Use the provided context to answer questions about building codes, emergency exits, and fire safety requirements."""
256
+
257
+ user_prompt = f"""Context: {context}
258
+
259
+ Question: {question}
260
+
261
+ Please provide a helpful answer based on the context about Vietnamese fire safety regulations."""
262
+
263
+ response = await cloudflare_worker.query(user_prompt, system_prompt)
264
+ return response
265
+
266
+ else:
267
+ return "❌ System not initialized yet. Please wait..."
268
+
269
  except Exception as e:
270
  return f"❌ Error: {str(e)}"
271
 
 
273
  return asyncio.run(ask_question(question, mode))
274
 
275
  # Create Gradio interface
276
+ with gr.Blocks(title="πŸ”₯ Fire Safety AI Assistant", theme=gr.themes.Soft()) as demo:
277
+ gr.HTML("<h1 style='text-align: center;'>πŸ”₯ Fire Safety AI Assistant</h1>")
278
+ gr.HTML("<p style='text-align: center;'>Ask questions about Vietnamese fire safety regulations</p>")
279
 
280
  with gr.Row():
281
+ with gr.Column(scale=1):
282
  question_input = gr.Textbox(
283
  label="Your Question",
284
  placeholder="What are the requirements for emergency exits?",
285
+ lines=3
286
  )
287
  mode_dropdown = gr.Dropdown(
288
  choices=["hybrid", "local", "global", "naive"],
289
  value="hybrid",
290
+ label="Search Mode",
291
+ info="Hybrid is recommended for best results"
292
+ )
293
+ submit_btn = gr.Button("πŸ” Ask Question", variant="primary", size="lg")
294
+
295
+ with gr.Column(scale=2):
296
+ answer_output = gr.Textbox(
297
+ label="Answer",
298
+ lines=15,
299
+ show_copy_button=True
300
  )
 
301
 
302
+ # System status
303
+ status_text = "βœ… LightRAG System" if LIGHTRAG_AVAILABLE else "⚠️ Fallback Mode"
304
+ gr.HTML(f"<p style='text-align: center; color: gray;'>Status: {status_text}</p>")
 
 
 
305
 
306
  # Example questions
307
+ gr.HTML("<h3 style='text-align: center;'>πŸ’‘ Example Questions:</h3>")
308
+
309
+ with gr.Row():
310
+ example1 = gr.Button("What are the requirements for emergency exits?", size="sm")
311
+ example2 = gr.Button("How many exits does a building need?", size="sm")
 
 
 
 
 
 
 
 
312
 
313
+ with gr.Row():
314
+ example3 = gr.Button("What are fire safety rules for stairwells?", size="sm")
315
+ example4 = gr.Button("What are building safety requirements?", size="sm")
316
+
317
+ # Event handlers
318
  submit_btn.click(
319
  sync_ask_question,
320
  inputs=[question_input, mode_dropdown],
321
  outputs=answer_output
322
  )
323
+
324
+ question_input.submit(
325
+ sync_ask_question,
326
+ inputs=[question_input, mode_dropdown],
327
+ outputs=answer_output
328
+ )
329
+
330
+ example1.click(lambda: "What are the requirements for emergency exits?", outputs=question_input)
331
+ example2.click(lambda: "How many exits does a building need?", outputs=question_input)
332
+ example3.click(lambda: "What are fire safety rules for stairwells?", outputs=question_input)
333
+ example4.click(lambda: "What are building safety requirements?", outputs=question_input)
334
 
335
+ if __name__ == "__main__":
336
+ demo.launch()