File size: 3,908 Bytes
a055f15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

import gradio as gr
from transformers import pipeline

# ---- Model registry (all are image-classification on HF) ----
# Sources:
# - prithivMLmods/Deep-Fake-Detector-v2-Model (labels: Realism / Deepfake)
# - Wvolf/ViT_Deepfake_Detection (real/fake)
# - yermandy/deepfake-detection (CLIP-encoder baseline for deepfake)
MODEL_REGISTRY = {
    "ViT Deepfake v2 (Prithiv)": "prithivMLmods/Deep-Fake-Detector-v2-Model",
    "ViT Deepfake (Wvolf)":      "Wvolf/ViT_Deepfake_Detection",
    "CLIP Deepfake (yermandy)":  "yermandy/deepfake-detection",
}

_pipes = {}

def _get_pipe(model_id: str):
    if model_id not in _pipes:
        _pipes[model_id] = pipeline("image-classification",
                                    model=model_id,
                                    device_map="auto")
    return _pipes[model_id]

def _fake_real_probs(result):
    # result: list[{'label': str, 'score': float}]
    fake, real = 0.0, 0.0
    for r in result:
        lbl = r["label"].strip().lower()
        s   = float(r["score"])
        if ("fake" in lbl) or ("deepfake" in lbl) or ("ai" in lbl):
            fake = max(fake, s)
        if ("real" in lbl) or ("realism" in lbl) or ("authentic" in lbl):
            real = max(real, s)
    if fake==0.0 and real==0.0:
        # fallback: take top-1 and mirror
        top = max(result, key=lambda x: x["score"])
        is_fake = ("fake" in top["label"].lower()) or ("deepfake" in top["label"].lower()) or ("ai" in top["label"].lower())
        if is_fake:
            fake, real = float(top["score"]), 1.0 - float(top["score"])
        else:
            real, fake = float(top["score"]), 1.0 - float(top["score"])
    # normalize to sum<=1 if both present
    s = fake + real
    if s > 1.0 and s > 0:
        fake, real = fake/s, real/s
    return fake, real

def predict(img, model_name, ensemble, top_k, threshold):
    model_ids = list(MODEL_REGISTRY.values()) if ensemble else [MODEL_REGISTRY[model_name]]

    agg_fake, agg_real, rows = [], [], []
    for mid in model_ids:
        pipe = _get_pipe(mid)
        out  = pipe(img, top_k=top_k)
        f, r = _fake_real_probs(out)
        agg_fake.append(f); agg_real.append(r)
        for item in out:
            rows.append([mid, item["label"], float(item["score"])])

    fake_prob = sum(agg_fake)/len(agg_fake)
    real_prob = sum(agg_real)/len(agg_real)
    pred = "FAKE" if fake_prob >= threshold else "REAL"

    # aggregate chart (Label expects {label: score})
    chart = {"FAKE": float(fake_prob), "REAL": float(real_prob)}
    # top-k table
    rows = sorted(rows, key=lambda x: x[2], reverse=True)[: top_k * len(model_ids)]
    return pred, chart, rows

with gr.Blocks(theme="soft") as demo:
    gr.Markdown("# 🔎 Deepfake Detector\nChoose a model or use an **Ensemble** for a more robust score.")
    with gr.Row():
        with gr.Column(scale=3):
            img = gr.Image(type="pil", label="Upload image (face works best)")
            with gr.Accordion("Settings", open=False):
                model_name = gr.Dropdown(list(MODEL_REGISTRY.keys()),
                                         value="ViT Deepfake v2 (Prithiv)",
                                         label="Backbone")
                ensemble   = gr.Checkbox(label="Ensemble (use all models)", value=False)
                top_k      = gr.Slider(1, 5, value=3, step=1, label="Top-k per model")
                threshold  = gr.Slider(0.1, 0.9, value=0.5, step=0.01, label="FAKE threshold")
            btn = gr.Button("Predict", variant="primary")
        with gr.Column(scale=2):
            pred  = gr.Label(label="Prediction (FAKE vs REAL)")
            chart = gr.Label(label="Aggregated probabilities")
            table = gr.Dataframe(headers=["model", "label", "score"], wrap=True)

    btn.click(predict, [img, model_name, ensemble, top_k, threshold], [pred, chart, table])

if __name__ == "__main__":
    demo.launch()