thrimurthi2025 commited on
Commit
78e86b6
·
verified ·
1 Parent(s): 513e05d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -333
app.py CHANGED
@@ -1,353 +1,70 @@
1
- # app.py
 
2
  import gradio as gr
3
- from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
4
  from PIL import Image
5
- import traceback, io, base64, time
6
- import torch
7
- import torch.nn.functional as F
8
- import numpy as np
9
 
10
- # -------- CONFIG --------
11
- HF_TOKEN = None # set to "hf_xxx" if you use private models
12
-
13
- # List the models you want to use. Keep one model if you need lower memory.
14
- models = [
15
- ("Ateeqq/ai-vs-human-image-detector", "ateeq"),
16
- ("umm-maybe/AI-image-detector", "umm_maybe"),
17
- ("dima806/ai_vs_human_generated_image_detection", "dimma"),
18
- ]
19
-
20
- # ---------- Utilities: safe overlay WITHOUT OPENCV ----------
21
- def apply_colormap_numpy(heatmap):
22
- h = np.clip(heatmap, 0.0, 1.0)
23
- c = np.zeros((h.shape[0], h.shape[1], 3), dtype=np.float32)
24
- c[..., 0] = np.clip(1.5 - 4.0 * np.abs(h - 0.25), 0, 1) # R
25
- c[..., 1] = np.clip(1.5 - 4.0 * np.abs(h - 0.5), 0, 1) # G
26
- c[..., 2] = np.clip(1.5 - 4.0 * np.abs(h - 0.75), 0, 1) # B
27
- return (c * 255).astype(np.uint8)
28
-
29
- def overlay_heatmap_on_pil_no_cv(orig_pil, heatmap, alpha=0.45):
30
- orig = np.array(orig_pil.convert("RGB"))
31
- # resize heatmap to original size
32
- heatmap_img = Image.fromarray((np.clip(heatmap,0,1)*255).astype(np.uint8))
33
- heatmap_resized = np.array(heatmap_img.resize((orig.shape[1], orig.shape[0]), resample=Image.BILINEAR)) / 255.0
34
- colored = apply_colormap_numpy(heatmap_resized)
35
- overlay = (orig * (1 - alpha) + colored * alpha).astype(np.uint8)
36
- return Image.fromarray(overlay)
37
-
38
- # ---------- Safe Grad-CAM helper (expects a helper object that returns heatmap) ----------
39
- def safe_gradcam_run(grad_helper, model, input_tensor):
40
- """
41
- grad_helper should be callable like: heatmap, class_idx, logits = grad_helper(input_tensor, class_idx)
42
- This wrapper keeps things safe and returns tuple (heatmap_or_None, class_idx_or_None, logits_or_None, error_or_None)
43
- """
44
- try:
45
- out = model(input_tensor.unsqueeze(0))
46
- logits = out.logits if hasattr(out, "logits") else out
47
- if logits is None:
48
- return None, None, None, "no logits returned"
49
- class_idx = int(torch.argmax(logits, dim=1).item())
50
- # call grad helper (it may perform backward internally)
51
- heatmap, idx, logits_tensor = grad_helper(input_tensor, class_idx)
52
- if heatmap is None:
53
- return None, class_idx, logits_tensor, "gradcam returned no heatmap"
54
- return heatmap, class_idx, logits_tensor, None
55
- except Exception as e:
56
- return None, None, None, f"GradCAM error: {repr(e)}"
57
-
58
- # ---------- Safe ViT attention rollout ----------
59
- def safe_vit_attention_heatmap(processor, model, image: Image.Image):
60
- try:
61
- if processor is None or model is None:
62
- return None, "processor or model missing"
63
- inputs = processor(images=image, return_tensors="pt")
64
- outputs = model(**inputs, output_attentions=True)
65
- attentions = getattr(outputs, "attentions", None)
66
- if not attentions:
67
- return None, "no attentions in model output"
68
- result = None
69
- for attn in attentions:
70
- a = attn[0].mean(0).detach().cpu().numpy() # (seq, seq)
71
- a = np.maximum(a, 0)
72
- a = a / (a.sum(-1, keepdims=True) + 1e-8)
73
- result = a if result is None else a @ result
74
- cls_attn = result[0, 1:]
75
- n_tokens = cls_attn.shape[0]
76
- grid = int(np.round(np.sqrt(n_tokens)))
77
- if grid * grid != n_tokens:
78
- # best-effort reshape
79
- grid = int(np.round(np.sqrt(n_tokens)))
80
- heatmap = cls_attn.reshape(grid, grid)
81
- heatmap = heatmap - heatmap.min()
82
- heatmap = heatmap / (heatmap.max() + 1e-8)
83
- return heatmap, None
84
- except Exception as e:
85
- return None, f"ViT rollout error: {repr(e)}"
86
-
87
- # ---------- Load pipelines and raw models (defensive) ----------
88
- pipes = [] # list of (model_id, pipeline)
89
- hf_models = {} # model_id -> dict with processor/model/explain_type/helper
90
-
91
- for model_id, short in models:
92
- # load inference pipeline (fast)
93
- try:
94
- p = pipeline("image-classification", model=model_id, use_auth_token=HF_TOKEN)
95
- pipes.append((model_id, p))
96
- print(f"[INFO] Loaded pipeline: {model_id}")
97
- except Exception as e:
98
- print(f"[WARN] Failed to load pipeline for {model_id}: {e}")
99
-
100
- # try to load raw model + processor for explainability
101
- proc = None
102
- raw_model = None
103
- explain_type = "none"
104
- helper = None
105
- try:
106
- proc = AutoImageProcessor.from_pretrained(model_id, use_auth_token=HF_TOKEN)
107
- except Exception:
108
- # fallback: processor may not exist
109
- proc = None
110
-
111
- try:
112
- raw_model = AutoModelForImageClassification.from_pretrained(model_id, use_auth_token=HF_TOKEN)
113
- raw_model.eval()
114
- # attempt to detect conv layers for Grad-CAM
115
- base = None
116
- for cand in ("base_model", "backbone", "model", "vit", "resnet", "conv_stem"):
117
- if hasattr(raw_model, cand):
118
- base = getattr(raw_model, cand)
119
- break
120
- if base is None:
121
- base = raw_model
122
-
123
- # find last conv2d if exists
124
- last_conv = None
125
- for name, m in base.named_modules():
126
- if isinstance(m, torch.nn.Conv2d):
127
- last_conv = m
128
- if last_conv is not None:
129
- explain_type = "gradcam"
130
- # Create a small grad-cam helper object that registers hooks on last_conv
131
- class GradCAMHelper:
132
- def __init__(self, model, target_layer):
133
- self.model = model
134
- self.target_layer = target_layer
135
- self.activations = None
136
- self.gradients = None
137
- target_layer.register_forward_hook(self._save_activation)
138
- try:
139
- target_layer.register_backward_hook(self._save_gradient)
140
- except Exception:
141
- target_layer.register_full_backward_hook(self._save_gradient)
142
-
143
- def _save_activation(self, module, inp, out):
144
- self.activations = out.detach()
145
-
146
- def _save_gradient(self, module, grad_input, grad_output):
147
- self.gradients = grad_output[0].detach()
148
-
149
- def __call__(self, input_tensor, class_idx=None):
150
- # forward
151
- out = self.model(input_tensor.unsqueeze(0))
152
- logits = out.logits if hasattr(out, "logits") else out
153
- if class_idx is None:
154
- class_idx = int(torch.argmax(logits, dim=1).item())
155
- self.model.zero_grad()
156
- score = logits[0, class_idx]
157
- score.backward(retain_graph=False)
158
- if self.gradients is None or self.activations is None:
159
- raise RuntimeError("gradcam hooks did not capture activations/gradients")
160
- pooled_grads = torch.mean(self.gradients[0], dim=(1,2))
161
- activ = self.activations[0].cpu()
162
- for i in range(activ.shape[0]):
163
- activ[i,:,:] *= pooled_grads[i].cpu()
164
- heatmap = torch.sum(activ, dim=0).cpu().numpy()
165
- heatmap = np.maximum(heatmap, 0)
166
- heatmap = heatmap - heatmap.min()
167
- heatmap = heatmap / (heatmap.max() + 1e-8)
168
- return heatmap, class_idx, logits
169
- helper = GradCAMHelper(raw_model, last_conv)
170
- print(f"[INFO] {model_id} -> gradcam ready")
171
- else:
172
- # if model looks like ViT (common in config.architectures)
173
- cfg = getattr(raw_model, "config", None)
174
- archs = getattr(cfg, "architectures", None) if cfg is not None else None
175
- if archs and any("ViT" in a or "VisionTransformer" in a for a in archs):
176
- explain_type = "vit"
177
- helper = None
178
- print(f"[INFO] {model_id} -> detected ViT, will use attention rollout")
179
- else:
180
- explain_type = "none"
181
- helper = None
182
- print(f"[INFO] {model_id} -> no explainability detected")
183
- except Exception as e:
184
- print(f"[WARN] Could not load raw HF model for explainability {model_id}: {e}")
185
- raw_model = None
186
- proc = proc
187
- explain_type = "none"
188
- helper = None
189
-
190
- hf_models[model_id] = {
191
- "processor": proc,
192
- "model": raw_model,
193
- "explain_type": explain_type,
194
- "helper": helper
195
- }
196
-
197
- # ---------- Prediction + explain wrapper ----------
198
- def predict_image_with_explain(image: Image.Image):
199
  try:
200
- results = []
201
- for model_id, pipe in pipes:
202
- try:
203
- res = pipe(image)
204
- results.append((model_id, res[0] if isinstance(res, list) and res else {"label":"error","score":0.0}))
205
- except Exception as e:
206
- results.append((model_id, {"label":"error","score":0.0}))
207
-
208
- if not results:
209
- return {"html": "<div style='color:red;'>No models loaded</div>", "overlay": None, "explain_reason": "no pipelines"}
210
-
211
- final_model_id, final_res = results[0]
212
- label = final_res.get("label","").lower()
213
- score = final_res.get("score",0.0) * 100
214
  if "ai" in label or "fake" in label:
215
  verdict = f"🧠 AI-Generated ({score:.1f}% confidence)"
216
  color = "#007BFF"
217
  else:
218
  verdict = f"🧍 Human-Made ({score:.1f}% confidence)"
219
  color = "#4CAF50"
220
-
221
- overlay_data_uri = None
222
- explain_reason = ""
223
-
224
- explain_entry = hf_models.get(final_model_id)
225
- if explain_entry:
226
- etype = explain_entry.get("explain_type","none")
227
- try:
228
- if etype == "gradcam" and explain_entry.get("helper") is not None:
229
- proc = explain_entry.get("processor")
230
- raw_model = explain_entry.get("model")
231
- # prepare input tensor robustly
232
- if proc is not None:
233
- inputs = proc(images=image, return_tensors="pt")
234
- # common key names
235
- input_tensor = inputs.get("pixel_values") or inputs.get("input_tensor") or list(inputs.values())[0]
236
- if isinstance(input_tensor, (list,tuple)):
237
- input_tensor = input_tensor[0]
238
- if isinstance(input_tensor, torch.Tensor) and input_tensor.dim()==4:
239
- input_tensor = input_tensor[0]
240
- else:
241
- # fallback preproc
242
- from torchvision import transforms
243
- pre = transforms.Compose([
244
- transforms.Resize((224,224)),
245
- transforms.ToTensor(),
246
- transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
247
- ])
248
- input_tensor = pre(image)
249
- heatmap, class_idx, logits, err = safe_gradcam_run(explain_entry["helper"], raw_model, input_tensor)
250
- if err:
251
- explain_reason = err
252
- elif heatmap is None:
253
- explain_reason = "gradcam returned no heatmap"
254
- else:
255
- overlay_img = overlay_heatmap_on_pil_no_cv(image, heatmap, alpha=0.45)
256
- buf = io.BytesIO()
257
- overlay_img.save(buf, format="PNG")
258
- overlay_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
259
- overlay_data_uri = "data:image/png;base64," + overlay_b64
260
- explain_reason = "Grad-CAM heatmap (activations)"
261
- elif etype == "vit" and explain_entry.get("model") is not None:
262
- proc = explain_entry.get("processor")
263
- raw_model = explain_entry.get("model")
264
- heatmap, err = safe_vit_attention_heatmap(proc, raw_model, image)
265
- if err:
266
- explain_reason = err
267
- elif heatmap is None:
268
- explain_reason = "vit produced no heatmap"
269
- else:
270
- overlay_img = overlay_heatmap_on_pil_no_cv(image, heatmap, alpha=0.45)
271
- buf = io.BytesIO()
272
- overlay_img.save(buf, format="PNG")
273
- overlay_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
274
- overlay_data_uri = "data:image/png;base64," + overlay_b64
275
- explain_reason = "ViT attention rollout heatmap"
276
- else:
277
- explain_reason = "No explainability available for this model"
278
- except Exception as e:
279
- explain_reason = f"Explain pipeline failed: {repr(e)}"
280
- else:
281
- explain_reason = "No raw HF entry for model"
282
-
283
  html = f"""
284
- <div class='result-box' style="
285
- background: linear-gradient(135deg, {color}33, #1a1a1a);
286
- border: 2px solid {color};
287
- border-radius: 15px;
288
- padding: 20px;
289
- text-align: center;
290
- color: white;
291
- font-size: 18px;
292
- font-weight: 600;
293
- box-shadow: 0 0 20px {color}55;
294
- animation: fadeIn 0.6s ease-in-out;
295
- ">
296
- {verdict}
297
- <div style="font-size:12px; margin-top:8px; font-weight:400; opacity:0.9;">
298
- Model: <b>{final_model_id}</b> — Score: {score:.1f}%
299
- </div>
300
  </div>
301
  """
