dungeon29 commited on
Commit
72b29fc
·
verified ·
1 Parent(s): 36d2903

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -123
app.py CHANGED
@@ -6,83 +6,70 @@ import gradio as gr
6
  import requests
7
  import re
8
  from urllib.parse import urlparse
9
- from bs4 import BeautifulSoup
10
  import time
11
- import joblib
12
  import os
13
 
14
  # --- import your architecture ---
15
- from model import DeBERTaLSTMClassifier
 
 
16
 
17
  # --- Import RAG modules ---
18
  from rag_engine import RAGEngine
19
  from llm_client import LLMClient
20
 
21
  # --------- Config ----------
22
- REPO_ID = "khoa-done/phishing-detector" # HF repo that holds the checkpoint
23
- CKPT_NAME = "deberta_lstm_checkpoint.pt" # the .pt file name
24
  MODEL_NAME = "microsoft/deberta-base" # base tokenizer/backbone
25
  LABELS = ["benign", "phishing"] # adjust to your classes
26
 
27
- # --------- Load DeBERTa model/tokenizer once (global) ----------
28
- print("🔷 Loading DeBERTa Model...")
 
 
 
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
31
 
32
- # Check if checkpoint exists locally, otherwise download from HF
33
- if os.path.exists(CKPT_NAME):
34
- print(f"📂 Found local checkpoint: {CKPT_NAME}")
35
- ckpt_path = CKPT_NAME
36
- else:
37
- print(f"⬇️ Downloading checkpoint {CKPT_NAME} from HF Hub...")
38
- try:
39
- ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_NAME)
40
- except Exception as e:
41
- print(f"⚠️ Could not download from HF: {e}")
42
- print("🔄 Trying fallback to pytorch_model.bin...")
43
- ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin")
44
-
45
  checkpoint = torch.load(ckpt_path, map_location=device)
46
 
47
- # Initialize model
48
- if isinstance(checkpoint, dict):
49
- model_args = checkpoint.get("model_args", {})
50
- else:
51
- model_args = {}
52
  model = DeBERTaLSTMClassifier(**model_args)
53
 
54
  # Load weights
55
  try:
56
- if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
57
- state_dict = checkpoint["model_state_dict"]
58
- elif isinstance(checkpoint, dict):
59
- state_dict = checkpoint
60
- else:
61
- state_dict = checkpoint
62
 
63
  model.load_state_dict(state_dict, strict=False)
64
 
 
65
  if hasattr(model, 'attention') and 'attention.weight' not in state_dict:
66
  print("⚠️ Loaded model without attention layer, using newly initialized attention weights")
67
  else:
68
- print("✅ DeBERTa Weights loaded successfully!")
69
 
70
  except Exception as e:
71
  print(f"❌ Error when loading weights: {e}")
72
- # Don't raise, allow app to start even if model fails (for debugging)
73
 
74
  model.to(device).eval()
75
 
76
  # --------- Initialize RAG & LLM ----------
77
- print("🔷 Initializing RAG Engine (LangChain)...")
78
  rag_engine = RAGEngine()
79
- print("RAG Engine ready.")
80
 
81
- print("🔷 Initializing Qwen2.5-1.5B LLM (LangChain)...")
82
  # Pass vector_store to LLMClient for RetrievalQA
83
  llm_client = LLMClient(vector_store=rag_engine.vector_store)
84
- print("LLM ready.")
85
-
86
 
87
  # --------- Helper functions ----------
88
  def is_url(text):
@@ -104,7 +91,7 @@ def fetch_html_content(url, timeout=10):
104
  }
105
  response = requests.get(url, headers=headers, timeout=timeout, verify=False)
106
  response.raise_for_status()
107
-
108
  return response.text, response.status_code
109
  except requests.exceptions.RequestException as e:
110
  return None, f"Request error: {str(e)}"
@@ -113,6 +100,7 @@ def fetch_html_content(url, timeout=10):
113
 
114
  def predict_single_text(text, text_type="text"):
115
  """Predict for a single text input"""
 
116
  inputs = tokenizer(
117
  text,
118
  return_tensors="pt",
@@ -120,11 +108,14 @@ def predict_single_text(text, text_type="text"):
120
  padding=True,
121
  max_length=256
122
  )
 
123
  inputs.pop("token_type_ids", None)
 
124
  inputs = {k: v.to(device) for k, v in inputs.items()}
125
 
126
  with torch.no_grad():
127
  try:
 
128
  result = model(**inputs, return_attention=True)
129
  if isinstance(result, tuple) and len(result) == 3:
130
  logits, attention_weights, deberta_attentions = result
@@ -133,11 +124,13 @@ def predict_single_text(text, text_type="text"):
133
  logits = result
134
  has_attention = False
135
  except TypeError:
 
136
  logits = model(**inputs)
137
  has_attention = False
138
 
139
  probs = F.softmax(logits, dim=-1).squeeze(0).tolist()
140
 
 
141
  tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze(0).tolist())
142
 
143
  return probs, tokens, has_attention, attention_weights if has_attention else None
@@ -150,38 +143,54 @@ def combine_predictions(url_probs, html_probs, url_weight=0.3, html_weight=0.7):
150
  ]
151
  return combined_probs
152
 
153
- # --------- DeBERTa Inference function ----------
154
  def predict_fn(text: str):
155
  if not text or not text.strip():
156
  return {"error": "Please enter a URL or text."}, ""
157
 
158
  # Check if input is URL
159
  if is_url(text.strip()):
 
160
  url = text.strip()
 
 
161
  url_probs, url_tokens, url_has_attention, url_attention = predict_single_text(url, "URL")
 
 
162
  html_content, status = fetch_html_content(url)
163
 
164
  if html_content:
 
165
  html_probs, html_tokens, html_has_attention, html_attention = predict_single_text(html_content, "HTML")
 
 
166
  combined_probs = combine_predictions(url_probs, html_probs)
 
 
167
  probs = combined_probs
168
- tokens = url_tokens + ["[SEP]"] + html_tokens[:50]
169
  has_attention = url_has_attention or html_has_attention
170
  attention_weights = url_attention if url_has_attention else html_attention
 
171
  analysis_type = "Combined URL + HTML Analysis"
172
  fetch_status = f"✅ Successfully fetched HTML content (Status: {status})"
 
173
  else:
 
174
  probs = url_probs
175
  tokens = url_tokens
176
  has_attention = url_has_attention
177
  attention_weights = url_attention
 
178
  analysis_type = "URL-only Analysis"
179
  fetch_status = f"⚠️ Could not fetch HTML content: {status}"
180
  else:
 
181
  probs, tokens, has_attention, attention_weights = predict_single_text(text, "text")
182
  analysis_type = "Text Analysis"
183
  fetch_status = ""
184
 
 
185
  predicted_class = "phishing" if probs[1] > probs[0] else "benign"
186
  confidence = max(probs)
187
 
@@ -210,13 +219,20 @@ def predict_fn(text: str):
210
 
211
  if has_attention and attention_weights is not None:
212
  attention_scores = attention_weights.squeeze(0).tolist()
 
213
  token_analysis = []
214
  for i, (token, score) in enumerate(zip(tokens, attention_scores)):
 
215
  if token not in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>'] and len(token.strip()) > 0 and score > 0.005:
216
- clean_token = token.replace(' ', '').replace('Ġ', '').strip()
217
- if clean_token:
218
- token_analysis.append({'token': clean_token, 'importance': score, 'position': i})
 
 
 
 
219
 
 
220
  token_analysis.sort(key=lambda x: x['importance'], reverse=True)
221
 
222
  detailed_analysis += f"""
@@ -226,20 +242,94 @@ def predict_fn(text: str):
226
  </div>
227
  <div style="font-family: Arial, sans-serif;">
228
  """
229
- for i, token_info in enumerate(token_analysis[:10]):
 
230
  bar_width = int(token_info['importance'] * 100)
231
  color = "#ff4444" if predicted_class == "phishing" else "#44ff44"
 
