Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| import torchvision.transforms as transforms | |
| import json | |
| import os | |
| import numpy as np | |
| import pandas as pd | |
| import random | |
| import onnxruntime as ort | |
| from huggingface_hub import HfApi, hf_hub_download | |
| from transformers import CLIPTokenizer, AutoImageProcessor, AutoModelForImageClassification | |
| from safetensors.torch import load_file as safe_load | |
| import subprocess | |
| from datetime import datetime | |
| # --- Config --- | |
| HUB_REPO_ID = "CDL-AMLRT/OpenArenaLeaderboard" | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| LOCAL_JSON = "leaderboard.json" | |
| HUB_JSON = "leaderboard.json" | |
| MODEL_PATH = "model.safetensors" # β updated filename | |
| MODEL_BACKBONE = "microsoft/swinv2-small-patch4-window16-256" | |
| CLIP_IMAGE_ENCODER_PATH = "clip_image_encoder.onnx" | |
| CLIP_TEXT_ENCODER_PATH = "clip_text_encoder.onnx" | |
| PROMPT_CSV_PATH = "prompts_0.csv" | |
| PROMPT_MATCH_THRESHOLD = 25 # percent | |
| # --- No-op for HF Space --- | |
| def load_assets(): | |
| print("Skipping snapshot_download. Assuming files exist via Git LFS in HF Space.") | |
| load_assets() | |
| # --- Load leaderboard --- | |
| def load_leaderboard(): | |
| try: | |
| # Download the latest leaderboard from the dataset repo | |
| leaderboard_path = hf_hub_download( | |
| repo_id=HUB_REPO_ID, | |
| repo_type="dataset", | |
| filename=HUB_JSON, | |
| token=HF_TOKEN | |
| ) | |
| with open(leaderboard_path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except Exception as e: | |
| print(f"Failed to load leaderboard from HF Hub: {e}") | |
| return {} | |
| def load_entries(): | |
| try: | |
| # grab the CSV of all past plays | |
| csv_path = hf_hub_download( | |
| repo_id=HUB_REPO_ID, | |
| repo_type="dataset", | |
| filename="test/leaderboard_entries.csv", | |
| token=HF_TOKEN | |
| ) | |
| df = pd.read_csv(csv_path) | |
| return df | |
| except Exception as e: | |
| print(f"Failed to load leaderboard from HF Hub: {e}") | |
| # if anything goes wrong, fall back to empty | |
| return pd.Dataframe(columns=["file_name","prompt","label","model","split","prediction","user","timestamp"]) | |
| leaderboard_scores = load_leaderboard() | |
| leaderboard_entries = load_entries() | |
| def save_leaderboard(): | |
| try: | |
| with open(HUB_JSON, "w", encoding="utf-8") as f: | |
| json.dump(leaderboard_scores, f, ensure_ascii=False) | |
| if HF_TOKEN is None: | |
| print("HF_TOKEN not set. Skipping push to hub.") | |
| return | |
| api = HfApi() | |
| api.upload_file( | |
| path_or_fileobj=HUB_JSON, | |
| path_in_repo=HUB_JSON, | |
| repo_id=HUB_REPO_ID, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| commit_message="Update leaderboard" | |
| ) | |
| except Exception as e: | |
| print(f"Failed to save leaderboard to HF Hub: {e}") | |
| # --- Load prompts from CSV --- | |
| def load_prompts(): | |
| try: | |
| df = pd.read_csv(PROMPT_CSV_PATH) | |
| if "prompt" in df.columns: | |
| return df["prompt"].dropna().tolist() | |
| else: | |
| print("CSV missing 'prompt' column.") | |
| return [] | |
| except Exception as e: | |
| print(f"Failed to load prompts: {e}") | |
| return [] | |
| PROMPT_LIST = load_prompts() | |
| # --- Load model + processor --- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| processor = AutoImageProcessor.from_pretrained(MODEL_BACKBONE) | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_BACKBONE) | |
| model.classifier = torch.nn.Linear(model.config.hidden_size, 2) | |
| model.load_state_dict(safe_load(MODEL_PATH, device="cpu"), strict=False) | |
| model.to(device) | |
| model.eval() | |
| # --- CLIP prompt matching --- | |
| clip_image_sess = ort.InferenceSession(CLIP_IMAGE_ENCODER_PATH, providers=["CPUExecutionProvider"]) | |
| clip_text_sess = ort.InferenceSession(CLIP_TEXT_ENCODER_PATH, providers=["CPUExecutionProvider"]) | |
| clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def compute_prompt_match(image: Image.Image, prompt: str) -> float: | |
| try: | |
| img_tensor = transform(image).unsqueeze(0).numpy().astype(np.float32) | |
| image_features = clip_image_sess.run(None, {clip_image_sess.get_inputs()[0].name: img_tensor})[0][0] | |
| image_features /= np.linalg.norm(image_features) | |
| inputs = clip_tokenizer(prompt, return_tensors="np", padding="max_length", truncation=True, max_length=77) | |
| input_ids = inputs["input_ids"] | |
| attention_mask = inputs["attention_mask"] | |
| text_features = clip_text_sess.run(None, { | |
| clip_text_sess.get_inputs()[0].name: input_ids, | |
| clip_text_sess.get_inputs()[1].name: attention_mask | |
| })[0][0] | |
| text_features /= np.linalg.norm(text_features) | |
| sim = np.dot(image_features, text_features) | |
| return round(sim * 100, 2) | |
| except Exception as e: | |
| print(f"CLIP ONNX match failed: {e}") | |
| return 0.0 | |
| # --- Main prediction logic --- | |
| def detect_with_model(image: Image.Image, prompt: str, username: str, model_name: str): | |
| if not username.strip(): | |
| return "Please enter your name.", None, [], gr.update(visible=True), gr.update(visible=False), username | |
| prompt_score = compute_prompt_match(image, prompt) | |
| if prompt_score < PROMPT_MATCH_THRESHOLD and (model_name.lower() != "real" and model_name != ""): | |
| message = f"β οΈ Prompt match too low ({round(prompt_score, 2)}%). Please generate an image that better matches the prompt." | |
| return message, None, leaderboard, gr.update(visible=True), gr.update(visible=False), username | |
| # Run model inference | |
| inputs = processor(image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| pred_class = torch.argmax(logits, dim=-1).item() | |
| prediction = "Real" if pred_class == 0 else "Fake" | |
| probs = torch.softmax(logits, dim=-1)[0] | |
| confidence = round(probs[pred_class].item() * 100, 2) | |
| score = 1 if prediction == "Real" else 0 | |
| message = f"π Prediction: {prediction} ({confidence}% confidence)\nπ§ Prompt match: {round(prompt_score, 2)}%" | |
| if prediction == "Real" and model_name.lower() != "real": | |
| leaderboard_scores[username] = leaderboard_scores.get(username, 0) + score | |
| message += "\nπ Nice! You fooled the AI. +1 point!" | |
| else: | |
| if model_name.lower() == "real": | |
| message += "\n You uploaded a real image, this does not count toward the leaderboard!" | |
| else: | |
| message += "\nπ The AI caught you this time. Try again!" | |
| save_leaderboard() | |
| sorted_scores = sorted(leaderboard_scores.items(), key=lambda x: x[1], reverse=True) | |
| leaderboard_table = [[name, points] for name, points in sorted_scores] | |
| type_image = "real" if (model_name.lower() == "real" or model_name == "") else "fake" | |
| image_dir = os.path.join("test", type_image) | |
| os.makedirs(image_dir, exist_ok=True) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| image_filename = f"{timestamp}.jpg" | |
| image_path = os.path.join(image_dir, image_filename) | |
| image.save(image_path) | |
| csv_path = os.path.join("test", "leaderboard_entries.csv") | |
| leaderboard_entries.loc[len(leaderboard_entries)] = [f"test/{type_image}/{image_filename}", prompt, type_image, model_name.lower(), "test", prediction.lower(), username, datetime.now().isoformat()] | |
| leaderboard_entries.to_csv(csv_path, index=False) | |
| try: | |
| from huggingface_hub import HfApi | |
| api = HfApi() | |
| api.upload_file( | |
| path_or_fileobj=image_path, | |
| path_in_repo=f"test/{type_image}/{image_filename}", | |
| repo_id=HUB_REPO_ID, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| commit_message="Add passing image" | |
| ) | |
| api.upload_file( | |
| path_or_fileobj=csv_path, | |
| path_in_repo="test/leaderboard_entries.csv", | |
| repo_id=HUB_REPO_ID, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| commit_message="Update leaderboard CSV" | |
| ) | |
| except Exception as e: | |
| print(f"Failed to save image to HF Hub: {e}") | |
| return ( | |
| message, | |
| image, | |
| leaderboard_table, | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| username | |
| ) | |
| def get_random_prompt(): | |
| return random.choice(PROMPT_LIST) if PROMPT_LIST else "A synthetic scene with dramatic lighting" | |
| def load_initial_state(): | |
| sorted_scores = sorted(leaderboard_scores.items(), key=lambda x: x[1], reverse=True) | |
| leaderboard_table = [[name, points] for name, points in sorted_scores] | |
| return gr.update(value=get_random_prompt()), leaderboard_table | |
| # --- Gradio UI --- | |
| with gr.Blocks(css=".gr-button {font-size: 16px !important}") as demo: | |
| gr.Markdown("## π OpenFake Arena") | |
| gr.Markdown("Welcome to the OpenFake Arena!\n\n**Your mission:** Generate a synthetic image for the prompt, upload it, and try to fool the AI detector into thinking itβs real.\n\n**Rules:**\n\n- You can modify the prompt on your end, but the image needs to have the same content. We verify the content with a CLIP similarity threshold.\n\n- Enter \"real\" in the model used to upload and test a real image. You don't need to follow the prompt for real images. Tips: you can also enter \"real\" if you just want to test the detector! We won't be collecting those images. \n\n- It is important to enter the correct model name for licensing.\n\n- Only synthetic images count toward the leaderboard!\n\n\nNote: The detector is still in early development. The prompt is not used for prediction, only the image.") | |
| with gr.Group(visible=True) as input_section: | |
| username_input = gr.Textbox(label="Your Name", placeholder="Enter your name", interactive=True) | |
| model_input = gr.Textbox(label="Model used, specify the version (e.g., Imagen 3, Dall-e 3, Midjourney 6). Write \"Real\" when uploading a real image.", placeholder="Name of the model used to generate the image", interactive=True) | |
| # π« Freeze this block: do not allow edits to the prompt input component's configuration. | |
| with gr.Row(): | |
| prompt_input = gr.Textbox( | |
| interactive=False, | |
| label="Prompt to match", | |
| placeholder="e.g., ...", | |
| value="", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Upload Synthetic Image") | |
| with gr.Row(): | |
| submit_btn = gr.Button("Upload") | |
| try_again_btn = gr.Button("Try Again", visible=False) | |
| with gr.Group(): | |
| gr.Markdown("### π― Result") | |
| with gr.Row(): | |
| prediction_output = gr.Textbox(label="Prediction", interactive=False, elem_id="prediction_box") | |
| image_output = gr.Image(label="Submitted Image", show_label=False) | |
| with gr.Group(): | |
| gr.Markdown("### π Leaderboard") | |
| leaderboard = gr.Dataframe( | |
| headers=["Username", "Score"], | |
| datatype=["str", "number"], | |
| interactive=False, | |
| row_count=5, | |
| visible=True | |
| ) | |
| submit_btn.click( | |
| fn=detect_with_model, | |
| inputs=[image_input, prompt_input, username_input, model_input], | |
| outputs=[ | |
| prediction_output, | |
| image_output, | |
| leaderboard, | |
| input_section, | |
| try_again_btn, | |
| username_input | |
| ] | |
| ) | |
| try_again_btn.click( | |
| fn=lambda name: ( | |
| "", # Clear prediction text | |
| None, # Clear uploaded image | |
| leaderboard, # Clear leaderboard (temporarily, gets reloaded on next submit) | |
| gr.update(visible=True), # Show input section | |
| gr.update(visible=False), # Hide "Try Again" button | |
| name, # Keep username | |
| gr.update(value=get_random_prompt()), # Load new prompt | |
| None # Clear image input | |
| ), | |
| inputs=[username_input], | |
| outputs=[ | |
| prediction_output, | |
| image_output, | |
| leaderboard, | |
| input_section, | |
| try_again_btn, | |
| username_input, | |
| prompt_input, | |
| image_input # β added output to clear image | |
| ] | |
| ) | |
| demo.load( | |
| fn=load_initial_state, | |
| outputs=[prompt_input, leaderboard] | |
| ) | |
| gr.HTML(""" | |
| <script> | |
| document.addEventListener('DOMContentLoaded', function () { | |
| const target = document.getElementById('prediction_box'); | |
| const observer = new MutationObserver(() => { | |
| if (target && target.innerText.trim() !== '') { | |
| window.scrollTo({ top: 0, behavior: 'smooth' }); | |
| } | |
| }); | |
| if (target) { | |
| observer.observe(target, { childList: true, subtree: true }); | |
| } | |
| }); | |
| </script> | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |