3v324v23 commited on
Commit
727dbd6
·
1 Parent(s): d13329c

Initial commit: add Gradio app and requirements

Browse files
Files changed (2) hide show
  1. app.py +115 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import PIL.Image
4
+ import numpy as np
5
+ import pandas as pd
6
+ import pathlib
7
+ import tempfile
8
+ import os
9
+ import shutil
10
+ import zipfile
11
+ import huggingface_hub as h
12
+ from huggingface_hub import HfApi, Repository
13
+ import autogluon.multimodal
14
+
15
+ model_repo_id = "nadakandrew/sign-identification-autogluon"
16
+ zip_filename = "autogluon_image_predictor_dir.zip"
17
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
18
+ cache_dir = pathlib.Path("hf_assets")
19
+ extract_dir = cache_dir / "predictor_native"
20
+
21
+ def prepare_predictor_dir() -> str:
22
+ cache_dir.mkdir(parents=True, exist_ok=True)
23
+ local_zip = h.hf_hub_download(
24
+ repo_id=model_repo_id,
25
+ filename=zip_filename,
26
+ repo_type="model",
27
+ token=HF_TOKEN,
28
+ local_dir=str(cache_dir),
29
+ local_dir_use_symlinks=False,
30
+ )
31
+ if extract_dir.exists():
32
+ shutil.rmtree(extract_dir)
33
+ extract_dir.mkdir(parents=True, exist_ok=True)
34
+ with zipfile.ZipFile(local_zip, "r") as zf:
35
+ zf.extractall(str(extract_dir))
36
+ contents = list(extract_dir.iterdir())
37
+ predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else extract_dir
38
+ return str(predictor_root)
39
+
40
+ predictor_dir = prepare_predictor_dir()
41
+ predictor = autogluon.multimodal.MultiModalPredictor.load(predictor_dir)
42
+
43
+ def do_predict(pil_img: PIL.Image.Image, preprocess: bool = True):
44
+ if pil_img is None:
45
+ return "No image provided.", None, None
46
+
47
+ original_img = pil_img.copy()
48
+ preprocessed_img = None
49
+
50
+ if preprocess:
51
+ target_size = (224, 224)
52
+ preprocessed_img = pil_img.resize(target_size).convert("RGB")
53
+ tmpdir = pathlib.Path(tempfile.mkdtemp())
54
+ img_path = tmpdir / "input.png"
55
+ preprocessed_img.save(img_path)
56
+ else:
57
+ tmpdir = pathlib.Path(tempfile.mkdtemp())
58
+ img_path = tmpdir / "input.png"
59
+ pil_img.save(img_path)
60
+
61
+
62
+ df = pd.DataFrame({"image": [str(img_path)]})
63
+
64
+ proba_df = predictor.predict_proba(df)
65
+
66
+ proba_df = proba_df.rename(columns={0: "class_0", 1: "class_1"})
67
+ row = proba_df.iloc[0]
68
+
69
+ pretty_dict = {
70
+ "Not a STOP sign": float(row.get("class_0", 0.0)),
71
+ "STOP sign": float(row.get("class_1", 0.0)),
72
+ }
73
+
74
+ return pretty_dict, original_img, preprocessed_img
75
+
76
+
77
+ EXAMPLES = [
78
+ ["https://universalsigns.com/wp-content/uploads/2022/08/StopSign-3.jpg"],
79
+ ["https://images.roadtrafficsigns.com/img/pla/K/student-drop-off-area-sign-k-2459_pl.png"],
80
+ ["https://hansonsign.com/wp-content/uploads/2024/05/donatos-662x646.jpg"]
81
+ ]
82
+
83
+ with gr.Blocks() as demo:
84
+
85
+ gr.Markdown("# Is this a STOP sign or not?")
86
+ gr.Markdown("Upload a photo to see results.")
87
+
88
+ with gr.Row():
89
+ image_in = gr.Image(type="pil", label="Input image", sources=["upload", "webcam"])
90
+ original_img_out = gr.Image(type="pil", label="Original image")
91
+ preprocessed_img_out = gr.Image(type="pil", label="Preprocessed image")
92
+
93
+ with gr.Row():
94
+ preprocess_checkbox = gr.Checkbox(label="Apply Preprocessing", value=True)
95
+
96
+ proba_pretty = gr.Label(num_top_classes=2, label="Class probabilities")
97
+
98
+ image_in.change(
99
+ fn=do_predict,
100
+ inputs=[image_in, preprocess_checkbox],
101
+ outputs=[proba_pretty, original_img_out, preprocessed_img_out]
102
+ )
103
+
104
+ gr.Examples(
105
+ examples=EXAMPLES,
106
+ inputs=[image_in],
107
+ label="Choose any one",
108
+ examples_per_page=8,
109
+ cache_examples=False,
110
+ )
111
+
112
+
113
+ if __name__ == "__main__":
114
+ demo.launch()
115
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ autogluon.multimodal
3
+ gradio
4
+ pillow
5
+ huggingface_hub