File size: 3,489 Bytes
727dbd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

import gradio as gr
import PIL.Image
import numpy as np
import pandas as pd
import pathlib
import tempfile
import os
import shutil
import zipfile
import huggingface_hub as h
from huggingface_hub import HfApi, Repository
import autogluon.multimodal

model_repo_id = "nadakandrew/sign-identification-autogluon"
zip_filename  = "autogluon_image_predictor_dir.zip"
HF_TOKEN = os.getenv("HF_TOKEN", None)
cache_dir   = pathlib.Path("hf_assets")
extract_dir = cache_dir / "predictor_native"

def prepare_predictor_dir() -> str:
    cache_dir.mkdir(parents=True, exist_ok=True)
    local_zip = h.hf_hub_download(
        repo_id=model_repo_id,
        filename=zip_filename,
        repo_type="model",
        token=HF_TOKEN,
        local_dir=str(cache_dir),
        local_dir_use_symlinks=False,
    )
    if extract_dir.exists():
        shutil.rmtree(extract_dir)
    extract_dir.mkdir(parents=True, exist_ok=True)
    with zipfile.ZipFile(local_zip, "r") as zf:
        zf.extractall(str(extract_dir))
    contents = list(extract_dir.iterdir())
    predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else extract_dir
    return str(predictor_root)

predictor_dir = prepare_predictor_dir()
predictor = autogluon.multimodal.MultiModalPredictor.load(predictor_dir)

def do_predict(pil_img: PIL.Image.Image, preprocess: bool = True):
    if pil_img is None:
        return "No image provided.", None, None

    original_img = pil_img.copy()
    preprocessed_img = None

    if preprocess:
        target_size = (224, 224)
        preprocessed_img = pil_img.resize(target_size).convert("RGB")
        tmpdir = pathlib.Path(tempfile.mkdtemp())
        img_path = tmpdir / "input.png"
        preprocessed_img.save(img_path)
    else:
        tmpdir = pathlib.Path(tempfile.mkdtemp())
        img_path = tmpdir / "input.png"
        pil_img.save(img_path)


    df = pd.DataFrame({"image": [str(img_path)]})

    proba_df = predictor.predict_proba(df)

    proba_df = proba_df.rename(columns={0: "class_0", 1: "class_1"})
    row = proba_df.iloc[0]

    pretty_dict = {
        "Not a STOP sign": float(row.get("class_0", 0.0)),
        "STOP sign": float(row.get("class_1", 0.0)),
    }

    return pretty_dict, original_img, preprocessed_img


EXAMPLES = [
    ["https://universalsigns.com/wp-content/uploads/2022/08/StopSign-3.jpg"],
    ["https://images.roadtrafficsigns.com/img/pla/K/student-drop-off-area-sign-k-2459_pl.png"],
    ["https://hansonsign.com/wp-content/uploads/2024/05/donatos-662x646.jpg"]
]

with gr.Blocks() as demo:

    gr.Markdown("# Is this a STOP sign or not?")
    gr.Markdown("Upload a photo to see results.")

    with gr.Row():
        image_in = gr.Image(type="pil", label="Input image", sources=["upload", "webcam"])
        original_img_out = gr.Image(type="pil", label="Original image")
        preprocessed_img_out = gr.Image(type="pil", label="Preprocessed image")

    with gr.Row():
        preprocess_checkbox = gr.Checkbox(label="Apply Preprocessing", value=True)

    proba_pretty = gr.Label(num_top_classes=2, label="Class probabilities")

    image_in.change(
        fn=do_predict,
        inputs=[image_in, preprocess_checkbox],
        outputs=[proba_pretty, original_img_out, preprocessed_img_out]
    )

    gr.Examples(
        examples=EXAMPLES,
        inputs=[image_in],
        label="Choose any one",
        examples_per_page=8,
        cache_examples=False,
    )


if __name__ == "__main__":
    demo.launch()