302
-
303
- return {"html": html, "overlay": overlay_data_uri, "explain_reason": explain_reason}
304
  except Exception as e:
305
- traceback.print_exc()
306
- return {"html": f"<div style='color:red;'>Error analyzing image: {str(e)}</div>", "overlay": None, "explain_reason": ""}
307
 
308
- # ---------- Gradio UI ----------
309
  css = """
310
- body, .gradio-container { font-family: 'Poppins', sans-serif !important; background: transparent !important; }
311
- h1 { text-align: center; font-weight: 700; color: #007BFF; margin-bottom: 10px; }
312
- .gr-button-primary { background-color: #007BFF !important; color: white !important; font-weight: 600; border-radius: 10px; height: 45px; }
313
- .gr-button-secondary { background-color: #dc3545 !important; color: white !important; border-radius: 10px; height: 45px; }
314
- #pulse-loader { width: 100%; height: 4px; background: linear-gradient(90deg, #007BFF, #00C3FF); animation: pulse 1.2s infinite ease-in-out; border-radius: 2px; box-shadow: 0 0 10px #007BFF; }
315
- @keyframes pulse { 0% { transform: scaleX(0.1); opacity: 0.6; } 50% { transform: scaleX(1); opacity: 1; } 100% { transform: scaleX(0.1); opacity: 0.6; } }
316
- @keyframes fadeIn { from { opacity: 0; transform: scale(0.95); } to { opacity: 1; transform: scale(1); } }
317
  """
