HW3_Image / app.py
mohitk24's picture
Upload app.py with huggingface_hub
8efa7a7 verified
raw
history blame
3.4 kB
import os
import shutil
import zipfile
import pathlib
import tempfile
import gradio as gr
import pandas as pd
import numpy as np
import PIL.Image
import huggingface_hub as h
import autogluon.multimodal
model_repo_id = "nadakandrew/sign-identification-autogluon"
zip_filename = "autogluon_image_predictor_dir.zip"
HF_TOKEN = os.getenv("HF_TOKEN", None)
cache_dir = pathlib.Path("hf_assets")
extract_dir = cache_dir / "predictor_native"
def prepare_predictor_dir():
cache_dir.mkdir(parents=True, exist_ok=True)
local_zip = h.hf_hub_download(
repo_id=model_repo_id,
filename=zip_filename,
repo_type="model",
token=HF_TOKEN,
local_dir=str(cache_dir),
local_dir_use_symlinks=False,
)
if extract_dir.exists():
shutil.rmtree(extract_dir)
extract_dir.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(local_zip, "r") as zf:
zf.extractall(str(extract_dir))
contents = list(extract_dir.iterdir())
predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else extract_dir
return str(predictor_root)
predictor_dir = prepare_predictor_dir()
predictor = autogluon.multimodal.MultiModalPredictor.load(predictor_dir)
def do_predict(pil_img, preprocess=True):
if pil_img is None:
return "No image provided.", None, None
original_img = pil_img.copy()
preprocessed_img = None
if preprocess:
target_size = (224, 224)
preprocessed_img = pil_img.resize(target_size).convert("RGB")
tmpdir = pathlib.Path(tempfile.mkdtemp())
img_path = tmpdir / "input.png"
preprocessed_img.save(img_path)
else:
tmpdir = pathlib.Path(tempfile.mkdtemp())
img_path = tmpdir / "input.png"
pil_img.save(img_path)
df = pd.DataFrame({"image": [str(img_path)]})
proba_df = predictor.predict_proba(df)
proba_df = proba_df.rename(columns={0: "class_0", 1: "class_1"})
row = proba_df.iloc[0]
pretty_dict = {
"Not a STOP sign": float(row.get("class_0", 0.0)),
"STOP sign": float(row.get("class_1", 0.0)),
}
return pretty_dict, original_img, preprocessed_img
EXAMPLES = [
["https://universalsigns.com/wp-content/uploads/2022/08/StopSign-3.jpg"],
["https://images.roadtrafficsigns.com/img/pla/K/student-drop-off-area-sign-k-2459_pl.png"],
["https://hansonsign.com/wp-content/uploads/2024/05/donatos-662x646.jpg"]
]
with gr.Blocks() as demo:
gr.Markdown("# Is this a STOP sign or not?")
gr.Markdown("Upload a photo to see results.")
with gr.Row():
image_in = gr.Image(type="pil", label="Input image", sources=["upload", "webcam"])
original_img_out = gr.Image(type="pil", label="Original image")
preprocessed_img_out = gr.Image(type="pil", label="Preprocessed image")
with gr.Row():
preprocess_checkbox = gr.Checkbox(label="Apply Preprocessing", value=True)
proba_pretty = gr.Label(num_top_classes=2, label="Class probabilities")
image_in.change(
fn=do_predict,
inputs=[image_in, preprocess_checkbox],
outputs=[proba_pretty, original_img_out, preprocessed_img_out]
)
gr.Examples(
examples=EXAMPLES,
inputs=[image_in],
label="Choose any one",
examples_per_page=8,
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()