dungeon29 commited on
Commit
e2284a9
·
verified ·
1 Parent(s): 090933b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -87
app.py CHANGED
@@ -11,47 +11,50 @@ import time
11
  import os
12
 
13
  # --- import your architecture ---
14
- from model import DeBERTaLSTMClassifier
 
 
15
 
16
  # --- Import RAG modules ---
17
- # Đảm bảo bạn đã có file rag_engine.py llm_client.py trong cùng thư mục
18
- try:
19
- from rag_engine import RAGEngine
20
- from llm_client import LLMClient
21
- RAG_AVAILABLE = True
22
- except ImportError:
23
- print("⚠️ Warning: RAG modules not found. The LLM tab might not work.")
24
- RAG_AVAILABLE = False
25
 
26
  # --------- Config ----------
27
- REPO_ID = "dungeon29/phishing-deberta-lstm"
28
- CKPT_NAME = "pytorch_model.bin"
29
- MODEL_NAME = "microsoft/deberta-base"
30
- LABELS = ["benign", "phishing"]
 
 
 
 
31
 
32
  # --------- Load model/tokenizer once (global) ----------
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
35
 
36
- print(f"Loading DeBERTa model from {REPO_ID}...")
37
  ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_NAME)
38
  checkpoint = torch.load(ckpt_path, map_location=device)
39
 
40
- model_args = checkpoint.get("model_args", {})
 
41
  model = DeBERTaLSTMClassifier(**model_args)
42
 
43
  # Load weights
44
  try:
45
  state_dict = torch.load(ckpt_path, map_location=device)
 
 
46
  if "model_state_dict" in state_dict:
47
  state_dict = state_dict["model_state_dict"]
48
 
49
  model.load_state_dict(state_dict, strict=False)
50
 
 
51
  if hasattr(model, 'attention') and 'attention.weight' not in state_dict:
52
  print("⚠️ Loaded model without attention layer, using newly initialized attention weights")
53
  else:
54
- print("✅ Load DeBERTa weights successfully!")
55
 
56
  except Exception as e:
57
  print(f"❌ Error when loading weights: {e}")
@@ -60,33 +63,29 @@ except Exception as e:
60
  model.to(device).eval()
61
 
62
  # --------- Initialize RAG & LLM ----------
63
- rag_engine = None
64
- llm_client = None
 
65
 
66
- if RAG_AVAILABLE:
67
- print("Initializing RAG Engine...")
68
- try:
69
- rag_engine = RAGEngine()
70
- print("RAG Engine ready.")
71
-
72
- print("Initializing Qwen2.5-3B LLM (this may take a minute)...")
73
- llm_client = LLMClient()
74
- print("LLM ready.")
75
- except Exception as e:
76
- print(f"❌ Error initializing RAG/LLM: {e}")
77
 
78
  # --------- Helper functions ----------
79
  def is_url(text):
 
80
  url_pattern = re.compile(
81
- r'^https?://'
82
- r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|'
83
- r'localhost|'
84
- r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})'
85
- r'(?::\d+)?'
86
  r'(?:/?|[/?]\S+)$', re.IGNORECASE)
87
  return url_pattern.match(text) is not None
88
 
89
  def fetch_html_content(url, timeout=10):
 
90
  try:
91
  headers = {
92
  'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
@@ -95,10 +94,12 @@ def fetch_html_content(url, timeout=10):
95
  response.raise_for_status()
96
 
97
  soup = BeautifulSoup(response.text, 'html.parser')
 
98
  for script in soup(["script", "style", "meta", "noscript", "header", "footer"]):
99
  script.decompose()
100
 
101
  text = soup.get_text(separator=' ')
 
102
  clean_text = " ".join(text.split())
103
  return clean_text, response.status_code
104
  except requests.exceptions.RequestException as e:
@@ -107,6 +108,8 @@ def fetch_html_content(url, timeout=10):
107
  return None, f"General error: {str(e)}"
108
 
109
  def predict_single_text(text, text_type="text"):
 
 
110
  inputs = tokenizer(
111
  text,
112
  return_tensors="pt",
@@ -114,11 +117,14 @@ def predict_single_text(text, text_type="text"):
114
  padding=True,
115
  max_length=256
116
  )
 
117
  inputs.pop("token_type_ids", None)
 
118
  inputs = {k: v.to(device) for k, v in inputs.items()}
119
 
120
  with torch.no_grad():
121
  try:
 
122
  result = model(**inputs, return_attention=True)
123
  if isinstance(result, tuple) and len(result) == 3:
124
  logits, attention_weights, deberta_attentions = result
@@ -127,55 +133,73 @@ def predict_single_text(text, text_type="text"):
127
  logits = result
128
  has_attention = False
129
  except TypeError:
 
130
  logits = model(**inputs)
131
  has_attention = False
132
 
133
  probs = F.softmax(logits, dim=-1).squeeze(0).tolist()
134
 
 
135
  tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze(0).tolist())
 
