DraconicDragon commited on
Commit
1bdc5be
·
verified ·
1 Parent(s): 80d899c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +204 -0
app.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import io
3
+ import os
4
+ from typing import List, Tuple
5
+ import requests
6
+ from PIL import Image, UnidentifiedImageError
7
+ import pandas as pd
8
+
9
+ import gradio as gr
10
+ from transformers import pipeline
11
+
12
+ # Model choices
13
+ MODEL_CHOICES = {
14
+ "shadowlilac/aesthetic-shadow (v1, fp32)": {
15
+ "repo": "shadowlilac/aesthetic-shadow",
16
+ "precision": "fp32",
17
+ },
18
+ "NeoChen1024/aesthetic-shadow-v2-backup (fp32)": {
19
+ "repo": "NeoChen1024/aesthetic-shadow-v2-backup",
20
+ "precision": "fp32",
21
+ },
22
+ "Disty0/aesthetic-shadow-v2 (fp16)": {
23
+ "repo": "Disty0/aesthetic-shadow-v2",
24
+ "precision": "fp16",
25
+ },
26
+ "default": "Disty0/aesthetic-shadow-v2 (fp16)"
27
+ }
28
+
29
+ # Keep a global reference to the current pipeline
30
+ pipe = None
31
+ current_model_repo = None
32
+
33
+
34
+ def load_model(model_key: str):
35
+ """Load a model by dropdown key if not already loaded."""
36
+ global pipe, current_model_repo
37
+ info = MODEL_CHOICES[model_key]
38
+ repo = info["repo"]
39
+ if repo == current_model_repo and pipe is not None:
40
+ return pipe
41
+ # Load new pipeline on CPU
42
+ pipe = pipeline(
43
+ "image-classification",
44
+ model=repo,
45
+ device=-1,
46
+ )
47
+ current_model_repo = repo
48
+ return pipe
49
+
50
+
51
+ def pil_from_uploaded(uploaded) -> Image.Image:
52
+ if uploaded is None:
53
+ return None
54
+ if hasattr(uploaded, "name"):
55
+ try:
56
+ return Image.open(uploaded).convert("RGB")
57
+ except UnidentifiedImageError:
58
+ return None
59
+ if isinstance(uploaded, Image.Image):
60
+ return uploaded.convert("RGB")
61
+ return None
62
+
63
+
64
+ def pil_from_url(url: str) -> Image.Image:
65
+ if not url:
66
+ return None
67
+ try:
68
+ r = requests.get(url, timeout=10)
69
+ r.raise_for_status()
70
+ return Image.open(io.BytesIO(r.content)).convert("RGB")
71
+ except Exception:
72
+ return None
73
+
74
+
75
+ def extract_hq_score(preds) -> float:
76
+ for p in preds:
77
+ if str(p.get("label")).lower() == "hq":
78
+ return float(p.get("score", 0.0))
79
+ if len(preds):
80
+ return float(preds[0].get("score", 0.0))
81
+ return 0.0
82
+
83
+
84
+ def make_progress_html(score: float) -> str:
85
+ pct = int(round(score * 100))
86
+ return f"""
87
+ <div style="width:100%; border:1px solid #ddd; border-radius:6px; padding:6px;">
88
+ <div style="font-weight:600; margin-bottom:4px;">High-quality score: {score:.3f} ({pct}%)</div>
89
+ <div style="background:#eee; border-radius:4px; overflow:hidden;">
90
+ <div style="width:{pct}%; padding:6px 0; text-align:center; font-weight:600;">
91
+ {pct}%
92
+ </div>
93
+ </div>
94
+ </div>
95
+ """
96
+
97
+
98
+ def classify_images(images: List[Image.Image], pipe) -> List[float]:
99
+ if not images:
100
+ return []
101
+ results = pipe(images=images)
102
+ scores = [extract_hq_score(r) for r in results]
103
+ return scores
104
+
105
+
106
+ def run_classify(
107
+ uploaded_image,
108
+ url_input,
109
+ batch_files,
110
+ batch_urls_text,
111
+ model_key,
112
+ ) -> Tuple[str, List[List], dict]:
113
+ images_for_batch = []
114
+ names_for_batch = []
115
+
116
+ if batch_files:
117
+ for f in batch_files:
118
+ img = pil_from_uploaded(f)
119
+ if img is not None:
120
+ images_for_batch.append(img)
121
+ name = getattr(f, "name", "uploaded_file")
122
+ names_for_batch.append(os.path.basename(name))
123
+ if batch_urls_text:
124
+ for line in batch_urls_text.splitlines():
125
+ line = line.strip()
126
+ if not line:
127
+ continue
128
+ img = pil_from_url(line)
129
+ if img is not None:
130
+ images_for_batch.append(img)
131
+ names_for_batch.append(line)
132
+
133
+ pipe = load_model(model_key)
134
+
135
+ if images_for_batch:
136
+ scores = classify_images(images_for_batch, pipe)
137
+ rows = [[names_for_batch[i] if i < len(names_for_batch) else f"img_{i}", float(scores[i])] for i in range(len(scores))]
138
+ avg = sum(scores) / len(scores)
139
+ html = make_progress_html(avg)
140
+ return html, rows, {"mode": "batch"}
141
+
142
+ img = None
143
+ img_name = "input"
144
+ if url_input:
145
+ img = pil_from_url(url_input.strip())
146
+ img_name = url_input.strip()
147
+ if img is None and uploaded_image:
148
+ img = pil_from_uploaded(uploaded_image)
149
+ img_name = getattr(uploaded_image, "name", "uploaded_image")
150
+
151
+ if img is None:
152
+ return "<div style='color:#a00;'>No valid image(s) provided. Please upload or supply a URL.</div>", [], {"mode": "none"}
153
+
154
+ scores = classify_images([img], pipe)
155
+ score = float(scores[0]) if scores else 0.0
156
+ html = make_progress_html(score)
157
+ rows = [[img_name, score]]
158
+ return html, rows, {"mode": "single", "image": img}
159
+
160
+
161
+ # Build the Gradio UI
162
+ with gr.Blocks(title="Aesthetic Shadow - Anime Image Quality Classifier (CPU)") as demo:
163
+
164
+ with gr.Row():
165
+ with gr.Column(scale=2):
166
+ with gr.Tabs():
167
+ with gr.TabItem("Image Upload"):
168
+ uploaded_image = gr.File(label="Upload single image", file_count="single", type="file")
169
+ with gr.TabItem("URL"):
170
+ url_input = gr.Textbox(label="Image URL", placeholder="https://...", lines=1)
171
+ with gr.TabItem("Batch"):
172
+ batch_files = gr.File(label="Upload multiple images (batch)", file_count="multiple", type="file")
173
+ batch_urls_text = gr.Textbox(label="Batch URLs (one per line)", placeholder="https://...", lines=4)
174
+ gr.Markdown("- If batch inputs are provided they will be used as batch mode. Otherwise single image/url will be used.")
175
+
176
+ with gr.Column(scale=1):
177
+ model_dropdown = gr.Dropdown(
178
+ choices=[k for k in MODEL_CHOICES.keys() if k != "default"],
179
+ value=MODEL_CHOICES["default"],
180
+ label="Model Selection",
181
+ )
182
+ run_button = gr.Button("Run", variant="primary")
183
+
184
+ result_html = gr.HTML("<div>Result will appear here after running.</div>")
185
+ result_image = gr.Image(label="Input image (single mode)", interactive=False)
186
+ result_table = gr.Dataframe(headers=["source", "hq_score"], interactive=False)
187
+
188
+ def run_handler(up_img, url_txt, b_files, b_urls, model_key):
189
+ html, rows, meta = run_classify(up_img, url_txt, b_files, b_urls, model_key)
190
+ if meta.get("mode") == "single" and meta.get("image") is not None:
191
+ return html, meta.get("image"), pd.DataFrame(rows, columns=["source", "hq_score"])
192
+ else:
193
+ return html, None, pd.DataFrame(rows, columns=["source", "hq_score"])
194
+
195
+ run_button.click(
196
+ fn=run_handler,
197
+ inputs=[uploaded_image, url_input, batch_files, batch_urls_text, model_dropdown],
198
+ outputs=[result_html, result_image, result_table],
199
+ )
200
+
201
+ gr.Markdown("All Aesthetic Shadow models are by shadowlilac. V2 is using reuploads by other people.")
202
+
203
+ if __name__ == "__main__":
204
+ demo.launch()