318
 
319
- with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
320
- gr.Markdown("<h1>🔍 AI Image Detector w/ Explainability</h1>")
321
  with gr.Row():
322
- with gr.Column(scale=1):
323
- image_input = gr.Image(type="pil", label="Upload an image")
324
- analyze_button = gr.Button("Analyze", variant="primary")
325
- clear_button = gr.Button("Clear", variant="secondary")
326
- loader = gr.HTML("")
327
- gr.Markdown("Opacity:")
328
- opacity = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.05)
329
- with gr.Column(scale=1):
330
- image_display = gr.Image(type="pil", label="Original / Overlay", interactive=False)
331
- output_html = gr.HTML(label="Result")
332
- explanation_text = gr.Textbox(label="Explainability", interactive=False)
333
-
334
- def analyze(img, op):
335
- if img is None:
336
- return (None, "<div style='color:red;'>Please upload an image first!</div>", "")
337
- yield (None, "<div id='pulse-loader'></div>", "")
338
- out = predict_image_with_explain(img)
339
- overlay_uri = out.get("overlay")
340
- if overlay_uri:
341
- header, b64 = overlay_uri.split(",",1)
342
- overlay_bytes = base64.b64decode(b64)
343
- overlay_img = Image.open(io.BytesIO(overlay_bytes)).convert("RGB")
344
- else:
345
- overlay_img = img
346
- explain_reason = out.get("explain_reason","")
347
- html = out.get("html","")
348
- yield (overlay_img, html, explain_reason)
349
-
350
- analyze_button.click(analyze, inputs=[image_input, opacity], outputs=[image_display, output_html, explanation_text])
351
- clear_button.click(lambda: (None, "", ""), outputs=[image_display, output_html, explanation_text])
352
-
353
  demo.launch()
 