136
  return probs, tokens, has_attention, attention_weights if has_attention else None
137
 
138
  def combine_predictions(url_probs, html_probs, url_weight=0.3, html_weight=0.7):
 
139
  combined_probs = [
140
- url_weight * url_probs[0] + html_weight * html_probs[0],
141
- url_weight * url_probs[1] + html_weight * html_probs[1]
142
  ]
143
  return combined_probs
144
 
145
- # --------- Inference function (DeBERTa) ----------
146
  def predict_fn(text: str):
147
  if not text or not text.strip():
148
- return "Please enter a URL or text."
149
 
 
150
  if is_url(text.strip()):
 
151
  url = text.strip()
 
 
152
  url_probs, url_tokens, url_has_attention, url_attention = predict_single_text(url, "URL")
 
 
153
  html_content, status = fetch_html_content(url)
154
 
155
  if html_content:
 
156
  html_probs, html_tokens, html_has_attention, html_attention = predict_single_text(html_content, "HTML")
 
 
157
  combined_probs = combine_predictions(url_probs, html_probs)
158
 
 
159
  probs = combined_probs
160
- tokens = url_tokens + ["[SEP]"] + html_tokens[:50]
161
  has_attention = url_has_attention or html_has_attention
162
  attention_weights = url_attention if url_has_attention else html_attention
163
 
164
  analysis_type = "Combined URL + HTML Analysis"
165
  fetch_status = f"✅ Successfully fetched HTML content (Status: {status})"
 
166
  else:
 
167
  probs = url_probs
168
  tokens = url_tokens
169
  has_attention = url_has_attention
170
  attention_weights = url_attention
 
171
  analysis_type = "URL-only Analysis"
172
  fetch_status = f"⚠️ Could not fetch HTML content: {status}"
173
  else:
 
174
  probs, tokens, has_attention, attention_weights = predict_single_text(text, "text")
175
  analysis_type = "Text Analysis"
176
  fetch_status = ""
177
 
178
- # Create HTML output
179
  predicted_class = "phishing" if probs[1] > probs[0] else "benign"
180
  confidence = max(probs)
181
 
@@ -194,6 +218,7 @@ def predict_fn(text: str):
194
  </div>
195
  </div>
196
  """
 
197
  if fetch_status:
198
  detailed_analysis += f"""
199
  <div style="background: #2d2d2d; padding: 15px; border-radius: 10px; margin: 15px 0; border-left: 4px solid #4caf50; color: #e0e0e0;">
@@ -203,39 +228,68 @@ def predict_fn(text: str):
203
 
204
  if has_attention and attention_weights is not None:
205
  attention_scores = attention_weights.squeeze(0).tolist()
 
206
  token_analysis = []
207
  for i, (token, score) in enumerate(zip(tokens, attention_scores)):
 
208
  if token not in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>'] and len(token.strip()) > 0 and score > 0.005:
209
- clean_token = token.replace(' ', '').replace('Ġ', '').strip()
210
- if clean_token:
211
- token_analysis.append({'token': clean_token, 'importance': score})
 
 
 
 
212
 
 
213
  token_analysis.sort(key=lambda x: x['importance'], reverse=True)
214
 
215
  detailed_analysis += f"""
216
  ## Top important tokens:
217
  <div style="background: #2d2d2d; padding: 15px; border-radius: 10px; margin: 15px 0; border-left: 4px solid #4caf50; color: #e0e0e0;">
218
- Found {len(token_analysis)} important tokens
219
  </div>
220
  <div style="font-family: Arial, sans-serif;">
221
  """
222
- for i, token_info in enumerate(token_analysis[:10]):
 
223
  bar_width = int(token_info['importance'] * 100)
224
  color = "#ff4444" if predicted_class == "phishing" else "#44ff44"
 