232
  detailed_analysis += f"""
233
  <div style="margin: 8px 0; display: flex; align-items: center; background: #2d2d2d; padding: 8px; border-radius: 8px; border-left: 4px solid {color};">
234
- <div style="width: 30px; text-align: right; margin-right: 10px; font-weight: bold; color: #ffffff;">{i+1}.</div>
235
- <div style="width: 120px; margin-right: 10px; font-weight: bold; color: #e0e0e0; text-align: right;">{token_info['token']}</div>
 
 
 
 
236
  <div style="width: 300px; background-color: #404040; border-radius: 10px; overflow: hidden; margin-right: 10px; border: 1px solid #555;">
237
  <div style="width: {bar_width}%; background-color: {color}; height: 20px; border-radius: 10px; transition: width 0.3s ease;"></div>
238
  </div>
239
- <div style="color: #cccccc; font-size: 12px; font-weight: bold;">{token_info['importance']:.1%}</div>
 
 
240
  </div>
241
  """
 
242
  detailed_analysis += "</div>\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
  # Build label->prob mapping for Gradio Label output
245
  if len(LABELS) == len(probs):
@@ -266,7 +356,7 @@ def rag_predict_fn(text: str):
266
 
267
  if fetched_content:
268
  # Limit content length to avoid token overflow
269
- truncated_content = fetched_content[:1500]
270
  analysis_context = f"URL: {input_text}\n\nWebsite Content:\n{truncated_content}\n..."
271
  print(f"✅ Successfully fetched {len(fetched_content)} chars from URL.")
272
  else:
@@ -278,11 +368,12 @@ def rag_predict_fn(text: str):
278
 
279
  return response
280
 
 
281
  def refresh_kb():
282
  return rag_engine.refresh_knowledge_base()
283
 
284
  # --------- Gradio UI ----------
285
- css_style = """
286
  .gradio-container {
287
  font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
288
  background-color: #1e1e1e !important;
@@ -308,84 +399,98 @@ css_style = """
308
  }
309
  .gradio-button:hover {
310
  background-color: #5a5a5a !important;
 
 
311
  }
312
  """
313
-
314
- # Tab 1: DeBERTa
315
- deberta_interface = gr.Interface(
316
- fn=predict_fn,
317
- inputs=gr.Textbox(label="URL or text", placeholder="Example: http://suspicious-site.example or paste any text"),
318
- outputs=[
319
- gr.Label(label="Prediction result"),
320
- gr.Markdown(label="Detailed token analysis")
321
- ],
322
- title="Phishing Detector (DeBERTa + LSTM)",
323
- description="""
324
- Enter a URL or text for analysis.
325
- **Features:**
326
- - **URL Analysis**: For URLs, the system will fetch HTML content and combine both URL and content analysis
327
- - **Combined Prediction**: Uses weighted combination of URL structure and webpage content analysis
328
- - **Visual Analysis**: Predict phishing/benign probability with visual charts
329
- - **Token Importance**: Display the most important tokens in classification
330
- - **Detailed Insights**: Comprehensive analysis of the impact of each token
331
- - **Dark Theme**: Beautiful interface with colorful charts optimized for dark themes
332
- """,
333
- examples=[
334
- ["http://rendmoiunserviceeee.com"],
335
- ["https://www.google.com"],
336
- ["Dear customer, your account has been suspended. Click here to verify your identity immediately."],
337
- ["https://mail-secure-login-verify.example/path?token=suspicious"],
338
- ["http://paypaI-security-update.net/login"],
339
- ["Your package has been delivered successfully. Thank you for using our service."],
340
- ["https://github.com/user/repo"]
341
- ]
342
- )
343
-
344
- # Tab 2: LLM + RAG
345
- with gr.Blocks() as rag_interface:
346
- gr.Markdown("# 🤖 AI Assistant (RAG)")
347
- gr.Markdown("""
348
- **AI Assistant** uses **Qwen2.5-1.5B** + **LangChain** to explain *why* a message is suspicious.
349
 
