File size: 6,004 Bytes
ef16f91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import torch, cv2, numpy as np
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
from facenet_pytorch import MTCNN
from temporal_model import TemporalConsistencyModel
import warnings, logging
import os
from dotenv import load_dotenv

warnings.filterwarnings("ignore")

# ---------- Logger Setup ----------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

load_dotenv()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- Face Detector ----------
face_detector = MTCNN(keep_all=False, device=device)

# ---------- Temporal Model ----------
temporal_model = TemporalConsistencyModel(window=7, alpha=0.75)

# ---------- Model Definitions ----------
MODEL_PATHS = [


    os.getenv("VIDEO_MODEL_1"),
    os.getenv("VIDEO_MODEL_2"),
    os.getenv("VIDEO_MODEL_3")


]

models, processors = [], []
for mid in MODEL_PATHS:
    try:
        proc = AutoImageProcessor.from_pretrained(mid)
        model = AutoModelForImageClassification.from_pretrained(mid).to(device)
        model.eval()
        models.append(model)
        processors.append(proc)
        logger.info(f"✅ Loaded model: {mid}")
    except Exception as e:
        logger.warning(f"⚠️ Failed to load {mid}: {e}")

# ---------- Heuristic ----------
def heuristic_texture_analysis(frame):
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    freq = np.fft.fft2(gray)
    freq_shift = np.fft.fftshift(freq)
    mag = np.log(np.abs(freq_shift) + 1)
    edge_var = np.var(cv2.Laplacian(gray, cv2.CV_64F))
    texture_score = np.mean(mag) / (edge_var + 1e-5)
    norm_score = np.clip(np.tanh(texture_score / 60), 0, 1)
    return float(norm_score)

# ---------- Face Cropper (Fixed) ----------
def extract_face(frame):
    boxes, _ = face_detector.detect(frame)
    if boxes is not None and len(boxes) > 0:
        x1, y1, x2, y2 = [int(b) for b in boxes[0]]
        face = frame[y1:y2, x1:x2]

        if face is None or face.size == 0:
            logger.warning("⚠️ Detected invalid face region; skipping frame.")
            return None

        return cv2.resize(face, (224, 224))
    else:
        logger.info("ℹ️ No face detected in this frame; skipping.")
        return None

# ---------- Prediction ----------
def predict_frame(frame):
    face_img = extract_face(frame)
    if face_img is None:
        return None  # skip frame gracefully

    frame_img = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB))
    preds = []

    for model, proc in zip(models, processors):
        try:
            inputs = proc(images=frame_img, return_tensors="pt").to(device)
            with torch.no_grad():
                logits = model(**inputs).logits
                probs = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy()

            id2label = model.config.id2label
            label_idx = np.argmax(probs)

            if str(label_idx) in id2label:
                label = id2label[str(label_idx)].lower()
            elif label_idx in id2label:
                label = id2label[label_idx].lower()
            else:
                label = "unknown"

            is_fake = any(k in label for k in ["fake", "forged", "manipulated", "edited"])
            confidence = float(probs[label_idx])

            score = confidence if is_fake else 1 - confidence
            preds.append(score)

        except Exception as e:
            logger.warning(f"⚠️ Model prediction failed for {model.__class__.__name__}: {e}")

    if not preds:
        logger.warning("⚠️ No valid model predictions; skipping frame.")
        return None

    # Weighted average (CNN:0.4, ViT:0.35, BEiT:0.25)
    weights = np.array([0.4, 0.35, 0.25])[:len(preds)]
    weights /= weights.sum()
    weighted_score = np.dot(preds, weights)
    return float(np.clip(weighted_score, 0, 1))

# ---------- Main Pipeline ----------
def ensemble_predict_video(video_path, frame_interval=10):
    cap = cv2.VideoCapture(video_path)
    frame_preds, heuristics = [], []
    frame_count = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        if frame_count % frame_interval == 0:
            model_score = predict_frame(frame)
            if model_score is None:
                frame_count += 1
                continue

            heuristic_score = heuristic_texture_analysis(frame)
            combined_score = 0.8 * model_score + 0.2 * heuristic_score
            temporal_score = temporal_model.update(combined_score)

            frame_preds.append(temporal_score)
            heuristics.append(heuristic_score)

        frame_count += 1

    cap.release()

    if not frame_preds:
        logger.error("❌ No valid frames processed. Returning unknown result.")
        return {"top": {"label": "unknown", "score": 0.0}}

    model_score = float(np.mean(frame_preds))
    heuristic_score = float(np.mean(heuristics))
    final_score = float(np.clip(model_score, 0, 1))

    logger.info(f"✅ Video processed | Final Score: {final_score:.4f}")

    return {
        "top": {
            "label": "fake" if final_score > 0.55 else "real",
            "score": round(final_score, 4)
        },
        "model_score": round(model_score, 4),
        "heuristic_score": round(heuristic_score, 4),
    }

# ---------- Compatibility Wrapper ----------
def ensemble_predict_from_path(video_path):
    """Compatibility wrapper for main.py"""
    return ensemble_predict_video(video_path)


#***********************************************************************************************************************************************************************************************************************