225
  detailed_analysis += f"""
226
  <div style="margin: 8px 0; display: flex; align-items: center; background: #2d2d2d; padding: 8px; border-radius: 8px; border-left: 4px solid {color};">
227
- <div style="width: 30px; text-align: right; margin-right: 10px; font-weight: bold; color: #ffffff;">{i+1}.</div>
228
- <div style="width: 120px; margin-right: 10px; font-weight: bold; color: #e0e0e0; text-align: right;">{token_info['token']}</div>
 
 
 
 
229
  <div style="width: 300px; background-color: #404040; border-radius: 10px; overflow: hidden; margin-right: 10px; border: 1px solid #555;">
230
- <div style="width: {bar_width}%; background-color: {color}; height: 20px; border-radius: 10px;"></div>
 
 
 
231
  </div>
232
- <div style="color: #cccccc; font-size: 12px; font-weight: bold;">{token_info['importance']:.1%}</div>
233
  </div>
234
  """
235
- detailed_analysis += "</div>"
236
 
237
- # Add stats and simple bars
 
238
  detailed_analysis += f"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  <div style="font-family: Arial, sans-serif; margin: 15px 0; background: #2d2d2d; padding: 20px; border-radius: 15px; border: 1px solid #555;">
240
  <h3 style="color: #ffffff; margin-bottom: 15px;"> Prediction Confidence</h3>
241
  <div style="display: flex; justify-content: space-between; margin-bottom: 10px;">
@@ -247,36 +301,66 @@ def predict_fn(text: str):
247
  {probs[1]:.1%}
248
  </div>
249
  </div>
 
 
 
250
  </div>
251
  """
252
  else:
253
- # Fallback when no attention
254
  detailed_analysis += f"""
255
- <div style="background: #2d2d2d; padding: 20px; border-radius: 15px; margin: 15px 0; border: 1px solid #555;">
256
- <h3 style="color: #ffffff; margin: 0 0 15px 0;">Prediction:</h3>
257
- <p style="color:white">Phishing: {probs[1]:.1%} | Benign: {probs[0]:.1%}</p>
258
- <p style="color: #ffcc02;">Note: Detailed attention weights not available.</p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  </div>
260
  """
 
 
 
 
 
 
261
 
262
- return detailed_analysis
263
 
264
  # --------- RAG Inference function ----------
265
  def rag_predict_fn(text: str):
266
  if not text or not text.strip():
267
  return "Please enter text to analyze."
268
 
269
- if not RAG_AVAILABLE or rag_engine is None or llm_client is None:
270
- return "⚠️ RAG/LLM system is not initialized properly. Check console logs."
 
 
271
 
272
- try:
273
- # 1. Retrieve context
274
- context = rag_engine.retrieve(text)
275
- # 2. Call LLM
276
- response = llm_client.analyze(text, context)
277
- return response
278
- except Exception as e:
279
- return f"Error during RAG analysis: {str(e)}"
280
 
281
  # --------- Gradio UI ----------
282
  css_style="""
@@ -288,11 +372,12 @@ css_style="""
288
  .dark .gradio-container {
289
  background-color: #1e1e1e !important;
290
  }
 
291
  .block {
292
  background-color: #2d2d2d !important;
293
  border: 1px solid #444 !important;
294
  }
295
- .gradio-textbox, .gradio-dropdown {
296
  background-color: #3d3d3d !important;
297
  color: #ffffff !important;
298
  border: 1px solid #666 !important;
@@ -304,65 +389,98 @@ css_style="""
304
  }
305
  .gradio-button:hover {
306
  background-color: #5a5a5a !important;
 
 
307
  }
308
  """
309
-
310
- with gr.Blocks() as demo:
311
- gr.Markdown("# 🛡️ Advanced Phishing Detector System")
312
 
313
  with gr.Tabs():
314
- # --- Tab 1: DeBERTa Analysis ---
315
- with gr.TabItem("DeBERTa + LSTM Detection"):
316
  gr.Markdown("""
317
- ### Hybrid Detection Model
318
- Enter a URL or text. The system uses **DeBERTa** combined with **LSTM** to analyze semantic meaning and structure.
319
- For URLs, it fetches the HTML content and performs a weighted analysis (30% URL structure + 70% Content).
 
 
 
 
 
 
 
 
 
 
 