350
- **Features:**
351
- - 🌐 Multilingual support (English + Vietnamese)
352
- - 📚 Knowledge Base retrieval (Auto-sync)
353
- - � **Auto-Fetch URL**: Automatically reads website content for analysis
354
- """)
355
-
356
- with gr.Row():
357
- with gr.Column(scale=1):
358
- rag_input = gr.Textbox(
359
- label="Suspicious Text/URL",
360
- placeholder="Paste the email content or URL here...",
361
- lines=5
362
- )
 
 
 
 
 
 
 
363
  with gr.Row():
364
- btn_rag = gr.Button("🤖 Ask AI Assistant", variant="primary")
365
- btn_refresh = gr.Button("♻️ Refresh Knowledge Base")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
- gr.Examples(
368
- examples=[
369
- ["Your PayPal account has been suspended. Click http://paypal-verify.com to unlock."],
370
- ["Tài khoản ngân hàng của bạn bị khóa. Nhấn vào đây để mở khóa ngay."],
371
- ["Your package is ready for delivery. Track here: https://fedex-track.com"],
372
- ],
373
- inputs=rag_input
374
- )
375
-
376
- with gr.Column(scale=1):
377
- rag_output = gr.Markdown(label="AI Analysis")
378
- refresh_output = gr.Markdown(label="Status")
379
-
380
- btn_rag.click(fn=rag_predict_fn, inputs=[rag_input], outputs=rag_output)
381
- btn_refresh.click(fn=refresh_kb, inputs=[], outputs=refresh_output)
382
-
383
-
384
- # Combine Tabs
385
- demo = gr.TabbedInterface(
386
- [deberta_interface, rag_interface],
387
- ["DeBERTa + LSTM", "AI Assistant (RAG)"]
388
- )
 
 
 
 
 
 
 
 
 
 
389
 
390
  if __name__ == "__main__":
391
  demo.launch(ssr_mode=False)
 
6
  import requests
7
  import re
8
  from urllib.parse import urlparse
 
9
  import time
 
10
  import os
11
 
12
  # --- import your architecture ---
13
+ # Make sure this file is in the repo (e.g., models/deberta_lstm_classifier.py)
14
+ # and update the import path accordingly.
15
+ from model import DeBERTaLSTMClassifier # <-- your class
16
 
17
  # --- Import RAG modules ---
18
  from rag_engine import RAGEngine
19
  from llm_client import LLMClient
20
 
21
  # --------- Config ----------
22
+ REPO_ID = "dungeon29/phishing-deberta-lstm" # HF repo that holds the checkpoint
23
+ CKPT_NAME = "pytorch_model.bin" # the .pt file name
24
  MODEL_NAME = "microsoft/deberta-base" # base tokenizer/backbone
25
  LABELS = ["benign", "phishing"] # adjust to your classes
26
 
27
+ # If your checkpoint contains hyperparams, you can fetch them like:
28
+ # checkpoint.get("config") or checkpoint.get("model_args")
29
+ # and pass into DeBERTaLSTMClassifier(**model_args)
30
+
31
+ # --------- Load model/tokenizer once (global) ----------
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
34
 
35
+ ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
36
  checkpoint = torch.load(ckpt_path, map_location=device)
37
 
38
+ # If you saved hyperparams in the checkpoint, use them:
39
+ model_args = checkpoint.get("model_args", {}) # e.g., {"lstm_hidden":256, "num_labels":2, ...}
 
 
 
40
  model = DeBERTaLSTMClassifier(**model_args)
41
 
42
  # Load weights
43
  try:
44
+ state_dict = torch.load(ckpt_path, map_location=device)
45
+
46
+ # Xử lý nếu file lưu dạng checkpoint đầy đủ (có key "model_state_dict")
47
+ if "model_state_dict" in state_dict:
48
+ state_dict = state_dict["model_state_dict"]
 
49
 
50
  model.load_state_dict(state_dict, strict=False)
51
 
52
+ # Kiểm tra layer attention
53
  if hasattr(model, 'attention') and 'attention.weight' not in state_dict:
54
  print("⚠️ Loaded model without attention layer, using newly initialized attention weights")
55
  else:
56
+ print("✅ Load weights successfully!")
57
 
58
  except Exception as e:
59
  print(f"❌ Error when loading weights: {e}")
60
+ raise e
61
 
62
  model.to(device).eval()
63
 
64
  # --------- Initialize RAG & LLM ----------
65
+ print("Initializing RAG Engine (LangChain)...")
66
  rag_engine = RAGEngine()
67
+ print("RAG Engine ready.")
68
 
69
+ print("Initializing Qwen2.5-3B LLM (LangChain)...")
70
  # Pass vector_store to LLMClient for RetrievalQA
71
  llm_client = LLMClient(vector_store=rag_engine.vector_store)
72
+ print("LLM ready.")
 
73
 
74
  # --------- Helper functions ----------
75
  def is_url(text):
 
91
  }
92
  response = requests.get(url, headers=headers, timeout=timeout, verify=False)
93
  response.raise_for_status()
94
+
95
  return response.text, response.status_code
96
  except requests.exceptions.RequestException as e:
97
  return None, f"Request error: {str(e)}"
 
100
 
101
  def predict_single_text(text, text_type="text"):
102
  """Predict for a single text input"""
103
+ # Tokenize
104
  inputs = tokenizer(
105
  text,
106
  return_tensors="pt",
 
108
  padding=True,
109
  max_length=256
110
  )
111
+ # DeBERTa typically doesn't use token_type_ids
112
  inputs.pop("token_type_ids", None)
113
+ # Move to device
114
  inputs = {k: v.to(device) for k, v in inputs.items()}
115
 
116
  with torch.no_grad():
117
  try:
118
+ # Try to get predictions with attention weights
119
  result = model(**inputs, return_attention=True)
120
  if isinstance(result, tuple) and len(result) == 3:
121
  logits, attention_weights, deberta_attentions = result
 
124
  logits = result
125
  has_attention = False
126
  except TypeError:
127
+ # Fallback for older model without return_attention parameter
128
  logits = model(**inputs)
129
  has_attention = False
130
 
131
  probs = F.softmax(logits, dim=-1).squeeze(0).tolist()
132
 
133
+ # Get tokens for visualization
134
  tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze(0).tolist())
135
 
136
  return probs, tokens, has_attention, attention_weights if has_attention else None
 
143
  ]
144
  return combined_probs
145
 
146
+ # --------- Inference function ----------
147
  def predict_fn(text: str):
148
  if not text or not text.strip():
149
  return {"error": "Please enter a URL or text."}, ""
150
 
151
  # Check if input is URL
152
  if is_url(text.strip()):
153
+ # Process URL
154
  url = text.strip()
155
+
156
+ # Get prediction for URL itself
157
  url_probs, url_tokens, url_has_attention, url_attention = predict_single_text(url, "URL")
158
+
159
+ # Try to fetch HTML content
160
  html_content, status = fetch_html_content(url)
161
 
162
  if html_content:
163
+ # Get prediction for HTML content
164
  html_probs, html_tokens, html_has_attention, html_attention = predict_single_text(html_content, "HTML")
165
+
166
+ # Combine predictions
167
  combined_probs = combine_predictions(url_probs, html_probs)
168
+
169
+ # Use combined probabilities but show analysis for both
170
  probs = combined_probs
171
+ tokens = url_tokens + ["[SEP]"] + html_tokens[:50] # Limit HTML tokens for display
172
  has_attention = url_has_attention or html_has_attention
173
  attention_weights = url_attention if url_has_attention else html_attention
174
+
175
  analysis_type = "Combined URL + HTML Analysis"
176
  fetch_status = f"✅ Successfully fetched HTML content (Status: {status})"
177
+
178
  else:
179
+ # Fallback to URL-only analysis
180
  probs = url_probs
181
  tokens = url_tokens
182
  has_attention = url_has_attention
183
  attention_weights = url_attention
184
+
185
  analysis_type = "URL-only Analysis"
186
  fetch_status = f"⚠️ Could not fetch HTML content: {status}"
187
  else:
188
+ # Process as regular text
189
  probs, tokens, has_attention, attention_weights = predict_single_text(text, "text")
190
  analysis_type = "Text Analysis"
191
  fetch_status = ""
192
 
193
+ # Create detailed analysis
194
  predicted_class = "phishing" if probs[1] > probs[0] else "benign"
195
  confidence = max(probs)
196
 
 
219
 
220
  if has_attention and attention_weights is not None:
221
  attention_scores = attention_weights.squeeze(0).tolist()
222
+
223
  token_analysis = []
224
  for i, (token, score) in enumerate(zip(tokens, attention_scores)):
225
+ # More lenient filtering - include more tokens for text analysis
226
  if token not in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>'] and len(token.strip()) > 0 and score > 0.005:
227
+ clean_token = token.replace(' ', '').replace('Ġ', '').strip() # Handle different tokenizer prefixes
228
+ if clean_token: # Only add if token has content after cleaning
229
+ token_analysis.append({
230
+ 'token': clean_token,
231
+ 'importance': score,
232
+ 'position': i
233
+ })
234
 
235
+ # Sort by importance
236
  token_analysis.sort(key=lambda x: x['importance'], reverse=True)
237
 
238
  detailed_analysis += f"""
 
