DraconicDragon's picture
ok no commenting in then
0f3c03b verified
import io
import os
import tempfile
import uuid
import base64
from typing import List, Tuple
import requests
from PIL import Image, UnidentifiedImageError
import pandas as pd
import gradio as gr
from transformers import pipeline
# --- Model choices (default = a key string) ---
MODEL_CHOICES = {
"shadowlilac/aesthetic-shadow (v1, fp32)": {
"repo": "shadowlilac/aesthetic-shadow",
"precision": "fp32",
},
"NeoChen1024/aesthetic-shadow-v2-backup (fp32)": {
"repo": "NeoChen1024/aesthetic-shadow-v2-backup",
"precision": "fp32",
},
"Disty0/aesthetic-shadow-v2 (fp16)": {
"repo": "Disty0/aesthetic-shadow-v2",
"precision": "fp16",
},
# keep a default key string for the dropdown
"default": "Disty0/aesthetic-shadow-v2 (fp16)",
}
DEFAULT_MODEL_KEY = MODEL_CHOICES["default"]
# globals
pipe = None
current_model_repo = None
def load_model(model_key: str):
global pipe, current_model_repo
info = MODEL_CHOICES[model_key]
repo = info["repo"]
if repo == current_model_repo and pipe is not None:
return pipe
# Use device=-1 for CPU
pipe = pipeline("image-classification", model=repo, device=-1)
current_model_repo = repo
return pipe
def pil_from_uploaded(uploaded) -> Image.Image:
if uploaded is None:
return None
if isinstance(uploaded, str):
try:
return Image.open(uploaded).convert("RGB")
except (UnidentifiedImageError, Exception):
return None
if isinstance(uploaded, Image.Image):
return uploaded.convert("RGB")
return None
def download_url_to_temp(url: str) -> Tuple[str, str]:
"""
Download the URL and save to tempdir with basename if possible.
Returns (saved_path, filename).
"""
if not url:
return None, None
try:
r = requests.get(url, timeout=15)
r.raise_for_status()
data = r.content
# try to derive filename from URL path
base = os.path.basename(url.split("?")[0])
if base and "." in base:
filename = base
else:
# fallback extension guess - try JPEG
filename = f"{uuid.uuid4().hex}.jpg"
tempdir = tempfile.gettempdir()
save_path = os.path.join(tempdir, filename)
with open(save_path, "wb") as f:
f.write(data)
return save_path, filename
except Exception:
return None, None
def extract_hq_score(preds) -> float:
for p in preds:
if str(p.get("label")).lower() == "hq":
return float(p.get("score", 0.0))
if len(preds):
return float(preds[0].get("score", 0.0))
return 0.0
def classify_images(images: List[Image.Image], pipe) -> List[float]:
if not images:
return []
results = pipe(inputs=images)
scores = [extract_hq_score(r) for r in results]
return scores
# --- NEW HELPER FUNCTION ---
def pil_to_data_uri(img: Image.Image, max_dim=256) -> str:
"""Converts a PIL Image to a Base64 Data URI for preview."""
if not img:
return ""
# Resize the image for a quick-loading preview in the table
img_copy = img.copy()
img_copy.thumbnail((max_dim, max_dim))
buffer = io.BytesIO()
try:
# Save as JPEG for smaller data size and wide support
img_copy.save(buffer, format="JPEG", quality=80)
except Exception:
return ""
img_str = base64.b64encode(buffer.getvalue()).decode()
return f"data:image/jpeg;base64,{img_str}"
# --------------------------
# --- MODIFIED FUNCTION SIGNATURE AND LOGIC ---
def build_results_table_html(entries: List[Tuple[str, str, str, float]]):
"""
entries: list of tuples (data_uri_for_preview, download_path, filename, score)
- data_uri_for_preview: base64 string for <img src> (Fixes display issue)
- download_path: file path for the full-res image to use as <a href> (Fixes download issue)
- filename: used as download filename
"""
parts = []
parts.append("<div id='results_table_container' style='padding:8px; border-radius:8px; background:#0f1113; color:#e8e8e8;'>")
parts.append("<table style='width:100%; border-collapse:collapse; font-family:Arial,Helvetica,sans-serif;'>")
# table header
parts.append(
"<thead><tr>"
"<th style='text-align:left; padding:10px; border-bottom:1px solid rgba(255,255,255,0.6); color:#e8e8e8;'>Source</th>"
"<th style='text-align:left; padding:10px; border-bottom:1px solid rgba(255,255,255,0.6); color:#e8e8e8; width:110px;'>Score</th>"
"</tr></thead>"
)
parts.append("<tbody>")
# NEW UNPACKING: (data_uri, download_path, filename, score)
for (data_uri, download_path, filename, score) in entries:
# image cell - no outline, image constrained to max-height:256px, object-fit:contain
# link points to download_path. The img src is the data URI.
href = download_path or "#"
# src is now the data URI for guaranteed display.
img_html = f"<a href='{href}' target='_blank' rel='noopener noreferrer' download='{filename}'><img src='{data_uri}' alt='{filename}' style='max-height:256px; width:auto; display:block; border:none; box-shadow:none;'></a>"
score_html = f"<div style='font-weight:600; padding:4px 0; color:#f0f0f0;'>{score:.3f}</div>"
parts.append("<tr style='vertical-align:middle;'>")
parts.append(f"<td style='padding:10px; border-bottom:1px solid rgba(255,255,255,0.06);'>{img_html}</td>")
parts.append(f"<td style='padding:10px; border-bottom:1px solid rgba(255,255,255,0.06);'>{score_html}</td>")
parts.append("</tr>")
parts.append("</tbody></table></div>")
return "\n".join(parts)
# ---------------------------------------------
# --- MODIFIED FUNCTION LOGIC ---
def run_classify(
uploaded_image_path,
url_input,
batch_files,
batch_urls_text,
model_key,
) -> str:
"""
Returns HTML for the custom table. Uses original file paths for images.
Works for single and batch modes.
"""
images_to_classify = []
original_paths = [] # Original paths for filename reference
# Batch files (filepaths)
if batch_files:
for f in batch_files:
try:
img = Image.open(f).convert("RGB")
except Exception:
continue
images_to_classify.append(img)
original_paths.append(f)
# Batch URLs
if batch_urls_text:
for line in batch_urls_text.splitlines():
line = line.strip()
if not line:
continue
saved_path, _ = download_url_to_temp(line)
if saved_path:
try:
img = Image.open(saved_path).convert("RGB")
except Exception:
continue
images_to_classify.append(img)
original_paths.append(saved_path)
pipe = load_model(model_key)
# Batch mode
if images_to_classify:
scores = classify_images(images_to_classify, pipe)
entries = []
for i, score in enumerate(scores):
img = images_to_classify[i]
# 1. Create Data URI for display (Fixes "text not image")
data_uri = pil_to_data_uri(img)
# 2. Save the full-res image to a new temp path for download/full-view.
# This is the full, uncorrupted image. (Fixes "corrupted/smaller on click")
temp_download_path = os.path.join(tempfile.gettempdir(), f"download_{uuid.uuid4().hex}.jpg")
img.save(temp_download_path, "JPEG", quality=95) # Save full quality
# Get original filename for download attribute
src_path = original_paths[i] if i < len(original_paths) else None
filename = os.path.basename(src_path) if src_path else f"{uuid.uuid4().hex}.jpg"
# Pass the new structure: (data_uri, download_path, filename, score)
entries.append((data_uri, temp_download_path, filename, float(score)))
return build_results_table_html(entries)
# Single-image mode: prefer URL input, then upload
img = None
src_path = None
if url_input:
saved_path, _ = download_url_to_temp(url_input.strip())
if saved_path:
try:
img = Image.open(saved_path).convert("RGB")
src_path = saved_path
except Exception:
img = None
src_path = None
if img is None and uploaded_image_path:
# uploaded_image_path is a filepath string
try:
img = Image.open(uploaded_image_path).convert("RGB")
src_path = uploaded_image_path
except Exception:
img = None
src_path = None
if img is None:
return "<div style='color:#ff7b7b;'>No valid image(s) provided. Please upload or supply a URL.</div>"
scores = classify_images([img], pipe)
score = float(scores[0]) if scores else 0.0
# 1. Create Data URI for display
data_uri = pil_to_data_uri(img)
# 2. Save the full-res image to a new temp path for download/full-view.
temp_download_path = os.path.join(tempfile.gettempdir(), f"download_{uuid.uuid4().hex}.jpg")
img.save(temp_download_path, "JPEG", quality=95)
filename = os.path.basename(src_path) if src_path else f"{uuid.uuid4().hex}.jpg"
entries = [(data_uri, temp_download_path, filename, score)] # Use the new temp path
return build_results_table_html(entries)
# ---------------------------------------------
# Build the Gradio UI with CSS tweaks
css = """
/* clamp the left upload image preview to 40vh */
/* make left column scroll if content would grow too tall, and limit to viewport */
#left_column {
max-height: 90vh; /* entire column won't exceed viewport*/
overflow: auto;
}
/* clamp the upload preview to 40% of the viewport height */
#left_input_image img,
#left_input_image canvas {
max-height: 40vh !important;
height: auto !important;
width: auto !important;
object-fit: contain !important;
display: block !important;
margin: 0 auto !important;
}
/* results table tweaks kept from before */
#results_table_container table th {
border-color: rgba(255,255,255,0.6) !important;
}
#results_table_container table td {
border-color: rgba(255,255,255,0.06) !important;
}
#results_table_container img {
outline: none !important;
border: none !important;
box-shadow: none !important;
max-height: 256px;
}
"""
with gr.Blocks(title="Aesthetic Shadow - Anime Image Quality Classifier (CPU)", css=css) as demo:
gr.Markdown("Aesthetic Shadow - Anime Image Quality Classifier (running on CPU | ETA for single image: ~50s (fp16 & fp32 are likely same speed on cpu))")
gr.Markdown("All Aesthetic Shadow models are by shadowlilac. V2 is using reuploads by other people.")
with gr.Row():
with gr.Column(scale=2, elem_id="left_column"):
gr.Markdown("### Input")
with gr.Tabs():
with gr.TabItem("Single"):
uploaded_image = gr.Image(label="Drop or click to upload", type="filepath", elem_id="left_input_image")
url_input = gr.Textbox(label="Image URL (optional)", placeholder="https://...", lines=1)
# it takes reaaaaally long and there is a bug i think where it will take images from both tabs instead of just the active one
with gr.TabItem("Batch"):
batch_files = gr.File(label="Upload multiple images (batch)", file_count="multiple", type="filepath")
batch_urls_text = gr.Textbox(label="Batch URLs (one per line)", placeholder="https://...", lines=4)
with gr.Column(scale=1):
gr.Markdown("### Model & Run")
model_dropdown = gr.Dropdown(
choices=[k for k in MODEL_CHOICES.keys() if k != "default"],
value=DEFAULT_MODEL_KEY,
label="Model Selection",
)
run_button = gr.Button("Run", variant="primary")
gr.Markdown("### Results")
result_table_html = gr.HTML("<div id='results_table_container'>Results will appear here after running.</div>")
def run_handler(up_img, url_txt, b_files, b_urls, model_key):
html = run_classify(up_img, url_txt, b_files, b_urls, model_key)
return html
run_button.click(
fn=run_handler,
inputs=[uploaded_image, url_input, batch_files, batch_urls_text, model_dropdown],
outputs=[result_table_html],
)
if __name__ == "__main__":
demo.launch()