320
  """)
321
 
322
  with gr.Row():
323
  with gr.Column(scale=2):
324
  input_box = gr.Textbox(
325
- label="Input URL or Text",
326
- placeholder="Enter URL (e.g., http://example.com) or paste suspicious text...",
327
  lines=3
328
  )
329
- btn_submit = gr.Button("Analyze Security", variant="primary")
330
 
331
  gr.Examples(
332
  examples=[
333
  ["http://rendmoiunserviceeee.com"],
334
  ["https://www.google.com"],
335
- ["Dear customer, your account has been suspended. Click here to verify."],
 
336
  ["http://paypaI-security-update.net/login"],
 
 
 
337
  ],
338
  inputs=input_box
339
  )
340
 
341
  with gr.Column(scale=3):
342
  output_html = gr.HTML(label="Analysis Result")
343
-
344
  btn_submit.click(fn=predict_fn, inputs=input_box, outputs=output_html)
345
 
346
  # --- Tab 2: LLM + RAG Analysis ---
347
- with gr.TabItem("LLM + RAG Assistant"):
348
  gr.Markdown("""
349
- ### AI Security Analyst (Qwen2.5 + RAG)
350
- This agent retrieves context from a security knowledge base and uses a Large Language Model to explain *why* something is suspicious.
 
 
 
 
351
  """)
352
 
353
  with gr.Row():
354
  with gr.Column(scale=1):
355
  rag_input = gr.Textbox(
356
- label="Suspicious Content / Question",
357
- placeholder="Paste email content, headers, or ask: 'Why is this url suspicious?'...",
358
  lines=5
359
  )
360
- btn_rag = gr.Button("Ask AI Assistant", variant="primary")
 
 
361
 
 
 
 
 
 
 
 
 
 
362
  with gr.Column(scale=1):
363
  rag_output = gr.Markdown(label="AI Analysis")
 
364
 
365
  btn_rag.click(fn=rag_predict_fn, inputs=[rag_input], outputs=rag_output)
 
366
 
367
  if __name__ == "__main__":
368
  demo.launch()
 
11
  import os
12
 
13
  # --- import your architecture ---
14
+ # Make sure this file is in the repo (e.g., models/deberta_lstm_classifier.py)
15
+ # and update the import path accordingly.
16
+ from model import DeBERTaLSTMClassifier # <-- your class
17
 
18
  # --- Import RAG modules ---
19
+ from rag_engine import RAGEngine
20
+ from llm_client import LLMClient
 
 
 
 
 
 
21
 
22
  # --------- Config ----------
23
+ REPO_ID = "dungeon29/phishing-deberta-lstm" # HF repo that holds the checkpoint
24
+ CKPT_NAME = "pytorch_model.bin" # the .pt file name
25
+ MODEL_NAME = "microsoft/deberta-base" # base tokenizer/backbone
26
+ LABELS = ["benign", "phishing"] # adjust to your classes
27
+
28
+ # If your checkpoint contains hyperparams, you can fetch them like:
29
+ # checkpoint.get("config") or checkpoint.get("model_args")
30
+ # and pass into DeBERTaLSTMClassifier(**model_args)
31
 
32
  # --------- Load model/tokenizer once (global) ----------
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
35
 
 
36
  ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_NAME)
37
  checkpoint = torch.load(ckpt_path, map_location=device)
38
 
39
+ # If you saved hyperparams in the checkpoint, use them:
40
+ model_args = checkpoint.get("model_args", {}) # e.g., {"lstm_hidden":256, "num_labels":2, ...}
41
  model = DeBERTaLSTMClassifier(**model_args)
42
 
43
  # Load weights
44
  try:
45
  state_dict = torch.load(ckpt_path, map_location=device)
46
+
47
+ # Xử lý nếu file lưu dạng checkpoint đầy đủ (có key "model_state_dict")
48
  if "model_state_dict" in state_dict:
49
  state_dict = state_dict["model_state_dict"]
50
 
51
  model.load_state_dict(state_dict, strict=False)
52
 
53
+ # Kiểm tra layer attention
54
  if hasattr(model, 'attention') and 'attention.weight' not in state_dict:
55
  print("⚠️ Loaded model without attention layer, using newly initialized attention weights")
56
  else:
57
+ print("✅ Load weights successfully!")
58
 
59
  except Exception as e:
60
  print(f"❌ Error when loading weights: {e}")
 
63
  model.to(device).eval()
64
 
65
  # --------- Initialize RAG & LLM ----------
66
+ print("Initializing RAG Engine (LangChain)...")
67
+ rag_engine = RAGEngine()
68
+ print("RAG Engine ready.")
69
 
70
+ print("Initializing Qwen2.5-3B LLM (LangChain)...")
71
+ # Pass vector_store to LLMClient for RetrievalQA
72
+ llm_client = LLMClient(vector_store=rag_engine.vector_store)
73
+ print("LLM ready.")
 
 
 
 
 
 
 
74
 
75
  # --------- Helper functions ----------
76
  def is_url(text):
77
+ """Check if text is a URL"""
78
  url_pattern = re.compile(
79
+ r'^https?://' # http:// or https://
80
+ r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain...
81
+ r'localhost|' # localhost...
82
+ r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
83
+ r'(?::\d+)?' # optional port
84
  r'(?:/?|[/?]\S+)$', re.IGNORECASE)
85
  return url_pattern.match(text) is not None
86
 
87
  def fetch_html_content(url, timeout=10):
88
+ """Fetch HTML content from URL"""
89
  try:
90
  headers = {
91
  'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
 
94
  response.raise_for_status()
95
 
96
  soup = BeautifulSoup(response.text, 'html.parser')
97
+
98
  for script in soup(["script", "style", "meta", "noscript", "header", "footer"]):
99
  script.decompose()
100
 
101
  text = soup.get_text(separator=' ')
102
+
103
  clean_text = " ".join(text.split())
104
  return clean_text, response.status_code
105
  except requests.exceptions.RequestException as e:
 
108
  return None, f"General error: {str(e)}"
109
 
110
  def predict_single_text(text, text_type="text"):
111
+ """Predict for a single text input"""
112
+ # Tokenize
113
  inputs = tokenizer(
114
  text,
115
  return_tensors="pt",
 
117
  padding=True,
118
  max_length=256
119
  )
120
+ # DeBERTa typically doesn't use token_type_ids
121
  inputs.pop("token_type_ids", None)
122
+ # Move to device
123
  inputs = {k: v.to(device) for k, v in inputs.items()}
124
 
125
  with torch.no_grad():
126
  try:
127
+ # Try to get predictions with attention weights
128
  result = model(**inputs, return_attention=True)
129
  if isinstance(result, tuple) and len(result) == 3:
130
  logits, attention_weights, deberta_attentions = result
 
133
  logits = result
134
  has_attention = False
135
  except TypeError:
136
+ # Fallback for older model without return_attention parameter
137
  logits = model(**inputs)
138
  has_attention = False
139
 
140
  probs = F.softmax(logits, dim=-1).squeeze(0).tolist()
141
 
142
+ # Get tokens for visualization
143
  tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze(0).tolist())
144
+
145
  return probs, tokens, has_attention, attention_weights if has_attention else None
146
 
147
  def combine_predictions(url_probs, html_probs, url_weight=0.3, html_weight=0.7):
148
+ """Combine URL and HTML content predictions"""
149
  combined_probs = [
150
+ url_weight * url_probs[0] + html_weight * html_probs[0], # benign
151
+ url_weight * url_probs[1] + html_weight * html_probs[1] # phishing
152
  ]
153
  return combined_probs
154
 
155
+ # --------- Inference function ----------
156
  def predict_fn(text: str):
157
  if not text or not text.strip():
158
+ return {"error": "Please enter a URL or text."}, ""
159
 
160
+ # Check if input is URL
161
  if is_url(text.strip()):
162
+ # Process URL
163
  url = text.strip()
164
+
165
+ # Get prediction for URL itself
166
  url_probs, url_tokens, url_has_attention, url_attention = predict_single_text(url, "URL")
167
+
168
+ # Try to fetch HTML content
169
  html_content, status = fetch_html_content(url)
170
 
171
  if html_content:
172
+ # Get prediction for HTML content
173
  html_probs, html_tokens, html_has_attention, html_attention = predict_single_text(html_content, "HTML")
174
+
175
+ # Combine predictions
176
  combined_probs = combine_predictions(url_probs, html_probs)
177
 
178
+ # Use combined probabilities but show analysis for both
179
  probs = combined_probs
180
+ tokens = url_tokens + ["[SEP]"] + html_tokens[:50] # Limit HTML tokens for display
181
  has_attention = url_has_attention or html_has_attention
182
  attention_weights = url_attention if url_has_attention else html_attention
183
 
184
  analysis_type = "Combined URL + HTML Analysis"
185
  fetch_status = f"✅ Successfully fetched HTML content (Status: {status})"
186
+
187
  else:
188
+ # Fallback to URL-only analysis
189
  probs = url_probs
190
  tokens = url_tokens
191
  has_attention = url_has_attention
192
  attention_weights = url_attention
193
+
194
  analysis_type = "URL-only Analysis"
195
  fetch_status = f"⚠️ Could not fetch HTML content: {status}"
196
  else:
197
+ # Process as regular text
198
  probs, tokens, has_attention, attention_weights = predict_single_text(text, "text")
199
  analysis_type = "Text Analysis"
200
  fetch_status = ""
201
 
202
+ # Create detailed analysis
203
  predicted_class = "phishing" if probs[1] > probs[0] else "benign"
204
  confidence = max(probs)
205
 
 
218
  </div>
219
  </div>
220
  """