242
  </div>
243
  <div style="font-family: Arial, sans-serif;">
244
  """
245
+
246
+ for i, token_info in enumerate(token_analysis[:10]): # Top 10 tokens
247
  bar_width = int(token_info['importance'] * 100)
248
  color = "#ff4444" if predicted_class == "phishing" else "#44ff44"
249
+
250
  detailed_analysis += f"""
251
  <div style="margin: 8px 0; display: flex; align-items: center; background: #2d2d2d; padding: 8px; border-radius: 8px; border-left: 4px solid {color};">
252
+ <div style="width: 30px; text-align: right; margin-right: 10px; font-weight: bold; color: #ffffff;">
253
+ {i+1}.
254
+ </div>
255
+ <div style="width: 120px; margin-right: 10px; font-weight: bold; color: #e0e0e0; text-align: right;">
256
+ {token_info['token']}
257
+ </div>
258
  <div style="width: 300px; background-color: #404040; border-radius: 10px; overflow: hidden; margin-right: 10px; border: 1px solid #555;">
259
  <div style="width: {bar_width}%; background-color: {color}; height: 20px; border-radius: 10px; transition: width 0.3s ease;"></div>
260
  </div>
261
+ <div style="color: #cccccc; font-size: 12px; font-weight: bold;">
262
+ {token_info['importance']:.1%}
263
+ </div>
264
  </div>
