Fake_Detection / app.py
ElBeh's picture
Update app.py
92f6098 verified
import keras
import numpy as np
import streamlit as st
import random
import os
from PIL import Image, ImageOps
from io import BytesIO
from huggingface_hub import snapshot_download
def random_crop(img, min_size=160, max_size=2048, ratio=5/8):
width, height = img.size
crop_width = random.randint(min_size, min(max_size, width))
crop_height = int(crop_width * ratio)
if crop_height > height:
crop_height = height
crop_width = int(crop_height / ratio)
left = random.randint(0, width - crop_width)
top = random.randint(0, height - crop_height)
right = left + crop_width
bottom = top + crop_height
return img.crop((left, top, right, bottom))
def jpg_compression(img):
quality = random.randint(65, 100)
jpeg_image = BytesIO()
img.convert("RGB").save(jpeg_image, 'JPEG', quality=quality)
jpeg_image.seek(0)
compressed_img = Image.open(jpeg_image)
return compressed_img
def get_prediction(img):
x = np.array(img)
x = np.expand_dims(x, axis=0)
predictions = model.predict(x)
return predictions[0,:]
models = ['DDPM', 'Glide', 'Latent Diffusion', 'Palette', 'Stable Diffusion', 'VQ Diffusion', 'real', 'unseen_fake']
st.title("Fake Detection")
st.divider()
st.subheader("Modelvariant")
variant = st.selectbox(
"Choose the model",
("ResNet50v2-Basemodel", "ResNet50v2-Finetuned"),
index=None,
placeholder="Choose a model",
label_visibility="hidden"
)
st.write("You selected model: ", variant)
binary = st.toggle("Activate binary classification")
st.divider()
if variant == "ResNet50v2-Basemodel":
local_model_path = snapshot_download(repo_id="ElBeh/ma_basemodel")
else:
local_model_path = snapshot_download(repo_id="ElBeh/ma_finetuned_model")
model = keras.models.load_model(local_model_path)
st.subheader("Image Preprocessing")
crop = st.toggle("random crop")
compress = st.toggle("jpeg compression")
st.divider()
file_name = st.file_uploader("Choose an image...")
#st.button("execute classification", type="primary")
if file_name is not None:
col1, col2 = st.columns(2)
image = Image.open(file_name)
image = ImageOps.exif_transpose(image)
if image.size != (200, 200) or image.mode != 'RGB':
if crop:
image = random_crop(image)
image = image.resize((200, 200), Image.LANCZOS)
#if image.format != "JPEG" and compress :
if compress :
image = jpg_compression(image)
col1.image(image, use_column_width=True)
predictions = get_prediction(image)
if binary:
col2.header("Prediction")
if predictions[6] > 0.5:
col2.markdown(":green[real image!]")
else:
col2.markdown(":red[fake image!]")
else:
col2.header("Probabilities")
if crop:
st.button("re-crop")
for idx,p in enumerate(predictions):
col2.text(f"{ models[idx] }: { round(p * 100, 2)}%")