221
+
222
  if fetch_status:
223
  detailed_analysis += f"""
224
  <div style="background: #2d2d2d; padding: 15px; border-radius: 10px; margin: 15px 0; border-left: 4px solid #4caf50; color: #e0e0e0;">
 
228
 
229
  if has_attention and attention_weights is not None:
230
  attention_scores = attention_weights.squeeze(0).tolist()
231
+
232
  token_analysis = []
233
  for i, (token, score) in enumerate(zip(tokens, attention_scores)):
234
+ # More lenient filtering - include more tokens for text analysis
235
  if token not in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>'] and len(token.strip()) > 0 and score > 0.005:
236
+ clean_token = token.replace(' ', '').replace('Ġ', '').strip() # Handle different tokenizer prefixes
237
+ if clean_token: # Only add if token has content after cleaning
238
+ token_analysis.append({
239
+ 'token': clean_token,
240
+ 'importance': score,
241
+ 'position': i
242
+ })
243
 
244
+ # Sort by importance
245
  token_analysis.sort(key=lambda x: x['importance'], reverse=True)
246
 
247
  detailed_analysis += f"""
248
  ## Top important tokens:
249
  <div style="background: #2d2d2d; padding: 15px; border-radius: 10px; margin: 15px 0; border-left: 4px solid #4caf50; color: #e0e0e0;">