265
  """
266
+
267
  detailed_analysis += "</div>\n"
268
+
269
+ detailed_analysis += f"""
270
+ ## Detailed analysis:
271
+ <div style="font-family: Arial, sans-serif; background: linear-gradient(135deg, #1a237e 0%, #3949ab 100%); padding: 20px; border-radius: 15px; color: white; margin: 15px 0; border: 2px solid #3f51b5;">
272
+ <h3 style="margin: 0 0 15px 0; color: white;">Statistical Overview</h3>
273
+ <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;">
274
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; border: 1px solid rgba(255,255,255,0.2);">
275
+ <div style="font-size: 24px; font-weight: bold; color: white;">{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}</div>
276
+ <div style="font-size: 14px; color: #e0e0e0;">Total tokens</div>
277
+ </div>
278
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; border: 1px solid rgba(255,255,255,0.2);">
279
+ <div style="font-size: 24px; font-weight: bold, color: white;">{len([t for t in token_analysis if t['importance'] > 0.05])}</div>
280
+ <div style="font-size: 14px, color: #e0e0e0;">High impact tokens (>5%)</div>
281
+ </div>
282
+ </div>
283
+ </div>
284
+ <div style="font-family: Arial, sans-serif; margin: 15px 0; background: #2d2d2d; padding: 20px; border-radius: 15px; border: 1px solid #555;">
285
+ <h3 style="color: #ffffff; margin-bottom: 15px;"> Prediction Confidence</h3>
286
+ <div style="display: flex; justify-content: space-between; margin-bottom: 10px;">
287
+ <span style="font-weight: bold; color: #ff4444;">Phishing</span>
288
+ <span style="font-weight: bold; color: #44ff44;">Benign</span>
289
+ </div>
290
+ <div style="width: 100%; background-color: #404040; border-radius: 25px; overflow: hidden; height: 30px; border: 1px solid #666;">
291
+ <div style="width: {probs[1]*100:.1f}%; background: linear-gradient(90deg, #ff4444 0%, #ff6666 100%); height: 100%; display: flex; align-items: center; justify-content: center; color: white; font-weight: bold; font-size: 14px;">
292
+ {probs[1]:.1%}
293
+ </div>
294
+ </div>
295
+ <div style="margin-top: 10px; text-align: center; color: #cccccc; font-size: 14px;">
296
+ Benign: {probs[0]:.1%}
297
+ </div>
298
+ </div>
299
+ """
300
+ else:
301
+ # Fallback analysis without attention weights
302
+ detailed_analysis += f"""
303
+ <div style="background: linear-gradient(135deg, #1a237e 0%, #3949ab 100%); padding: 20px; border-radius: 15px; color: white; margin: 15px 0; border: 2px solid #3f51b5;">
304
+ <h3 style="margin: 0 0 15px 0; color: white;">Basic Analysis</h3>
305
+ <div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 15px;">
306
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);">
307
+ <div style="font-size: 24px; font-weight: bold; color: white;">{probs[1]:.1%}</div>
308
+ <div style="font-size: 14px; color: #e0e0e0;">Phishing</div>
309
+ </div>
310
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);">
311
+ <div style="font-size: 24px; font-weight: bold; color: white;">{probs[0]:.1%}</div>
312
+ <div style="font-size: 14px; color: #e0e0e0;">Benign</div>
313
+ </div>
314
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);">
315
+ <div style="font-size: 24px; font-weight: bold; color: white;">{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}</div>
316
+ <div style="font-size: 14px; color: #e0e0e0;">Tokens</div>
317
+ </div>
318
+ </div>
319
+ </div>
320
+ <div style="font-family: Arial, sans-serif; margin: 15px 0; background: #2d2d2d; padding: 20px; border-radius: 15px; border: 1px solid #555;">
321
+ <h3 style="color: #ffffff; margin: 0 0 15px 0;">🔤 Tokens in text:</h3>
322
+ <div style="display: flex; flex-wrap: wrap; gap: 8px;">""" + ''.join([f'<span style="background: #404040; color: #64b5f6; padding: 4px 8px; border-radius: 15px; font-size: 12px; border: 1px solid #666;">{token.replace(" ", "")}</span>' for token in tokens if token not in ['[CLS]', '[SEP]', '[PAD]']]) + f"""</div>
323
+ <div style="margin-top: 15px; padding: 10px; background: #3d2914; border-radius: 8px; border-left: 4px solid #ff9800;">
324
+ <strong style="color: #ffcc02;">Debug info:</strong> <span style="color: #e0e0e0;">Found {len(tokens)} total tokens, {len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])} content tokens</span>
325
+ </div>
326
+ </div>
327
+ <div style="background: #3d2914; padding: 15px; border-radius: 10px; border-left: 4px solid #ff9800; margin: 15px 0;">
328
+ <p style="margin: 0; color: #ffcc02; font-size: 14px;">
329
+ <strong>Note:</strong> Detailed attention weights analysis is not available for the current model.
330
+ </p>
331
+ </div>
332
+ """
333
 
334
  # Build label->prob mapping for Gradio Label output
335
  if len(LABELS) == len(probs):
 
356
 
357
  if fetched_content:
358
  # Limit content length to avoid token overflow
359
+ truncated_content = fetched_content[:4000]
360
  analysis_context = f"URL: {input_text}\n\nWebsite Content:\n{truncated_content}\n..."
361
  print(f"✅ Successfully fetched {len(fetched_content)} chars from URL.")
362
  else:
 
368
 
369
  return response
370
 
371
+ # --------- Refresh Knowledge Base function ----------
372
  def refresh_kb():
373
  return rag_engine.refresh_knowledge_base()
374
 
375
  # --------- Gradio UI ----------
376
+ css_style="""
377
  .gradio-container {
378
  font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
379
  background-color: #1e1e1e !important;
 
399
  }
