thrimurthi2025 commited on
Commit
e2caa40
·
verified ·
1 Parent(s): e4f3529

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +304 -29
app.py CHANGED
@@ -1,35 +1,230 @@
 
1
  import gradio as gr
2
- from transformers import pipeline
3
  from PIL import Image
4
  import traceback
5
  import time
6
  import threading
 
 
 
 
 
 
 
 
 
 
7
 
8
- # Models
9
  models = [
10
  ("Ateeqq/ai-vs-human-image-detector", "ateeq"),
11
  ("umm-maybe/AI-image-detector", "umm_maybe"),
12
  ("dima806/ai_vs_human_generated_image_detection", "dimma"),
13
  ]
14
 
15
- pipes = []
16
- for model_id, _ in models:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  try:
18
- pipes.append((model_id, pipeline("image-classification", model=model_id)))
19
- print(f"Loaded {model_id}")
 
20
  except Exception as e:
21
- print(f"Error loading {model_id}: {e}")
22
 
23
- def predict_image(image: Image.Image):
24
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  results = []
26
- for _, pipe in pipes:
27
- res = pipe(image)[0]
28
- results.append(res)
 
 
 
29
 
30
- final_result = results[0]
31
- label = final_result["label"].lower()
32
- score = final_result["score"] * 100
 
33
 
34
  if "ai" in label or "fake" in label:
35
  verdict = f"🧠 AI-Generated ({score:.1f}% confidence)"
@@ -38,28 +233,89 @@ def predict_image(image: Image.Image):
38
  verdict = f"🧍 Human-Made ({score:.1f}% confidence)"
39
  color = "#4CAF50"
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  html = f"""
42
  <div class='result-box' style="
43
  background: linear-gradient(135deg, {color}33, #1a1a1a);
44
  border: 2px solid {color};
45
  border-radius: 15px;
46
- padding: 25px;
47
  text-align: center;
48
  color: white;
49
- font-size: 20px;
50
  font-weight: 600;
51
  box-shadow: 0 0 20px {color}55;
52
  animation: fadeIn 0.6s ease-in-out;
53
  ">
54
  {verdict}
 
 
 
55
  </div>
56
  """
57
- return html
 
 
 
 
 
58
  except Exception as e:
59
  traceback.print_exc()
60
- return f"<div style='color:red;'>Error analyzing image: {str(e)}</div>"
61
 
62
- # CSS for sleek glowing pulse
63
  css = """
64
  body, .gradio-container {
65
  font-family: 'Poppins', sans-serif !important;
@@ -104,7 +360,7 @@ h1 {
104
  """
105
 
106
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
107
- gr.Markdown("<h1>🔍 AI Image Detector</h1>")
108
 
109
  with gr.Row():
110
  with gr.Column(scale=1):
@@ -112,20 +368,39 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
112
  analyze_button = gr.Button("Analyze", variant="primary")
113
  clear_button = gr.Button("Clear", variant="secondary")
114
  loader = gr.HTML("")
 
 
115
  with gr.Column(scale=1):
116
- output = gr.HTML(label="Result")
 
 
 
117
 
118
- def analyze(img):
119
  if img is None:
120
- return ("", "<div style='color:red;'>Please upload an image first!</div>")
121
  loader_html = "<div id='pulse-loader'></div>"
122
- yield (loader_html, "") # instantly show loader
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- # do analysis in background
125
- result = predict_image(img)
126
- yield ("", result) # hide loader, show result
 
 
127
 
128
- analyze_button.click(analyze, inputs=image_input, outputs=[loader, output])
129
- clear_button.click(lambda: ("", ""), outputs=[loader, output])
130
 
131
  demo.launch()
 
1
+ # unreal_explain_gradio.py
2
  import gradio as gr
3
+ from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
4
  from PIL import Image
5
  import traceback
6
  import time