250
+ <strong>Analysis Info:</strong> Found {len(token_analysis)} important tokens out of {len(tokens)} total tokens
251
  </div>
252
  <div style="font-family: Arial, sans-serif;">
253
  """
254
+
255
+ for i, token_info in enumerate(token_analysis[:10]): # Top 10 tokens
256
  bar_width = int(token_info['importance'] * 100)
257
  color = "#ff4444" if predicted_class == "phishing" else "#44ff44"
258
+
259
  detailed_analysis += f"""
260
  <div style="margin: 8px 0; display: flex; align-items: center; background: #2d2d2d; padding: 8px; border-radius: 8px; border-left: 4px solid {color};">
261
+ <div style="width: 30px; text-align: right; margin-right: 10px; font-weight: bold; color: #ffffff;">
262
+ {i+1}.
263
+ </div>
264
+ <div style="width: 120px; margin-right: 10px; font-weight: bold; color: #e0e0e0; text-align: right;">
265
+ {token_info['token']}
266
+ </div>
267
  <div style="width: 300px; background-color: #404040; border-radius: 10px; overflow: hidden; margin-right: 10px; border: 1px solid #555;">
268
+ <div style="width: {bar_width}%; background-color: {color}; height: 20px; border-radius: 10px; transition: width 0.3s ease;"></div>
269
+ </div>
270
+ <div style="color: #cccccc; font-size: 12px; font-weight: bold;">
271
+ {token_info['importance']:.1%}
272
  </div>
 
273
  </div>
274
  """
 
275
 
276
+ detailed_analysis += "</div>\n"
277
+
278
  detailed_analysis += f"""
279
+ ## Detailed analysis:
280
+ <div style="font-family: Arial, sans-serif; background: linear-gradient(135deg, #1a237e 0%, #3949ab 100%); padding: 20px; border-radius: 15px; color: white; margin: 15px 0; border: 2px solid #3f51b5;">
281
+ <h3 style="margin: 0 0 15px 0; color: white;">Statistical Overview</h3>
282
+ <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;">
283
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; border: 1px solid rgba(255,255,255,0.2);">
284
+ <div style="font-size: 24px; font-weight: bold; color: white;">{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}</div>
285
+ <div style="font-size: 14px; color: #e0e0e0;">Total tokens</div>
286
+ </div>
287
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; border: 1px solid rgba(255,255,255,0.2);">
288
+ <div style="font-size: 24px; font-weight: bold, color: white;">{len([t for t in token_analysis if t['importance'] > 0.05])}</div>
289
+ <div style="font-size: 14px, color: #e0e0e0;">High impact tokens (>5%)</div>
290
+ </div>
291
+ </div>
292
+ </div>
293
  <div style="font-family: Arial, sans-serif; margin: 15px 0; background: #2d2d2d; padding: 20px; border-radius: 15px; border: 1px solid #555;">
