HW3_Image / app.py
3v324v23's picture
Initial commit: add Gradio app and requirements
7d568cf
raw
history blame
3.49 kB
import gradio as gr
import PIL.Image
import numpy as np
import pandas as pd
import pathlib
import tempfile
import os
import shutil
import zipfile
import huggingface_hub as h
from huggingface_hub import HfApi, Repository
import autogluon.multimodal
model_repo_id = "nadakandrew/sign-identification-autogluon"
zip_filename = "autogluon_image_predictor_dir.zip"
HF_TOKEN = os.getenv("HF_TOKEN", None)
cache_dir = pathlib.Path("hf_assets")
extract_dir = cache_dir / "predictor_native"
def prepare_predictor_dir() -> str:
cache_dir.mkdir(parents=True, exist_ok=True)
local_zip = h.hf_hub_download(
repo_id=model_repo_id,
filename=zip_filename,
repo_type="model",
token=HF_TOKEN,
local_dir=str(cache_dir),
local_dir_use_symlinks=False,
)
if extract_dir.exists():
shutil.rmtree(extract_dir)
extract_dir.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(local_zip, "r") as zf:
zf.extractall(str(extract_dir))
contents = list(extract_dir.iterdir())
predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else extract_dir
return str(predictor_root)
predictor_dir = prepare_predictor_dir()
predictor = autogluon.multimodal.MultiModalPredictor.load(predictor_dir)
def do_predict(pil_img: PIL.Image.Image, preprocess: bool = True):
if pil_img is None:
return "No image provided.", None, None
original_img = pil_img.copy()
preprocessed_img = None
if preprocess:
target_size = (224, 224)
preprocessed_img = pil_img.resize(target_size).convert("RGB")
tmpdir = pathlib.Path(tempfile.mkdtemp())
img_path = tmpdir / "input.png"
preprocessed_img.save(img_path)
else:
tmpdir = pathlib.Path(tempfile.mkdtemp())
img_path = tmpdir / "input.png"
pil_img.save(img_path)
df = pd.DataFrame({"image": [str(img_path)]})
proba_df = predictor.predict_proba(df)
proba_df = proba_df.rename(columns={0: "class_0", 1: "class_1"})
row = proba_df.iloc[0]
pretty_dict = {
"Not a STOP sign": float(row.get("class_0", 0.0)),
"STOP sign": float(row.get("class_1", 0.0)),
}
return pretty_dict, original_img, preprocessed_img
EXAMPLES = [
["https://universalsigns.com/wp-content/uploads/2022/08/StopSign-3.jpg"],
["https://images.roadtrafficsigns.com/img/pla/K/student-drop-off-area-sign-k-2459_pl.png"],
["https://hansonsign.com/wp-content/uploads/2024/05/donatos-662x646.jpg"]
]
with gr.Blocks() as demo:
gr.Markdown("# Is this a STOP sign or not?")
gr.Markdown("Upload a photo to see results.")
with gr.Row():
image_in = gr.Image(type="pil", label="Input image", sources=["upload", "webcam"])
original_img_out = gr.Image(type="pil", label="Original image")
preprocessed_img_out = gr.Image(type="pil", label="Preprocessed image")
with gr.Row():
preprocess_checkbox = gr.Checkbox(label="Apply Preprocessing", value=True)
proba_pretty = gr.Label(num_top_classes=2, label="Class probabilities")
image_in.change(
fn=do_predict,
inputs=[image_in, preprocess_checkbox],
outputs=[proba_pretty, original_img_out, preprocessed_img_out]
)
gr.Examples(
examples=EXAMPLES,
inputs=[image_in],
label="Choose any one",
examples_per_page=8,
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()