dungeon29 commited on
Commit
d75c95a
·
verified ·
1 Parent(s): adbfcbd

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +399 -0
app.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ from huggingface_hub import hf_hub_download
5
+ import gradio as gr
6
+ import requests
7
+ import re
8
+ from urllib.parse import urlparse
9
+ from bs4 import BeautifulSoup
10
+ import time
11
+ import joblib
12
+
13
+
14
+
15
+ # --- import your architecture ---
16
+ # Make sure this file is in the repo (e.g., models/deberta_lstm_classifier.py)
17
+ # and update the import path accordingly.
18
+ from model import DeBERTaLSTMClassifier # <-- your class
19
+
20
+ # --------- Config ----------
21
+ REPO_ID = "dungeon29/DetectPhishing" # HF repo that holds the checkpoint
22
+ CKPT_NAME = "deberta_lstm_checkpoint.pt" # the .pt file name
23
+ MODEL_NAME = "microsoft/deberta-base" # base tokenizer/backbone
24
+ LABELS = ["benign", "phishing"] # adjust to your classes
25
+
26
+ # If your checkpoint contains hyperparams, you can fetch them like:
27
+ # checkpoint.get("config") or checkpoint.get("model_args")
28
+ # and pass into DeBERTaLSTMClassifier(**model_args)
29
+
30
+ # --------- Load model/tokenizer once (global) ----------
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
33
+
34
+ ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_NAME)
35
+ checkpoint = torch.load(ckpt_path, map_location=device)
36
+
37
+ # If you saved hyperparams in the checkpoint, use them:
38
+ model_args = checkpoint.get("model_args", {}) # e.g., {"lstm_hidden":256, "num_labels":2, ...}
39
+ model = DeBERTaLSTMClassifier(**model_args)
40
+
41
+ # Load state dict and handle missing attention layer for older models
42
+ try:
43
+ model.load_state_dict(checkpoint["model_state_dict"])
44
+ except RuntimeError as e:
45
+ if "attention" in str(e):
46
+ # Old model without attention layer - initialize attention layer and load partial state
47
+ state_dict = checkpoint["model_state_dict"]
48
+ model_dict = model.state_dict()
49
+ # Filter out attention layer parameters
50
+ filtered_dict = {k: v for k, v in state_dict.items() if "attention" not in k}
51
+ model_dict.update(filtered_dict)
52
+ model.load_state_dict(model_dict)
53
+ print("Loaded model without attention layer, using newly initialized attention weights")
54
+ else:
55
+ raise e
56
+
57
+ model.to(device).eval()
58
+
59
+ # --------- Helper functions ----------
60
+ def is_url(text):
61
+ """Check if text is a URL"""
62
+ url_pattern = re.compile(
63
+ r'^https?://' # http:// or https://
64
+ r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain...
65
+ r'localhost|' # localhost...
66
+ r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
67
+ r'(?::\d+)?' # optional port
68
+ r'(?:/?|[/?]\S+)$', re.IGNORECASE)
69
+ return url_pattern.match(text) is not None
70
+
71
+ def fetch_html_content(url, timeout=10):
72
+ """Fetch HTML content from URL"""
73
+ try:
74
+ headers = {
75
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
76
+ }
77
+ response = requests.get(url, headers=headers, timeout=timeout, verify=False)
78
+ response.raise_for_status()
79
+ return response.text, response.status_code
80
+ except requests.exceptions.RequestException as e:
81
+ return None, f"Request error: {str(e)}"
82
+ except Exception as e:
83
+ return None, f"General error: {str(e)}"
84
+
85
+ def predict_single_text(text, text_type="text"):
86
+ """Predict for a single text input"""
87
+ # Tokenize
88
+ inputs = tokenizer(
89
+ text,
90
+ return_tensors="pt",
91
+ truncation=True,
92
+ padding=True,
93
+ max_length=256
94
+ )
95
+ # DeBERTa typically doesn't use token_type_ids
96
+ inputs.pop("token_type_ids", None)
97
+ # Move to device
98
+ inputs = {k: v.to(device) for k, v in inputs.items()}
99
+
100
+ with torch.no_grad():
101
+ try:
102
+ # Try to get predictions with attention weights
103
+ result = model(**inputs, return_attention=True)
104
+ if isinstance(result, tuple) and len(result) == 3:
105
+ logits, attention_weights, deberta_attentions = result
106
+ has_attention = True
107
+ else:
108
+ logits = result
109
+ has_attention = False
110
+ except TypeError:
111
+ # Fallback for older model without return_attention parameter
112
+ logits = model(**inputs)
113
+ has_attention = False
114
+
115
+ probs = F.softmax(logits, dim=-1).squeeze(0).tolist()
116
+
117
+ # Get tokens for visualization
118
+ tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze(0).tolist())
119
+
120
+ return probs, tokens, has_attention, attention_weights if has_attention else None
121
+
122
+ def combine_predictions(url_probs, html_probs, url_weight=0.3, html_weight=0.7):
123
+ """Combine URL and HTML content predictions"""
124
+ combined_probs = [
125
+ url_weight * url_probs[0] + html_weight * html_probs[0], # benign
126
+ url_weight * url_probs[1] + html_weight * html_probs[1] # phishing
127
+ ]
128
+ return combined_probs
129
+
130
+ # --------- Inference function ----------
131
+ def predict_fn(text: str):
132
+ if not text or not text.strip():
133
+ return {"error": "Please enter a URL or text."}, ""
134
+
135
+ # Check if input is URL
136
+ if is_url(text.strip()):
137
+ # Process URL
138
+ url = text.strip()
139
+
140
+ # Get prediction for URL itself
141
+ url_probs, url_tokens, url_has_attention, url_attention = predict_single_text(url, "URL")
142
+
143
+ # Try to fetch HTML content
144
+ html_content, status = fetch_html_content(url)
145
+
146
+ if html_content:
147
+ # Get prediction for HTML content
148
+ html_probs, html_tokens, html_has_attention, html_attention = predict_single_text(html_content, "HTML")
149
+
150
+ # Combine predictions
151
+ combined_probs = combine_predictions(url_probs, html_probs)
152
+
153
+ # Use combined probabilities but show analysis for both
154
+ probs = combined_probs
155
+ tokens = url_tokens + ["[SEP]"] + html_tokens[:50] # Limit HTML tokens for display
156
+ has_attention = url_has_attention or html_has_attention
157
+ attention_weights = url_attention if url_has_attention else html_attention
158
+
159
+ analysis_type = "Combined URL + HTML Analysis"
160
+ fetch_status = f"✅ Successfully fetched HTML content (Status: {status})"
161
+
162
+ else:
163
+ # Fallback to URL-only analysis
164
+ probs = url_probs
165
+ tokens = url_tokens
166
+ has_attention = url_has_attention
167
+ attention_weights = url_attention
168
+
169
+ analysis_type = "URL-only Analysis"
170
+ fetch_status = f"⚠️ Could not fetch HTML content: {status}"
171
+ else:
172
+ # Process as regular text
173
+ probs, tokens, has_attention, attention_weights = predict_single_text(text, "text")
174
+ analysis_type = "Text Analysis"
175
+ fetch_status = ""
176
+
177
+ # Get tokens for visualization
178
+
179
+ # Create detailed analysis
180
+ predicted_class = "phishing" if probs[1] > probs[0] else "benign"
181
+ confidence = max(probs)
182
+
183
+ detailed_analysis = f"""
184
+ <div style="font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; background: #1e1e1e; padding: 20px; border-radius: 15px;">
185
+ <div style="background: linear-gradient(135deg, {'#8b0000' if predicted_class == 'phishing' else '#006400'} 0%, {'#dc143c' if predicted_class == 'phishing' else '#228b22'} 100%); padding: 25px; border-radius: 20px; color: white; text-align: center; margin-bottom: 20px; box-shadow: 0 8px 32px rgba(0,0,0,0.5); border: 2px solid {'#ff4444' if predicted_class == 'phishing' else '#44ff44'};">
186
+ <h2 style="margin: 0 0 10px 0; font-size: 28px; color: white;">🔍 {analysis_type}</h2>
187
+ <div style="font-size: 36px; font-weight: bold; margin: 10px 0; color: white;">
188
+ {predicted_class.upper()}
189
+ </div>
190
+ <div style="font-size: 18px; color: #f0f0f0;">
191
+ Confidence: {confidence:.1%}
192
+ </div>
193
+ <div style="margin-top: 15px; font-size: 14px; color: #e0e0e0;">
194
+ {'This appears to be a phishing attempt!' if predicted_class == 'phishing' else '✅ This appears to be legitimate content.'}
195
+ </div>
196
+ </div>
197
+ """
198
+
199
+ if fetch_status:
200
+ detailed_analysis += f"""
201
+ <div style="background: #2d2d2d; padding: 15px; border-radius: 10px; margin: 15px 0; border-left: 4px solid #4caf50; color: #e0e0e0;">
202
+ <strong>Fetch Status:</strong> {fetch_status}
203
+ </div>
204
+ """
205
+
206
+ if has_attention and attention_weights is not None:
207
+ attention_scores = attention_weights.squeeze(0).tolist()
208
+
209
+ token_analysis = []
210
+ for i, (token, score) in enumerate(zip(tokens, attention_scores)):
211
+ # More lenient filtering - include more tokens for text analysis
212
+ if token not in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>'] and len(token.strip()) > 0 and score > 0.005:
213
+ clean_token = token.replace('▁', '').replace('Ġ', '').strip() # Handle different tokenizer prefixes
214
+ if clean_token: # Only add if token has content after cleaning
215
+ token_analysis.append({
216
+ 'token': clean_token,
217
+ 'importance': score,
218
+ 'position': i
219
+ })
220
+
221
+ # Sort by importance
222
+ token_analysis.sort(key=lambda x: x['importance'], reverse=True)
223
+
224
+ detailed_analysis += f"""
225
+ ## Top important tokens:
226
+ <div style="background: #2d2d2d; padding: 15px; border-radius: 10px; margin: 15px 0; border-left: 4px solid #4caf50; color: #e0e0e0;">
227
+ <strong>Analysis Info:</strong> Found {len(token_analysis)} important tokens out of {len(tokens)} total tokens
228
+ </div>
229
+ <div style="font-family: Arial, sans-serif;">
230
+ """
231
+
232
+ for i, token_info in enumerate(token_analysis[:10]): # Top 10 tokens
233
+ bar_width = int(token_info['importance'] * 100)
234
+ color = "#ff4444" if predicted_class == "phishing" else "#44ff44"
235
+
236
+ detailed_analysis += f"""
237
+ <div style="margin: 8px 0; display: flex; align-items: center; background: #2d2d2d; padding: 8px; border-radius: 8px; border-left: 4px solid {color};">
238
+ <div style="width: 30px; text-align: right; margin-right: 10px; font-weight: bold; color: #ffffff;">
239
+ {i+1}.
240
+ </div>
241
+ <div style="width: 120px; margin-right: 10px; font-weight: bold; color: #e0e0e0; text-align: right;">
242
+ {token_info['token']}
243
+ </div>
244
+ <div style="width: 300px; background-color: #404040; border-radius: 10px; overflow: hidden; margin-right: 10px; border: 1px solid #555;">
245
+ <div style="width: {bar_width}%; background-color: {color}; height: 20px; border-radius: 10px; transition: width 0.3s ease;"></div>
246
+ </div>
247
+ <div style="color: #cccccc; font-size: 12px; font-weight: bold;">
248
+ {token_info['importance']:.1%}
249
+ </div>
250
+ </div>
251
+ """
252
+
253
+ detailed_analysis += "</div>\n"
254
+
255
+ detailed_analysis += f"""
256
+ ## Detailed analysis:
257
+ <div style="font-family: Arial, sans-serif; background: linear-gradient(135deg, #1a237e 0%, #3949ab 100%); padding: 20px; border-radius: 15px; color: white; margin: 15px 0; border: 2px solid #3f51b5;">
258
+ <h3 style="margin: 0 0 15px 0; color: white;">Statistical Overview</h3>
259
+ <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;">
260
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; border: 1px solid rgba(255,255,255,0.2);">
261
+ <div style="font-size: 24px; font-weight: bold; color: white;">{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}</div>
262
+ <div style="font-size: 14px; color: #e0e0e0;">Total tokens</div>
263
+ </div>
264
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; border: 1px solid rgba(255,255,255,0.2);">
265
+ <div style="font-size: 24px; font-weight: bold, color: white;">{len([t for t in token_analysis if t['importance'] > 0.05])}</div>
266
+ <div style="font-size: 14px, color: #e0e0e0;">High impact tokens (>5%)</div>
267
+ </div>
268
+ </div>
269
+ </div>
270
+ <div style="font-family: Arial, sans-serif; margin: 15px 0; background: #2d2d2d; padding: 20px; border-radius: 15px; border: 1px solid #555;">
271
+ <h3 style="color: #ffffff; margin-bottom: 15px;"> Prediction Confidence</h3>
272
+ <div style="display: flex; justify-content: space-between; margin-bottom: 10px;">
273
+ <span style="font-weight: bold; color: #ff4444;">Phishing</span>
274
+ <span style="font-weight: bold; color: #44ff44;">Benign</span>
275
+ </div>
276
+ <div style="width: 100%; background-color: #404040; border-radius: 25px; overflow: hidden; height: 30px; border: 1px solid #666;">
277
+ <div style="width: {probs[1]*100:.1f}%; background: linear-gradient(90deg, #ff4444 0%, #ff6666 100%); height: 100%; display: flex; align-items: center; justify-content: center; color: white; font-weight: bold; font-size: 14px;">
278
+ {probs[1]:.1%}
279
+ </div>
280
+ </div>
281
+ <div style="margin-top: 10px; text-align: center; color: #cccccc; font-size: 14px;">
282
+ Benign: {probs[0]:.1%}
283
+ </div>
284
+ </div>
285
+ """
286
+ else:
287
+ # Fallback analysis without attention weights
288
+ detailed_analysis += f"""
289
+ <div style="background: linear-gradient(135deg, #1a237e 0%, #3949ab 100%); padding: 20px; border-radius: 15px; color: white; margin: 15px 0; border: 2px solid #3f51b5;">
290
+ <h3 style="margin: 0 0 15px 0; color: white;">Basic Analysis</h3>
291
+ <div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 15px;">
292
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);">
293
+ <div style="font-size: 24px; font-weight: bold; color: white;">{probs[1]:.1%}</div>
294
+ <div style="font-size: 14px; color: #e0e0e0;">Phishing</div>
295
+ </div>
296
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);">
297
+ <div style="font-size: 24px; font-weight: bold; color: white;">{probs[0]:.1%}</div>
298
+ <div style="font-size: 14px; color: #e0e0e0;">Benign</div>
299
+ </div>
300
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);">
301
+ <div style="font-size: 24px; font-weight: bold; color: white;">{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}</div>
302
+ <div style="font-size: 14px; color: #e0e0e0;">Tokens</div>
303
+ </div>
304
+ </div>
305
+ </div>
306
+ <div style="background: #2d2d2d; padding: 20px; border-radius: 15px; margin: 15px 0; border: 1px solid #555;">
307
+ <h3 style="color: #ffffff; margin: 0 0 15px 0;">🔤 Tokens in text:</h3>
308
+ <div style="display: flex; flex-wrap: wrap; gap: 8px;">""" + ''.join([f'<span style="background: #404040; color: #64b5f6; padding: 4px 8px; border-radius: 15px; font-size: 12px; border: 1px solid #666;">{token.replace("▁", "")}</span>' for token in tokens if token not in ['[CLS]', '[SEP]', '[PAD]']]) + f"""</div>
309
+ <div style="margin-top: 15px; padding: 10px; background: #3d2914; border-radius: 8px; border-left: 4px solid #ff9800;">
310
+ <strong style="color: #ffcc02;">Debug info:</strong> <span style="color: #e0e0e0;">Found {len(tokens)} total tokens, {len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])} content tokens</span>
311
+ </div>
312
+ </div>
313
+ <div style="background: #3d2914; padding: 15px; border-radius: 10px; border-left: 4px solid #ff9800; margin: 15px 0;">
314
+ <p style="margin: 0; color: #ffcc02; font-size: 14px;">
315
+ <strong>Note:</strong> Detailed attention weights analysis is not available for the current model.
316
+ </p>
317
+ </div>
318
+ """
319
+
320
+ # Build label->prob mapping for Gradio Label output
321
+ if len(LABELS) == len(probs):
322
+ prediction_result = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
323
+ else:
324
+ prediction_result = {f"class_{i}": float(p) for i, p in enumerate(probs)}
325
+
326
+ return prediction_result, detailed_analysis
327
+
328
+ # --------- Gradio UI ----------
329
+ deberta_interface = gr.Interface(
330
+ fn=predict_fn,
331
+ inputs=gr.Textbox(label="URL or text", placeholder="Example: http://suspicious-site.example or paste any text"),
332
+ outputs=[
333
+ gr.Label(label="Prediction result"),
334
+ gr.Markdown(label="Detailed token analysis")
335
+ ],
336
+ title="Phishing Detector (DeBERTa + LSTM)",
337
+ description="""
338
+ Enter a URL or text for analysis.
339
+ **Features:**
340
+ - **URL Analysis**: For URLs, the system will fetch HTML content and combine both URL and content analysis
341
+ - **Combined Prediction**: Uses weighted combination of URL structure and webpage content analysis
342
+ - **Visual Analysis**: Predict phishing/benign probability with visual charts
343
+ - **Token Importance**: Display the most important tokens in classification
344
+ - **Detailed Insights**: Comprehensive analysis of the impact of each token
345
+ - **Dark Theme**: Beautiful interface with colorful charts optimized for dark themes
346
+
347
+ **How it works for URLs:**
348
+ 1. Analyze the URL structure itself
349
+ 2. Fetch the webpage HTML content
350
+ 3. Analyze the webpage content
351
+ 4. Combine both results for final prediction (30% URL + 70% content)
352
+ """,
353
+ examples=[
354
+ ["http://rendmoiunserviceeee.com"],
355
+ ["https://www.google.com"],
356
+ ["Dear customer, your account has been suspended. Click here to verify your identity immediately."],
357
+ ["https://mail-secure-login-verify.example/path?token=suspicious"],
358
+ ["http://paypaI-security-update.net/login"],
359
+ ["Your package has been delivered successfully. Thank you for using our service."],
360
+ ["https://github.com/user/repo"]
361
+ ],
362
+ theme=gr.themes.Soft(),
363
+ css="""
364
+ .gradio-container {
365
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
366
+ background-color: #1e1e1e !important;
367
+ color: #ffffff !important;
368
+ }
369
+ .dark .gradio-container {
370
+ background-color: #1e1e1e !important;
371
+ }
372
+ /* Dark theme for all components */
373
+ .block {
374
+ background-color: #2d2d2d !important;
375
+ border: 1px solid #444 !important;
376
+ }
377
+ .gradio-textbox {
378
+ background-color: #3d3d3d !important;
379
+ color: #ffffff !important;
380
+ border: 1px solid #666 !important;
381
+ }
382
+ .gradio-button {
383
+ background-color: #4a4a4a !important;
384
+ color: #ffffff !important;
385
+ border: 1px solid #666 !important;
386
+ }
387
+ .gradio-button:hover {
388
+ background-color: #5a5a5a !important;
389
+ }
390
+ """
391
+ )
392
+
393
+ demo = gr.TabbedInterface(
394
+ [deberta_interface,],
395
+ ["DeBERTa + LSTM"]
396
+ )
397
+
398
+ if __name__ == "__main__":
399
+ demo.launch()