294
  <h3 style="color: #ffffff; margin-bottom: 15px;"> Prediction Confidence</h3>
295
  <div style="display: flex; justify-content: space-between; margin-bottom: 10px;">
 
301
  {probs[1]:.1%}
302
  </div>
303
  </div>
304
+ <div style="margin-top: 10px; text-align: center; color: #cccccc; font-size: 14px;">
305
+ Benign: {probs[0]:.1%}
306
+ </div>
307
  </div>
308
  """
309
  else:
310
+ # Fallback analysis without attention weights
311
  detailed_analysis += f"""
312
+ <div style="background: linear-gradient(135deg, #1a237e 0%, #3949ab 100%); padding: 20px; border-radius: 15px; color: white; margin: 15px 0; border: 2px solid #3f51b5;">
313
+ <h3 style="margin: 0 0 15px 0; color: white;">Basic Analysis</h3>
314
+ <div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 15px;">
315
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);">
316
+ <div style="font-size: 24px; font-weight: bold; color: white;">{probs[1]:.1%}</div>
317
+ <div style="font-size: 14px; color: #e0e0e0;">Phishing</div>
318
+ </div>
319
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);">
320
+ <div style="font-size: 24px; font-weight: bold; color: white;">{probs[0]:.1%}</div>
321
+ <div style="font-size: 14px; color: #e0e0e0;">Benign</div>
322
+ </div>
323
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);">
324
+ <div style="font-size: 24px; font-weight: bold; color: white;">{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}</div>
325
+ <div style="font-size: 14px; color: #e0e0e0;">Tokens</div>
326
+ </div>
327
+ </div>
328
+ </div>
329
+ <div style="font-family: Arial, sans-serif; margin: 15px 0; background: #2d2d2d; padding: 20px; border-radius: 15px; border: 1px solid #555;">
330
+ <h3 style="color: #ffffff; margin: 0 0 15px 0;">🔤 Tokens in text:</h3>
331
+ <div style="display: flex; flex-wrap: wrap; gap: 8px;">""" + ''.join([f'<span style="background: #404040; color: #64b5f6; padding: 4px 8px; border-radius: 15px; font-size: 12px; border: 1px solid #666;">{token.replace(" ", "")}</span>' for token in tokens if token not in ['[CLS]', '[SEP]', '[PAD]']]) + f"""</div>
332
+ <div style="margin-top: 15px; padding: 10px; background: #3d2914; border-radius: 8px; border-left: 4px solid #ff9800;">
333
+ <strong style="color: #ffcc02;">Debug info:</strong> <span style="color: #e0e0e0;">Found {len(tokens)} total tokens, {len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])} content tokens</span>
334
+ </div>
335
+ </div>
336
+ <div style="background: #3d2914; padding: 15px; border-radius: 10px; border-left: 4px solid #ff9800; margin: 15px 0;">
337
+ <p style="margin: 0; color: #ffcc02; font-size: 14px;">
338
+ <strong>Note:</strong> Detailed attention weights analysis is not available for the current model.
339
+ </p>
340
  </div>
341
  """
342
+
343
+ # Build label->prob mapping for Gradio Label output
344
+ if len(LABELS) == len(probs):
345
+ prediction_result = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
346
+ else:
347
+ prediction_result = {f"class_{i}": float(p) for i, p in enumerate(probs)}
348
 
349
+ return prediction_result, detailed_analysis
350
 
351
  # --------- RAG Inference function ----------
352
  def rag_predict_fn(text: str):
353
  if not text or not text.strip():
354
  return "Please enter text to analyze."
355
 
356
+ # Call LLM (which now handles retrieval internally via LangChain)
357
+ response = llm_client.analyze(text)
358
+
359
+ return response
360
 
361
+ # --------- Refresh Knowledge Base function ----------
362
+ def refresh_kb():
363
+ return rag_engine.refresh_knowledge_base()
 
 
 
 
 
364
 
365
  # --------- Gradio UI ----------
366
  css_style="""
 
372
  .dark .gradio-container {
373
  background-color: #1e1e1e !important;
374
  }
375
+ /* Dark theme for all components */
376
  .block {
377
  background-color: #2d2d2d !important;
378
  border: 1px solid #444 !important;
379
  }
380
+ .gradio-textbox {
381
  background-color: #3d3d3d !important;
382
  color: #ffffff !important;
383
  border: 1px solid #666 !important;
 
389
  }
