import torch, cv2, numpy as np from transformers import AutoImageProcessor, AutoModelForImageClassification from PIL import Image from facenet_pytorch import MTCNN from temporal_model import TemporalConsistencyModel import warnings, logging import os from dotenv import load_dotenv warnings.filterwarnings("ignore") # ---------- Logger Setup ---------- logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler()] ) logger = logging.getLogger(__name__) load_dotenv() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ---------- Face Detector ---------- face_detector = MTCNN(keep_all=False, device=device) # ---------- Temporal Model ---------- temporal_model = TemporalConsistencyModel(window=7, alpha=0.75) # ---------- Model Definitions ---------- MODEL_PATHS = [ os.getenv("VIDEO_MODEL_1"), os.getenv("VIDEO_MODEL_2"), os.getenv("VIDEO_MODEL_3") ] models, processors = [], [] for mid in MODEL_PATHS: try: proc = AutoImageProcessor.from_pretrained(mid) model = AutoModelForImageClassification.from_pretrained(mid).to(device) model.eval() models.append(model) processors.append(proc) logger.info(f"✅ Loaded model: {mid}") except Exception as e: logger.warning(f"⚠️ Failed to load {mid}: {e}") # ---------- Heuristic ---------- def heuristic_texture_analysis(frame): gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) freq = np.fft.fft2(gray) freq_shift = np.fft.fftshift(freq) mag = np.log(np.abs(freq_shift) + 1) edge_var = np.var(cv2.Laplacian(gray, cv2.CV_64F)) texture_score = np.mean(mag) / (edge_var + 1e-5) norm_score = np.clip(np.tanh(texture_score / 60), 0, 1) return float(norm_score) # ---------- Face Cropper (Fixed) ---------- def extract_face(frame): boxes, _ = face_detector.detect(frame) if boxes is not None and len(boxes) > 0: x1, y1, x2, y2 = [int(b) for b in boxes[0]] face = frame[y1:y2, x1:x2] if face is None or face.size == 0: logger.warning("⚠️ Detected invalid face region; skipping frame.") return None return cv2.resize(face, (224, 224)) else: logger.info("ℹ️ No face detected in this frame; skipping.") return None # ---------- Prediction ---------- def predict_frame(frame): face_img = extract_face(frame) if face_img is None: return None # skip frame gracefully frame_img = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB)) preds = [] for model, proc in zip(models, processors): try: inputs = proc(images=frame_img, return_tensors="pt").to(device) with torch.no_grad(): logits = model(**inputs).logits probs = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy() id2label = model.config.id2label label_idx = np.argmax(probs) if str(label_idx) in id2label: label = id2label[str(label_idx)].lower() elif label_idx in id2label: label = id2label[label_idx].lower() else: label = "unknown" is_fake = any(k in label for k in ["fake", "forged", "manipulated", "edited"]) confidence = float(probs[label_idx]) score = confidence if is_fake else 1 - confidence preds.append(score) except Exception as e: logger.warning(f"⚠️ Model prediction failed for {model.__class__.__name__}: {e}") if not preds: logger.warning("⚠️ No valid model predictions; skipping frame.") return None # Weighted average (CNN:0.4, ViT:0.35, BEiT:0.25) weights = np.array([0.4, 0.35, 0.25])[:len(preds)] weights /= weights.sum() weighted_score = np.dot(preds, weights) return float(np.clip(weighted_score, 0, 1)) # ---------- Main Pipeline ---------- def ensemble_predict_video(video_path, frame_interval=10): cap = cv2.VideoCapture(video_path) frame_preds, heuristics = [], [] frame_count = 0 while True: ret, frame = cap.read() if not ret: break if frame_count % frame_interval == 0: model_score = predict_frame(frame) if model_score is None: frame_count += 1 continue heuristic_score = heuristic_texture_analysis(frame) combined_score = 0.8 * model_score + 0.2 * heuristic_score temporal_score = temporal_model.update(combined_score) frame_preds.append(temporal_score) heuristics.append(heuristic_score) frame_count += 1 cap.release() if not frame_preds: logger.error("❌ No valid frames processed. Returning unknown result.") return {"top": {"label": "unknown", "score": 0.0}} model_score = float(np.mean(frame_preds)) heuristic_score = float(np.mean(heuristics)) final_score = float(np.clip(model_score, 0, 1)) logger.info(f"✅ Video processed | Final Score: {final_score:.4f}") return { "top": { "label": "fake" if final_score > 0.55 else "real", "score": round(final_score, 4) }, "model_score": round(model_score, 4), "heuristic_score": round(heuristic_score, 4), } # ---------- Compatibility Wrapper ---------- def ensemble_predict_from_path(video_path): """Compatibility wrapper for main.py""" return ensemble_predict_video(video_path) #***********************************************************************************************************************************************************************************************************************