File size: 12,754 Bytes
1bdc5be
 
360d4ea
79944db
7d4c327
1bdc5be
 
 
 
 
 
 
 
bcf1656
1bdc5be
 
 
 
 
 
 
bfdf521
1bdc5be
 
 
 
 
bcf1656
 
1bdc5be
bcf1656
1bdc5be
bcf1656
1bdc5be
 
 
 
 
 
 
 
 
 
f4a4f4b
bcf1656
1bdc5be
 
 
 
 
 
 
474db85
1bdc5be
 
360d4ea
474db85
1bdc5be
 
 
 
 
bcf1656
 
 
 
 
1bdc5be
bcf1656
1bdc5be
bcf1656
1bdc5be
bcf1656
 
 
 
 
 
 
 
 
 
 
 
 
1bdc5be
bcf1656
1bdc5be
 
 
 
 
 
 
 
 
 
 
 
 
 
79944db
1bdc5be
 
 
f4a4f4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bdc5be
f4a4f4b
 
 
bfdf521
f4a4f4b
 
 
5d88f96
bfdf521
 
bcf1656
 
5d88f96
bcf1656
 
 
 
 
 
bfdf521
f4a4f4b
 
 
 
 
 
 
bcf1656
 
 
5d88f96
bfdf521
 
 
f4a4f4b
360d4ea
5d88f96
f4a4f4b
360d4ea
ee1c7b9
360d4ea
 
 
 
bfdf521
ee1c7b9
bcf1656
 
ee1c7b9
f4a4f4b
 
360d4ea
bcf1656
360d4ea
 
ee1c7b9
 
 
 
f4a4f4b
 
360d4ea
bcf1656
360d4ea
 
 
 
 
f4a4f4b
bcf1656
 
 
 
 
f4a4f4b
 
360d4ea
 
 
bcf1656
f4a4f4b
 
bfdf521
360d4ea
f4a4f4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfdf521
360d4ea
bcf1656
360d4ea
bcf1656
360d4ea
f4a4f4b
bcf1656
 
 
 
 
 
 
ee1c7b9
bcf1656
 
 
 
 
 
 
79944db
1bdc5be
bcf1656
79944db
360d4ea
 
f4a4f4b
 
 
 
 
 
 
 
 
 
 
bfdf521
1bdc5be
f4a4f4b
1bdc5be
bcf1656
ee1c7b9
bcf1656
7d4c327
 
 
 
 
 
 
932f8cf
7d4c327
bfdf521
7d4c327
932f8cf
 
 
 
ee1c7b9
7d4c327
 
bcf1656
 
 
 
 
 
 
 
 
 
 
 
ee1c7b9
1bdc5be
ee1c7b9
7d4c327
 
1bdc5be
7d4c327
38131cb
 
 
 
 
0f3c03b
 
 
 
1bdc5be
 
ee1c7b9
1bdc5be
bcf1656
ee1c7b9
1bdc5be
 
 
932f8cf
bfdf521
1bdc5be
 
bfdf521
 
1bdc5be
 
 
38131cb
bfdf521
1bdc5be
 
 
 
5d88f96
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
import io
import os
import tempfile
import uuid
import base64
from typing import List, Tuple
import requests
from PIL import Image, UnidentifiedImageError
import pandas as pd

import gradio as gr
from transformers import pipeline

# --- Model choices (default = a key string) ---
MODEL_CHOICES = {
    "shadowlilac/aesthetic-shadow (v1, fp32)": {
        "repo": "shadowlilac/aesthetic-shadow",
        "precision": "fp32",
    },
    "NeoChen1024/aesthetic-shadow-v2-backup (fp32)": {
        "repo": "NeoChen1024/aesthetic-shadow-v2-backup",
        "precision": "fp32",
    },
    "Disty0/aesthetic-shadow-v2 (fp16)": {
        "repo": "Disty0/aesthetic-shadow-v2",
        "precision": "fp16",
    },
    # keep a default key string for the dropdown
    "default": "Disty0/aesthetic-shadow-v2 (fp16)",
}
DEFAULT_MODEL_KEY = MODEL_CHOICES["default"]

# globals
pipe = None
current_model_repo = None


def load_model(model_key: str):
    global pipe, current_model_repo
    info = MODEL_CHOICES[model_key]
    repo = info["repo"]
    if repo == current_model_repo and pipe is not None:
        return pipe
    # Use device=-1 for CPU
    pipe = pipeline("image-classification", model=repo, device=-1)
    current_model_repo = repo
    return pipe


def pil_from_uploaded(uploaded) -> Image.Image:
    if uploaded is None:
        return None
    if isinstance(uploaded, str):
        try:
            return Image.open(uploaded).convert("RGB")
        except (UnidentifiedImageError, Exception):
            return None
    if isinstance(uploaded, Image.Image):
        return uploaded.convert("RGB")
    return None