400
  .gradio-button:hover {
401
  background-color: #5a5a5a !important;
402
+ color: #ffffff !important;
403
+ border: 1px solid #666 !important;
404
  }
405
  """
406
+ with gr.Blocks() as demo:
407
+ gr.Markdown("# 🛡️ Phishing Detector (DeBERTa + LSTM + RAG)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
 
409
+ with gr.Tabs():
410
+ # --- Tab 1: Standard Detection ---
411
+ with gr.TabItem("🔍 Standard Detection"):
412
+ gr.Markdown("""
413
+ Enter a URL or text for analysis using the DeBERTa + LSTM model.
414
+
415
+ **Features:**
416
+ - **URL Analysis**: For URLs, the system will fetch HTML content and combine both URL and content analysis
417
+ - **Combined Prediction**: Uses weighted combination of URL structure and webpage content analysis
418
+ - **Visual Analysis**: Predict phishing/benign probability with visual charts
419
+ - **Token Importance**: Display the most important tokens in classification
420
+ - **Detailed Insights**: Comprehensive analysis of the impact of each token
421
+
422
+ **How it works for URLs:**
423
+ 1. Analyze the URL structure itself
424
+ 2. Fetch the webpage HTML content
425
+ 3. Analyze the webpage content
426
+ 4. Combine both results for final prediction (30% URL + 70% content)
427
+ """)
428
+
429
  with gr.Row():
430
+ with gr.Column(scale=2):
431
+ input_box = gr.Textbox(
432
+ label="URL or text",
433
+ placeholder="Example: http://suspicious-site.example or paste any text",
434
+ lines=3
435
+ )
436
+ btn_submit = gr.Button("🔍 Analyze", variant="primary")
437
+
438
+ gr.Examples(
439
+ examples=[
440
+ ["http://rendmoiunserviceeee.com"],
441
+ ["https://www.google.com"],
442
+ ["Dear customer, your account has been suspended. Click here to verify your identity immediately."],
443
+ ["https://mail-secure-login-verify.example/path?token=suspicious"],
444
+ ["http://paypaI-security-update.net/login"],
445
+ ["Your package has been delivered successfully. Thank you for using our service."],
446
+ ["https://github.com/user/repo"],
447
+ ["Dear customer, your account has been suspended. Click here to verify."],
448
+ ],
449
+ inputs=input_box
450
+ )
451
+
452
+ with gr.Column(scale=3):
453
+ output_html = gr.HTML(label="Analysis Result")
454
+
455
+ btn_submit.click(fn=predict_fn, inputs=input_box, outputs=output_html)
456
+
457
+ # --- Tab 2: LLM + RAG Analysis ---
458
+ with gr.TabItem("🤖 AI Assistant (RAG)"):
459
+ gr.Markdown("""
460
+ **AI Assistant** uses **Qwen2.5-3B** + **LangChain** to explain *why* a message is suspicious.
461
 