390
  .gradio-button:hover {
391
  background-color: #5a5a5a !important;
392
+ color: #ffffff !important;
393
+ border: 1px solid #666 !important;
394
  }
395
  """
396
+ with gr.Blocks(css=css_style) as demo:
397
+ gr.Markdown("# 🛡️ Phishing Detector (DeBERTa + LSTM + RAG)")
 
398
 
399
  with gr.Tabs():
400
+ # --- Tab 1: Standard Detection ---
401
+ with gr.TabItem("🔍 Standard Detection"):
402
  gr.Markdown("""
403
+ Enter a URL or text for analysis using the DeBERTa + LSTM model.
404
+
405
+ **Features:**
406
+ - **URL Analysis**: For URLs, the system will fetch HTML content and combine both URL and content analysis
407
+ - **Combined Prediction**: Uses weighted combination of URL structure and webpage content analysis
408
+ - **Visual Analysis**: Predict phishing/benign probability with visual charts
409
+ - **Token Importance**: Display the most important tokens in classification
410
+ - **Detailed Insights**: Comprehensive analysis of the impact of each token
411
+
412
+ **How it works for URLs:**
413
+ 1. Analyze the URL structure itself
414
+ 2. Fetch the webpage HTML content
415
+ 3. Analyze the webpage content
416
+ 4. Combine both results for final prediction (30% URL + 70% content)
417
  """)
418
 
419
  with gr.Row():
420
  with gr.Column(scale=2):
421
  input_box = gr.Textbox(
422
+ label="URL or text",
423
+ placeholder="Example: http://suspicious-site.example or paste any text",
424
  lines=3
425
  )
426
+ btn_submit = gr.Button("🔍 Analyze", variant="primary")
427
 
428
  gr.Examples(
429
  examples=[
430
  ["http://rendmoiunserviceeee.com"],
431
  ["https://www.google.com"],
432
+ ["Dear customer, your account has been suspended. Click here to verify your identity immediately."],
433
+ ["https://mail-secure-login-verify.example/path?token=suspicious"],
434
  ["http://paypaI-security-update.net/login"],
435
+ ["Your package has been delivered successfully. Thank you for using our service."],
436
+ ["https://github.com/user/repo"],
437
+ ["Dear customer, your account has been suspended. Click here to verify."],
438
  ],
439
  inputs=input_box
440
  )
441
 
442
  with gr.Column(scale=3):
443
  output_html = gr.HTML(label="Analysis Result")
444
+
445
  btn_submit.click(fn=predict_fn, inputs=input_box, outputs=output_html)
446
 
447
  # --- Tab 2: LLM + RAG Analysis ---
448
+ with gr.TabItem("🤖 AI Assistant (RAG)"):
449
  gr.Markdown("""
450
+ **AI Assistant** uses **Qwen2.5-3B** + **LangChain** to explain *why* a message is suspicious.
451
+
452
+ **Features:**
453
+ - 🌐 Multilingual support (English + Vietnamese)
454
+ - 📚 Knowledge Base retrieval (Auto-sync)
455
+ - 🚀 No rate limits (self-hosted)
456
  """)
457
 
458
  with gr.Row():
459
  with gr.Column(scale=1):
460
  rag_input = gr.Textbox(
461
+ label="Suspicious Text/URL",
462
+ placeholder="Paste the email content or URL here...",
463
  lines=5
464
  )
465
+ with gr.Row():
466
+ btn_rag = gr.Button("🤖 Ask AI Assistant", variant="primary")
467
+ btn_refresh = gr.Button("♻️ Refresh Knowledge Base")
468
 
469
+ gr.Examples(
470
+ examples=[
471
+ ["Your PayPal account has been suspended. Click http://paypal-verify.com to unlock."],
472
+ ["Tài khoản ngân hàng của bạn bị khóa. Nhấn vào đây để mở khóa ngay."],
473
+ ["Your package is ready for delivery. Track here: https://fedex-track.com"],
474
+ ],
475
+ inputs=rag_input
476
+ )
477
+
478
  with gr.Column(scale=1):
479
  rag_output = gr.Markdown(label="AI Analysis")
480
+ refresh_output = gr.Markdown(label="Status")
481
 
482
  btn_rag.click(fn=rag_predict_fn, inputs=[rag_input], outputs=rag_output)
483
+ btn_refresh.click(fn=refresh_kb, inputs=[], outputs=refresh_output)
484
 
485
  if __name__ == "__main__":
486
  demo.launch()