def download_url_to_temp(url: str) -> Tuple[str, str]:
    """
    Download the URL and save to tempdir with basename if possible.
    Returns (saved_path, filename).
    """
    if not url:
        return None, None
    try:
        r = requests.get(url, timeout=15)
        r.raise_for_status()
        data = r.content
        # try to derive filename from URL path
        base = os.path.basename(url.split("?")[0])
        if base and "." in base:
            filename = base
        else:
            # fallback extension guess - try JPEG
            filename = f"{uuid.uuid4().hex}.jpg"
        tempdir = tempfile.gettempdir()
        save_path = os.path.join(tempdir, filename)
        with open(save_path, "wb") as f:
            f.write(data)
        return save_path, filename
    except Exception:
        return None, None


def extract_hq_score(preds) -> float:
    for p in preds:
        if str(p.get("label")).lower() == "hq":
            return float(p.get("score", 0.0))
    if len(preds):
        return float(preds[0].get("score", 0.0))
    return 0.0


def classify_images(images: List[Image.Image], pipe) -> List[float]:
    if not images:
        return []
    results = pipe(inputs=images)
    scores = [extract_hq_score(r) for r in results]
    return scores

# --- NEW HELPER FUNCTION ---
def pil_to_data_uri(img: Image.Image, max_dim=256) -> str:
    """Converts a PIL Image to a Base64 Data URI for preview."""
    if not img:
        return ""
    # Resize the image for a quick-loading preview in the table
    img_copy = img.copy()
    img_copy.thumbnail((max_dim, max_dim))
    buffer = io.BytesIO()
    try:
        # Save as JPEG for smaller data size and wide support
        img_copy.save(buffer, format="JPEG", quality=80)
    except Exception:
        return ""
    img_str = base64.b64encode(buffer.getvalue()).decode()
    return f"data:image/jpeg;base64,{img_str}"
# --------------------------


# --- MODIFIED FUNCTION SIGNATURE AND LOGIC ---
def build_results_table_html(entries: List[Tuple[str, str, str, float]]):
    """
    entries: list of tuples (data_uri_for_preview, download_path, filename, score)
    - data_uri_for_preview: base64 string for <img src> (Fixes display issue)
    - download_path: file path for the full-res image to use as <a href> (Fixes download issue)
    - filename: used as download filename
    """
    parts = []
    parts.append("<div id='results_table_container' style='padding:8px; border-radius:8px; background:#0f1113; color:#e8e8e8;'>")
    parts.append("<table style='width:100%; border-collapse:collapse; font-family:Arial,Helvetica,sans-serif;'>")
    # table header
    parts.append(
        "<thead><tr>"
        "<th style='text-align:left; padding:10px; border-bottom:1px solid rgba(255,255,255,0.6); color:#e8e8e8;'>Source</th>"
        "<th style='text-align:left; padding:10px; border-bottom:1px solid rgba(255,255,255,0.6); color:#e8e8e8; width:110px;'>Score</th>"
        "</tr></thead>"
    )
    parts.append("<tbody>")
    # NEW UNPACKING: (data_uri, download_path, filename, score)
    for (data_uri, download_path, filename, score) in entries:
        # image cell - no outline, image constrained to max-height:256px, object-fit:contain
        # link points to download_path. The img src is the data URI.
        href = download_path or "#"
        # src is now the data URI for guaranteed display.
        img_html = f"<a href='{href}' target='_blank' rel='noopener noreferrer' download='{filename}'><img src='{data_uri}' alt='{filename}' style='max-height:256px; width:auto; display:block; border:none; box-shadow:none;'></a>"
        score_html = f"<div style='font-weight:600; padding:4px 0; color:#f0f0f0;'>{score:.3f}</div>"
        parts.append("<tr style='vertical-align:middle;'>")
        parts.append(f"<td style='padding:10px; border-bottom:1px solid rgba(255,255,255,0.06);'>{img_html}</td>")
        parts.append(f"<td style='padding:10px; border-bottom:1px solid rgba(255,255,255,0.06);'>{score_html}</td>")
        parts.append("</tr>")
    parts.append("</tbody></table></div>")
    return "\n".join(parts)
# ---------------------------------------------