7
  import threading
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+ import io
12
+ import base64
13
+ import cv2
14
+
15
+ # ---------- Configuration ----------
16
+ # If any of your Hugging Face models are private, set HF_TOKEN = "<YOUR_TOKEN>"
17
+ HF_TOKEN = None # or "hf_xxx" if needed
18
 
 
19
  models = [
20
  ("Ateeqq/ai-vs-human-image-detector", "ateeq"),
21
  ("umm-maybe/AI-image-detector", "umm_maybe"),
22
  ("dima806/ai_vs_human_generated_image_detection", "dimma"),
23
  ]
24
 
25
+ # ---------- Helper functions for explainability ----------
26
+ def find_last_conv(module):
27
+ last = None
28
+ for name, m in module.named_modules():
29
+ if isinstance(m, torch.nn.Conv2d):
30
+ last = m
31
+ return last
32
+
33
+ class GradCAM:
34
+ def __init__(self, model, target_layer):
35
+ self.model = model
36
+ self.target_layer = target_layer
37
+ self.activations = None
38
+ self.gradients = None
39
+ # register hooks
40
+ target_layer.register_forward_hook(self._save_activation)
41
+ # backward hook signature differs by torch version
42
+ try:
43
+ target_layer.register_backward_hook(self._save_gradient)
44
+ except Exception:
45
+ target_layer.register_full_backward_hook(self._save_gradient)
46
+
47
+ def _save_activation(self, module, input, output):
48
+ self.activations = output.detach()
49
+
50
+ def _save_gradient(self, module, grad_input, grad_output):
51
+ # grad_output can be tuple
52
+ self.gradients = grad_output[0].detach()
53
+
54
+ def __call__(self, input_tensor, class_idx=None):
55
+ self.activations = None
56
+ self.gradients = None
57
+ # forward
58
+ logits = self.model(input_tensor.unsqueeze(0))
59
+ # transformers models return objects, handle both
60
+ if hasattr(logits, "logits"):
61
+ logits_tensor = logits.logits
62
+ else:
63
+ logits_tensor = logits
64
+ if class_idx is None:
65
+ class_idx = int(torch.argmax(logits_tensor, dim=1).item())
66
+ # backward
67
+ self.model.zero_grad()
68
+ score = logits_tensor[0, class_idx]
69
+ score.backward(retain_graph=False)
70
+ # compute weights
71
+ pooled_grads = torch.mean(self.gradients[0], dim=(1,2)) # C
72
+ activ = self.activations[0].cpu()
73
+ for i in range(activ.shape[0]):
74
+ activ[i, :, :] *= pooled_grads[i].cpu()
75
+ heatmap = torch.sum(activ, dim=0).cpu().numpy()
76
+ heatmap = np.maximum(heatmap, 0)
77
+ heatmap = heatmap - np.min(heatmap)
78
+ denom = (np.max(heatmap) + 1e-8)
79
+ heatmap = heatmap / denom
80
+ return heatmap, int(class_idx), logits_tensor
81
+
82
+ def overlay_heatmap_on_pil(orig_pil, heatmap, alpha=0.45):
83
+ orig = np.array(orig_pil.convert("RGB"))
84
+ heatmap_resized = cv2.resize(heatmap, (orig.shape[1], orig.shape[0]))
85
+ heatmap_u8 = np.uint8(255 * heatmap_resized)
86
+ colored = cv2.applyColorMap(heatmap_u8, cv2.COLORMAP_JET)
87
+ colored = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB)
88
+ overlay = np.uint8(orig * (1 - alpha) + colored * alpha)
89
+ return Image.fromarray(overlay)
90
+
91
+ # Attention rollout for ViT-style models
92
+ def attention_rollout_from_attentions(attentions, discard_ratio=0.9):
93
+ """
94
+ attentions: tuple/list of tensors, each shape (batch, heads, seq, seq)
95
+ returns token-to-token rollout matrix shape (seq, seq)
96
+ """
97
+ # Convert to numpy arrays, avg heads
98
+ result = None
99
+ for attn in attentions:
100
+ # attn shape (batch, heads, seq, seq)
101
+ a = attn[0].mean(0).detach().cpu().numpy() # (seq, seq)
102
+ # optionally remove low weights
103
+ a = np.maximum(a, 0)
104
+ a = a / (a.sum(-1, keepdims=True) + 1e-8)
105
+ if result is None:
106
+ result = a
107
+ else:
108
+ result = a @ result
109
+ return result
110
+
111
+ def vit_attention_heatmap(processor, model, image: Image.Image):
112
+ # preprocess
113
+ inputs = processor(images=image, return_tensors="pt")
114
+ # call model with output_attentions=True
115
+ outputs = model(**inputs, output_attentions=True)
116
+ if not hasattr(outputs, "attentions") or outputs.attentions is None:
117
+ return None
118
+ rollout = attention_rollout_from_attentions(outputs.attentions)
119
+ # rollout shape (seq, seq). First token is CLS — we use CLS attention to patches.
120
+ cls_attention = rollout[0, 1:] # skip CLS->CLS token
121
+ # map patch attention to image heatmap
122
+ # get image size and patch grid shape from processor/model config
123
+ try:
124
+ config = model.config
125
+ if hasattr(config, "image_size"):
126
+ image_size = config.image_size
127
+ else:
128
+ image_size = processor.size.get("shortest_edge", 224) if hasattr(processor, "size") else 224
129
+ patch_size = config.patch_size if hasattr(config, "patch_size") else 16
130
+ except Exception:
131
+ image_size = 224
132
+ patch_size = 16
133
+ grid_size = int(image_size // patch_size)
134
+ # if tokens don't match product, try sqrt
135
+ if cls_attention.shape[0] != grid_size * grid_size:
136
+ # fallback: reshape by nearest square
137
+ n = int(np.sqrt(cls_attention.shape[0]))
138
+ grid_size = n
139
+ heatmap = cls_attention.reshape(grid_size, grid_size)
140
+ heatmap = heatmap - heatmap.min()
141
+ heatmap = heatmap / (heatmap.max() + 1e-8)
142
+ return heatmap
143
+
144
+ # ---------- Load pipelines and also underlying models/processors ----------
145
+ pipes = [] # (model_id, pipeline)
146
+ hf_models = {} # model_id -> (processor, model, explain_type)
147
+
148
+ for model_id, short in models:
149
  try:
150
+ p = pipeline("image-classification", model=model_id, use_auth_token=HF_TOKEN)
151
+ pipes.append((model_id, p))
152
+ print(f"Loaded pipeline {model_id}")
153
  except Exception as e:
154
+ print(f"Error loading pipeline for {model_id}: {e}")
155
 
156
+ # try to load processor + raw model for explainability
157
  try:
158
+ processor = AutoImageProcessor.from_pretrained(model_id, use_auth_token=HF_TOKEN)
159
+ except Exception:
160
+ # older HF spacing: AutoFeatureExtractor fallback
161
+ try:
162
+ from transformers import AutoFeatureExtractor
163
+ processor = AutoFeatureExtractor.from_pretrained(model_id, use_auth_token=HF_TOKEN)
164
+ except Exception:
165
+ processor = None
166
+
167
+ try:
168
+ raw_model = AutoModelForImageClassification.from_pretrained(model_id, use_auth_token=HF_TOKEN)
169
+ raw_model.eval()
170
+ # attempt to detect conv layers
171
+ # try to find a backbone / base model
172
+ base = None
173
+ for candidate in ("base_model", "backbone", "model", "vit", "resnet", "conv_stem"):
174
+ if hasattr(raw_model, candidate):
175
+ base = getattr(raw_model, candidate)
176
+ break
177
+ if base is None:
178
+ base = raw_model
179
+
180
+ last_conv = find_last_conv(base)
181
+ if last_conv is not None:
182
+ explain_type = "gradcam"
183
+ explain_helper = GradCAM(raw_model, last_conv)
184
+ print(f"{model_id} -> Grad-CAM available")
185
+ else:
186
+ # try transformer attention route
187
+ # check config for is_vit
188
+ cfg = raw_model.config
189
+ if getattr(cfg, "architectures", None) and any("ViT" in a or "VisionTransformer" in a for a in cfg.architectures):
190
+ explain_type = "vit"
191
+ explain_helper = None
192
+ print(f"{model_id} -> ViT | will use attention rollout")
193
+ else:
194
+ # fallback: no explainability
195
+ explain_type = "none"
196
+ explain_helper = None
197
+ print(f"{model_id} -> No explainability (no convs and not ViT)")
198
+ except Exception as e:
199
+ print(f"Couldn't load raw hf model for {model_id}: {e}")
200
+ raw_model = None
201
+ processor = None
202
+ explain_type = "none"
203
+ explain_helper = None
204
+
205
+ hf_models[model_id] = {
206
+ "processor": processor,
207
+ "model": raw_model,
208
+ "explain_type": explain_type,
209
+ "helper": explain_helper
210
+ }
211
+
212
+ # ---------- original predict function updated to produce overlay ----------
213
+ def predict_image_with_explain(image: Image.Image):
214
+ try:
215
+ # run all pipelines to get consensus / first result for UI
216
  results = []
217
+ for model_id, pipe in pipes:
218
+ try:
219
+ res = pipe(image)[0]
220
+ results.append((model_id, res))
221
+ except Exception as e:
222
+ results.append((model_id, {"label": "error", "score": 0.0}))
223
 
224
+ # pick first result for the main verdict (like before)
225
+ final_model_id, final_res = results[0]
226
+ label = final_res.get("label", "").lower()
227
+ score = final_res.get("score", 0.0) * 100
228
 
229
  if "ai" in label or "fake" in label:
230
  verdict = f"🧠 AI-Generated ({score:.1f}% confidence)"
 
233
  verdict = f"🧍 Human-Made ({score:.1f}% confidence)"
234
  color = "#4CAF50"
235
 
236
+ # Try to compute explainability overlay from the corresponding HF model if available
237
+ explain_entry = hf_models.get(final_model_id)
238
+ overlay_data_uri = None
239
+ explain_reason = None
240
+
241
+ if explain_entry and explain_entry["explain_type"] == "gradcam" and explain_entry["helper"] is not None:
242
+ try:
243
+ # preprocess: use processor if present, else fallback to torchvision transforms
244
+ proc = explain_entry["processor"]
245
+ raw_model = explain_entry["model"]
246
+ if proc is not None:
247
+ inputs = proc(images=image, return_tensors="pt")
248
+ input_tensor = inputs["pixel_values"][0] if "pixel_values" in inputs else inputs["input_tensor"][0]
249
+ else:
250
+ # fallback resize + normalize similar to common models
251
+ from torchvision import transforms
252
+ pre = transforms.Compose([
253
+ transforms.Resize((224,224)),
254
+ transforms.ToTensor(),
255
+ transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
256
+ ])
257
+ input_tensor = pre(image)
258
+
259
+ grad_helper = explain_entry["helper"]
260
+ heatmap, class_idx, logits = grad_helper(input_tensor)
261
+ # overlay
262
+ overlay_img = overlay_heatmap_on_pil(image, heatmap, alpha=0.45)
263
+ buf = io.BytesIO()
264
+ overlay_img.save(buf, format="PNG")
265
+ overlay_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
266
+ overlay_data_uri = "data:image/png;base64," + overlay_b64
267
+ explain_reason = "Grad-CAM heatmap (activations)"
268
+ except Exception as e:
269
+ traceback.print_exc()
270
+ explain_reason = f"Grad-CAM failed: {e}"
271
+
272
+ elif explain_entry and explain_entry["explain_type"] == "vit" and explain_entry["model"] is not None:
273
+ try:
274
+ proc = explain_entry["processor"]
275
+ raw_model = explain_entry["model"]
276
+ heatmap = vit_attention_heatmap(proc, raw_model, image)
277
+ if heatmap is not None:
278
+ overlay_img = overlay_heatmap_on_pil(image, heatmap, alpha=0.45)
279
+ buf = io.BytesIO()
280
+ overlay_img.save(buf, format="PNG")
281
+ overlay_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
282
+ overlay_data_uri = "data:image/png;base64," + overlay_b64
283
+ explain_reason = "ViT attention rollout heatmap"
284
+ except Exception as e:
285
+ traceback.print_exc()
286
+ explain_reason = f"ViT rollout failed: {e}"
287
+
288
+ # Build HTML for verdict box
289
  html = f"""
290
  <div class='result-box' style="
291
  background: linear-gradient(135deg, {color}33, #1a1a1a);
292
  border: 2px solid {color};
293
  border-radius: 15px;
294
+ padding: 20px;
295
  text-align: center;
296
  color: white;
297
+ font-size: 18px;
298
  font-weight: 600;
299
  box-shadow: 0 0 20px {color}55;
300
  animation: fadeIn 0.6s ease-in-out;
301
  ">
302
  {verdict}
303
+ <div style="font-size:12px; margin-top:8px; font-weight:400; opacity:0.9;">
304
+ Model: <b>{final_model_id}</b> — Score by model: {score:.1f}%
305
+ </div>
306
  </div>
307
  """
308
+
309
+ return {
310
+ "html": html,
311
+ "overlay": overlay_data_uri,
312
+ "explain_reason": explain_reason or ""
313
+ }
314
  except Exception as e:
315
  traceback.print_exc()
316
+ return {"html": f"<div style='color:red;'>Error analyzing image: {str(e)}</div>", "overlay": None, "explain_reason": ""}
317
 
318
+ # ---------- Gradio UI ----------
319
  css = """
320
  body, .gradio-container {
321
  font-family: 'Poppins', sans-serif !important;
 
360
  """
361
 
362
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
363
+ gr.Markdown("<h1>🔍 AI Image Detector w/ Explainability</h1>")
364
 
365
  with gr.Row():
366
  with gr.Column(scale=1):
 
368
  analyze_button = gr.Button("Analyze", variant="primary")
369
  clear_button = gr.Button("Clear", variant="secondary")
370
  loader = gr.HTML("")
371
+ gr.Markdown("Opacity:")
372
+ opacity = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.05)
373
  with gr.Column(scale=1):