1
+ # app.py (Option B - Minimal local pipeline; may use more RAM)
2
+ import os, io, base64, traceback
3
  import gradio as gr
4
+ from transformers import pipeline
5
  from PIL import Image
 
 
 
 
6
 
7
+ MODEL_ID = "Ateeqq/ai-vs-human-image-detector"
8
+ HF_TOKEN = os.environ.get("HF_TOKEN") # set if model private
9
+
10
+ # Try to load pipeline (defensive)
11
+ pipes = []
12
+ load_error = None
13
+ try:
14
+ pipes.append((MODEL_ID, pipeline("image-classification", model=MODEL_ID, use_auth_token=HF_TOKEN)))
15
+ load_error = None
16
+ print(f"[INFO] Loaded {MODEL_ID}")
17
+ except Exception as e:
18
+ load_error = repr(e)
19
+ print("[ERROR] Failed to load pipeline:", load_error)
20
+
21
+ def predict(image: Image.Image):
22
+ if image is None:
23
+ return None, "<div style='color:red;'>Upload an image first</div>", load_error or ""
24
+ if not pipes:
25
+ # Show the exact load error to help debugging
26
+ return image, "<div style='color:red;'>No models loaded</div>", load_error or "No pipeline"
27
+ model_id, pipe = pipes[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  try:
29
+ res = pipe(image)
30
+ if not res:
31
+ return image, "<div style='color:red;'>Model returned no results</div>", ""
32
+ top = res[0]
33
+ label = top.get("label","").lower()
34
+ score = top.get("score", 0.0) * 100
 
 
 
 
 
 
 
 
35
  if "ai" in label or "fake" in label:
36
  verdict = f"🧠 AI-Generated ({score:.1f}% confidence)"
37
  color = "#007BFF"
38
  else:
39
  verdict = f"🧍 Human-Made ({score:.1f}% confidence)"
40
  color = "#4CAF50"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  html = f"""
42
+ <div style='background:linear-gradient(135deg,{color}33,#1a1a1a);
43
+ border:2px solid {color}; border-radius:12px; padding:18px;
44
+ text-align:center; color:white; font-weight:700;'>
45
+ {verdict}<div style="font-size:12px;opacity:0.85;margin-top:6px">Model: {model_id}</div>
 
 
 
 
 
 
 
 
 
 
 
 
46
  </div>
47
  """
48
+ return image, html, ""
 
49
  except Exception as e:
50
+ err = repr(e)
51
+ return image, f"<div style='color:red;'>Inference failed: {err}</div>", err
52
 
 
53
  css = """
54
+ .gradio-container { font-family: 'Poppins', sans-serif; }
 
 
 
 
 
 
55
  """
56
 
57
+ with gr.Blocks(css=css) as demo:
58
+ gr.Markdown("<h2>🔍 Unreal Eye (Local single-model)</h2>")
59
  with gr.Row():
60
+ with gr.Column():
61
+ inp = gr.Image(type="pil", label="Upload an image")
62
+ btn = gr.Button("Analyze")
63
+ btn_clear = gr.Button("Clear")
64
+ with gr.Column():
65
+ out_img = gr.Image(type="pil", label="Original / Overlay")
66
+ out_html = gr.HTML()
67
+ load_box = gr.Textbox(label="Load status / explainability", value=(load_error or "Model loaded" if pipes else "No model loaded"), interactive=False)
68
+ btn.click(predict, inputs=inp, outputs=[out_img, out_html, load_box])
69
+ btn_clear.click(lambda: (None, "", ""), outputs=[out_img, out_html, load_box])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  demo.launch()