# --- MODIFIED FUNCTION LOGIC ---
def run_classify(
    uploaded_image_path,
    url_input,
    batch_files,
    batch_urls_text,
    model_key,
) -> str:
    """
    Returns HTML for the custom table. Uses original file paths for images.
    Works for single and batch modes.
    """
    images_to_classify = []
    original_paths = [] # Original paths for filename reference

    # Batch files (filepaths)
    if batch_files:
        for f in batch_files:
            try:
                img = Image.open(f).convert("RGB")
            except Exception:
                continue
            images_to_classify.append(img)
            original_paths.append(f)

    # Batch URLs
    if batch_urls_text:
        for line in batch_urls_text.splitlines():
            line = line.strip()
            if not line:
                continue
            saved_path, _ = download_url_to_temp(line)
            if saved_path:
                try:
                    img = Image.open(saved_path).convert("RGB")
                except Exception:
                    continue
                images_to_classify.append(img)
                original_paths.append(saved_path)

    pipe = load_model(model_key)

    # Batch mode
    if images_to_classify:
        scores = classify_images(images_to_classify, pipe)
        entries = []
        for i, score in enumerate(scores):
            img = images_to_classify[i]
            
            # 1. Create Data URI for display (Fixes "text not image")
            data_uri = pil_to_data_uri(img)
            
            # 2. Save the full-res image to a new temp path for download/full-view.
            # This is the full, uncorrupted image. (Fixes "corrupted/smaller on click")
            temp_download_path = os.path.join(tempfile.gettempdir(), f"download_{uuid.uuid4().hex}.jpg")
            img.save(temp_download_path, "JPEG", quality=95) # Save full quality
            
            # Get original filename for download attribute
            src_path = original_paths[i] if i < len(original_paths) else None
            filename = os.path.basename(src_path) if src_path else f"{uuid.uuid4().hex}.jpg"

            # Pass the new structure: (data_uri, download_path, filename, score)
            entries.append((data_uri, temp_download_path, filename, float(score)))
        
        return build_results_table_html(entries)

    # Single-image mode: prefer URL input, then upload
    img = None
    src_path = None
    if url_input:
        saved_path, _ = download_url_to_temp(url_input.strip())
        if saved_path:
            try:
                img = Image.open(saved_path).convert("RGB")
                src_path = saved_path
            except Exception:
                img = None
                src_path = None
    if img is None and uploaded_image_path:
        # uploaded_image_path is a filepath string
        try:
            img = Image.open(uploaded_image_path).convert("RGB")
            src_path = uploaded_image_path
        except Exception:
            img = None
            src_path = None

    if img is None:
        return "<div style='color:#ff7b7b;'>No valid image(s) provided. Please upload or supply a URL.</div>"

    scores = classify_images([img], pipe)
    score = float(scores[0]) if scores else 0.0

    # 1. Create Data URI for display
    data_uri = pil_to_data_uri(img)
    
    # 2. Save the full-res image to a new temp path for download/full-view.
    temp_download_path = os.path.join(tempfile.gettempdir(), f"download_{uuid.uuid4().hex}.jpg")
    img.save(temp_download_path, "JPEG", quality=95)
    
    filename = os.path.basename(src_path) if src_path else f"{uuid.uuid4().hex}.jpg"
    entries = [(data_uri, temp_download_path, filename, score)] # Use the new temp path

    return build_results_table_html(entries)

# ---------------------------------------------

# Build the Gradio UI with CSS tweaks
css = """
/* clamp the left upload image preview to 40vh */
/* make left column scroll if content would grow too tall, and limit to viewport */
#left_column {
  max-height: 90vh;        /* entire column won't exceed viewport*/
  overflow: auto;
}

/* clamp the upload preview to 40% of the viewport height */
#left_input_image img,
#left_input_image canvas {
  max-height: 40vh !important;
  height: auto !important;
  width: auto !important;
  object-fit: contain !important;
  display: block !important;
  margin: 0 auto !important;
}

/* results table tweaks kept from before */
#results_table_container table th {
  border-color: rgba(255,255,255,0.6) !important;
}
#results_table_container table td {
  border-color: rgba(255,255,255,0.06) !important;
}
#results_table_container img {
  outline: none !important;
  border: none !important;
  box-shadow: none !important;
  max-height: 256px;
}
"""

with gr.Blocks(title="Aesthetic Shadow - Anime Image Quality Classifier (CPU)", css=css) as demo:
    gr.Markdown("Aesthetic Shadow - Anime Image Quality Classifier (running on CPU | ETA for single image: ~50s (fp16 & fp32 are likely same speed on cpu))")
    gr.Markdown("All Aesthetic Shadow models are by shadowlilac. V2 is using reuploads by other people.")
    with gr.Row():
        with gr.Column(scale=2, elem_id="left_column"):
            gr.Markdown("### Input")
            with gr.Tabs():
                with gr.TabItem("Single"):
                    uploaded_image = gr.Image(label="Drop or click to upload", type="filepath", elem_id="left_input_image")
                    url_input = gr.Textbox(label="Image URL (optional)", placeholder="https://...", lines=1)
# it takes reaaaaally long and there is a bug i think where it will take images from both tabs instead of just the active one
                with gr.TabItem("Batch"): 
                    batch_files = gr.File(label="Upload multiple images (batch)", file_count="multiple", type="filepath")
                    batch_urls_text = gr.Textbox(label="Batch URLs (one per line)", placeholder="https://...", lines=4)

        with gr.Column(scale=1):
            gr.Markdown("### Model & Run")
            model_dropdown = gr.Dropdown(
                choices=[k for k in MODEL_CHOICES.keys() if k != "default"],
                value=DEFAULT_MODEL_KEY,
                label="Model Selection",
            )
            run_button = gr.Button("Run", variant="primary")
            gr.Markdown("### Results")
            result_table_html = gr.HTML("<div id='results_table_container'>Results will appear here after running.</div>")

            def run_handler(up_img, url_txt, b_files, b_urls, model_key):
                html = run_classify(up_img, url_txt, b_files, b_urls, model_key)
                return html

            run_button.click(
                fn=run_handler,
                inputs=[uploaded_image, url_input, batch_files, batch_urls_text, model_dropdown],
                outputs=[result_table_html],
            )


if __name__ == "__main__":
    demo.launch()