462
+ **Features:**
463
+ - 🌐 Multilingual support (English + Vietnamese)
464
+ - 📚 Knowledge Base retrieval (Auto-sync)
465
+ - 🚀 No rate limits (self-hosted)
466
+ """)
467
+
468
+ with gr.Row():
469
+ with gr.Column(scale=1):
470
+ rag_input = gr.Textbox(
471
+ label="Suspicious Text/URL",
472
+ placeholder="Paste the email content or URL here...",
473
+ lines=5
474
+ )
475
+ with gr.Row():
476
+ btn_rag = gr.Button("🤖 Ask AI Assistant", variant="primary")
477
+ btn_refresh = gr.Button("♻️ Refresh Knowledge Base")
478
+
479
+ gr.Examples(
480
+ examples=[
481
+ ["Your PayPal account has been suspended. Click http://paypal-verify.com to unlock."],
482
+ ["Tài khoản ngân hàng của bạn bị khóa. Nhấn vào đây để mở khóa ngay."],
483
+ ["Your package is ready for delivery. Track here: https://fedex-track.com"],
484
+ ],
485
+ inputs=rag_input
486
+ )
487
+
488
+ with gr.Column(scale=1):
489
+ rag_output = gr.Markdown(label="AI Analysis")
490
+ refresh_output = gr.Markdown(label="Status")
491
+
492
+ btn_rag.click(fn=rag_predict_fn, inputs=[rag_input], outputs=rag_output)
493
+ btn_refresh.click(fn=refresh_kb, inputs=[], outputs=refresh_output)
494
 
495
  if __name__ == "__main__":
496
  demo.launch(ssr_mode=False)