374
+ # show original image plus overlay using HTML
375
+ image_display = gr.Image(type="pil", label="Original / Overlay", interactive=False)
376
+ output_html = gr.HTML(label="Result")
377
+ explanation_text = gr.Textbox(label="Explainability", interactive=False)
378
 
379
+ def analyze(img, op):
380
  if img is None:
381
+ return (None, "<div style='color:red;'>Please upload an image first!</div>", "")
382
  loader_html = "<div id='pulse-loader'></div>"
383
+ # show loader
384
+ yield (None, loader_html, "")
385
+ # run analysis
386
+ out = predict_image_with_explain(img)
387
+ # overlay image if available
388
+ overlay_uri = out.get("overlay")
389
+ if overlay_uri:
390
+ # convert data uri to PIL for gr.Image output
391
+ header, b64 = overlay_uri.split(",", 1)
392
+ overlay_bytes = base64.b64decode(b64)
393
+ overlay_img = Image.open(io.BytesIO(overlay_bytes)).convert("RGB")
394
+ else:
395
+ overlay_img = img # fallback: show orig
396
 
397
+ # explanation text
398
+ explain_reason = out.get("explain_reason", "")
399
+ html = out.get("html", "")
400
+ # yield overlay image, html, explanation string
401
+ yield (overlay_img, html, explain_reason)
402
 
403
+ analyze_button.click(analyze, inputs=[image_input, opacity], outputs=[image_display, output_html, explanation_text])
404
+ clear_button.click(lambda: (None, "", ""), outputs=[image_display, output_html, explanation_text])
405
 
406
  demo.launch()