import io import os import tempfile import uuid import base64 from typing import List, Tuple import requests from PIL import Image, UnidentifiedImageError import pandas as pd import gradio as gr from transformers import pipeline # --- Model choices (default = a key string) --- MODEL_CHOICES = { "shadowlilac/aesthetic-shadow (v1, fp32)": { "repo": "shadowlilac/aesthetic-shadow", "precision": "fp32", }, "NeoChen1024/aesthetic-shadow-v2-backup (fp32)": { "repo": "NeoChen1024/aesthetic-shadow-v2-backup", "precision": "fp32", }, "Disty0/aesthetic-shadow-v2 (fp16)": { "repo": "Disty0/aesthetic-shadow-v2", "precision": "fp16", }, # keep a default key string for the dropdown "default": "Disty0/aesthetic-shadow-v2 (fp16)", } DEFAULT_MODEL_KEY = MODEL_CHOICES["default"] # globals pipe = None current_model_repo = None def load_model(model_key: str): global pipe, current_model_repo info = MODEL_CHOICES[model_key] repo = info["repo"] if repo == current_model_repo and pipe is not None: return pipe # Use device=-1 for CPU pipe = pipeline("image-classification", model=repo, device=-1) current_model_repo = repo return pipe def pil_from_uploaded(uploaded) -> Image.Image: if uploaded is None: return None if isinstance(uploaded, str): try: return Image.open(uploaded).convert("RGB") except (UnidentifiedImageError, Exception): return None if isinstance(uploaded, Image.Image): return uploaded.convert("RGB") return None def download_url_to_temp(url: str) -> Tuple[str, str]: """ Download the URL and save to tempdir with basename if possible. Returns (saved_path, filename). """ if not url: return None, None try: r = requests.get(url, timeout=15) r.raise_for_status() data = r.content # try to derive filename from URL path base = os.path.basename(url.split("?")[0]) if base and "." in base: filename = base else: # fallback extension guess - try JPEG filename = f"{uuid.uuid4().hex}.jpg" tempdir = tempfile.gettempdir() save_path = os.path.join(tempdir, filename) with open(save_path, "wb") as f: f.write(data) return save_path, filename except Exception: return None, None def extract_hq_score(preds) -> float: for p in preds: if str(p.get("label")).lower() == "hq": return float(p.get("score", 0.0)) if len(preds): return float(preds[0].get("score", 0.0)) return 0.0 def classify_images(images: List[Image.Image], pipe) -> List[float]: if not images: return [] results = pipe(inputs=images) scores = [extract_hq_score(r) for r in results] return scores # --- NEW HELPER FUNCTION --- def pil_to_data_uri(img: Image.Image, max_dim=256) -> str: """Converts a PIL Image to a Base64 Data URI for preview.""" if not img: return "" # Resize the image for a quick-loading preview in the table img_copy = img.copy() img_copy.thumbnail((max_dim, max_dim)) buffer = io.BytesIO() try: # Save as JPEG for smaller data size and wide support img_copy.save(buffer, format="JPEG", quality=80) except Exception: return "" img_str = base64.b64encode(buffer.getvalue()).decode() return f"data:image/jpeg;base64,{img_str}" # -------------------------- # --- MODIFIED FUNCTION SIGNATURE AND LOGIC --- def build_results_table_html(entries: List[Tuple[str, str, str, float]]): """ entries: list of tuples (data_uri_for_preview, download_path, filename, score) - data_uri_for_preview: base64 string for (Fixes display issue) - download_path: file path for the full-res image to use as (Fixes download issue) - filename: used as download filename """ parts = [] parts.append("
") parts.append("") # table header parts.append( "" "" "" "" ) parts.append("") # NEW UNPACKING: (data_uri, download_path, filename, score) for (data_uri, download_path, filename, score) in entries: # image cell - no outline, image constrained to max-height:256px, object-fit:contain # link points to download_path. The img src is the data URI. href = download_path or "#" # src is now the data URI for guaranteed display. img_html = f"{filename}" score_html = f"
{score:.3f}
" parts.append("") parts.append(f"") parts.append(f"") parts.append("") parts.append("
SourceScore
{img_html}{score_html}
") return "\n".join(parts) # --------------------------------------------- # --- MODIFIED FUNCTION LOGIC --- def run_classify( uploaded_image_path, url_input, batch_files, batch_urls_text, model_key, ) -> str: """ Returns HTML for the custom table. Uses original file paths for images. Works for single and batch modes. """ images_to_classify = [] original_paths = [] # Original paths for filename reference # Batch files (filepaths) if batch_files: for f in batch_files: try: img = Image.open(f).convert("RGB") except Exception: continue images_to_classify.append(img) original_paths.append(f) # Batch URLs if batch_urls_text: for line in batch_urls_text.splitlines(): line = line.strip() if not line: continue saved_path, _ = download_url_to_temp(line) if saved_path: try: img = Image.open(saved_path).convert("RGB") except Exception: continue images_to_classify.append(img) original_paths.append(saved_path) pipe = load_model(model_key) # Batch mode if images_to_classify: scores = classify_images(images_to_classify, pipe) entries = [] for i, score in enumerate(scores): img = images_to_classify[i] # 1. Create Data URI for display (Fixes "text not image") data_uri = pil_to_data_uri(img) # 2. Save the full-res image to a new temp path for download/full-view. # This is the full, uncorrupted image. (Fixes "corrupted/smaller on click") temp_download_path = os.path.join(tempfile.gettempdir(), f"download_{uuid.uuid4().hex}.jpg") img.save(temp_download_path, "JPEG", quality=95) # Save full quality # Get original filename for download attribute src_path = original_paths[i] if i < len(original_paths) else None filename = os.path.basename(src_path) if src_path else f"{uuid.uuid4().hex}.jpg" # Pass the new structure: (data_uri, download_path, filename, score) entries.append((data_uri, temp_download_path, filename, float(score))) return build_results_table_html(entries) # Single-image mode: prefer URL input, then upload img = None src_path = None if url_input: saved_path, _ = download_url_to_temp(url_input.strip()) if saved_path: try: img = Image.open(saved_path).convert("RGB") src_path = saved_path except Exception: img = None src_path = None if img is None and uploaded_image_path: # uploaded_image_path is a filepath string try: img = Image.open(uploaded_image_path).convert("RGB") src_path = uploaded_image_path except Exception: img = None src_path = None if img is None: return "
No valid image(s) provided. Please upload or supply a URL.
" scores = classify_images([img], pipe) score = float(scores[0]) if scores else 0.0 # 1. Create Data URI for display data_uri = pil_to_data_uri(img) # 2. Save the full-res image to a new temp path for download/full-view. temp_download_path = os.path.join(tempfile.gettempdir(), f"download_{uuid.uuid4().hex}.jpg") img.save(temp_download_path, "JPEG", quality=95) filename = os.path.basename(src_path) if src_path else f"{uuid.uuid4().hex}.jpg" entries = [(data_uri, temp_download_path, filename, score)] # Use the new temp path return build_results_table_html(entries) # --------------------------------------------- # Build the Gradio UI with CSS tweaks css = """ /* clamp the left upload image preview to 40vh */ /* make left column scroll if content would grow too tall, and limit to viewport */ #left_column { max-height: 90vh; /* entire column won't exceed viewport*/ overflow: auto; } /* clamp the upload preview to 40% of the viewport height */ #left_input_image img, #left_input_image canvas { max-height: 40vh !important; height: auto !important; width: auto !important; object-fit: contain !important; display: block !important; margin: 0 auto !important; } /* results table tweaks kept from before */ #results_table_container table th { border-color: rgba(255,255,255,0.6) !important; } #results_table_container table td { border-color: rgba(255,255,255,0.06) !important; } #results_table_container img { outline: none !important; border: none !important; box-shadow: none !important; max-height: 256px; } """ with gr.Blocks(title="Aesthetic Shadow - Anime Image Quality Classifier (CPU)", css=css) as demo: gr.Markdown("Aesthetic Shadow - Anime Image Quality Classifier (running on CPU | ETA for single image: ~50s (fp16 & fp32 are likely same speed on cpu))") gr.Markdown("All Aesthetic Shadow models are by shadowlilac. V2 is using reuploads by other people.") with gr.Row(): with gr.Column(scale=2, elem_id="left_column"): gr.Markdown("### Input") with gr.Tabs(): with gr.TabItem("Single"): uploaded_image = gr.Image(label="Drop or click to upload", type="filepath", elem_id="left_input_image") url_input = gr.Textbox(label="Image URL (optional)", placeholder="https://...", lines=1) # it takes reaaaaally long and there is a bug i think where it will take images from both tabs instead of just the active one with gr.TabItem("Batch"): batch_files = gr.File(label="Upload multiple images (batch)", file_count="multiple", type="filepath") batch_urls_text = gr.Textbox(label="Batch URLs (one per line)", placeholder="https://...", lines=4) with gr.Column(scale=1): gr.Markdown("### Model & Run") model_dropdown = gr.Dropdown( choices=[k for k in MODEL_CHOICES.keys() if k != "default"], value=DEFAULT_MODEL_KEY, label="Model Selection", ) run_button = gr.Button("Run", variant="primary") gr.Markdown("### Results") result_table_html = gr.HTML("
Results will appear here after running.
") def run_handler(up_img, url_txt, b_files, b_urls, model_key): html = run_classify(up_img, url_txt, b_files, b_urls, model_key) return html run_button.click( fn=run_handler, inputs=[uploaded_image, url_input, batch_files, batch_urls_text, model_dropdown], outputs=[result_table_html], ) if __name__ == "__main__": demo.launch()