Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import tempfile | |
| import uuid | |
| import base64 | |
| from typing import List, Tuple | |
| import requests | |
| from PIL import Image, UnidentifiedImageError | |
| import pandas as pd | |
| import gradio as gr | |
| from transformers import pipeline | |
| # --- Model choices (default = a key string) --- | |
| MODEL_CHOICES = { | |
| "shadowlilac/aesthetic-shadow (v1, fp32)": { | |
| "repo": "shadowlilac/aesthetic-shadow", | |
| "precision": "fp32", | |
| }, | |
| "NeoChen1024/aesthetic-shadow-v2-backup (fp32)": { | |
| "repo": "NeoChen1024/aesthetic-shadow-v2-backup", | |
| "precision": "fp32", | |
| }, | |
| "Disty0/aesthetic-shadow-v2 (fp16)": { | |
| "repo": "Disty0/aesthetic-shadow-v2", | |
| "precision": "fp16", | |
| }, | |
| # keep a default key string for the dropdown | |
| "default": "Disty0/aesthetic-shadow-v2 (fp16)", | |
| } | |
| DEFAULT_MODEL_KEY = MODEL_CHOICES["default"] | |
| # globals | |
| pipe = None | |
| current_model_repo = None | |
| def load_model(model_key: str): | |
| global pipe, current_model_repo | |
| info = MODEL_CHOICES[model_key] | |
| repo = info["repo"] | |
| if repo == current_model_repo and pipe is not None: | |
| return pipe | |
| # Use device=-1 for CPU | |
| pipe = pipeline("image-classification", model=repo, device=-1) | |
| current_model_repo = repo | |
| return pipe | |
| def pil_from_uploaded(uploaded) -> Image.Image: | |
| if uploaded is None: | |
| return None | |
| if isinstance(uploaded, str): | |
| try: | |
| return Image.open(uploaded).convert("RGB") | |
| except (UnidentifiedImageError, Exception): | |
| return None | |
| if isinstance(uploaded, Image.Image): | |
| return uploaded.convert("RGB") | |
| return None | |
| def download_url_to_temp(url: str) -> Tuple[str, str]: | |
| """ | |
| Download the URL and save to tempdir with basename if possible. | |
| Returns (saved_path, filename). | |
| """ | |
| if not url: | |
| return None, None | |
| try: | |
| r = requests.get(url, timeout=15) | |
| r.raise_for_status() | |
| data = r.content | |
| # try to derive filename from URL path | |
| base = os.path.basename(url.split("?")[0]) | |
| if base and "." in base: | |
| filename = base | |
| else: | |
| # fallback extension guess - try JPEG | |
| filename = f"{uuid.uuid4().hex}.jpg" | |
| tempdir = tempfile.gettempdir() | |
| save_path = os.path.join(tempdir, filename) | |
| with open(save_path, "wb") as f: | |
| f.write(data) | |
| return save_path, filename | |
| except Exception: | |
| return None, None | |
| def extract_hq_score(preds) -> float: | |
| for p in preds: | |
| if str(p.get("label")).lower() == "hq": | |
| return float(p.get("score", 0.0)) | |
| if len(preds): | |
| return float(preds[0].get("score", 0.0)) | |
| return 0.0 | |
| def classify_images(images: List[Image.Image], pipe) -> List[float]: | |
| if not images: | |
| return [] | |
| results = pipe(inputs=images) | |
| scores = [extract_hq_score(r) for r in results] | |
| return scores | |
| # --- NEW HELPER FUNCTION --- | |
| def pil_to_data_uri(img: Image.Image, max_dim=256) -> str: | |
| """Converts a PIL Image to a Base64 Data URI for preview.""" | |
| if not img: | |
| return "" | |
| # Resize the image for a quick-loading preview in the table | |
| img_copy = img.copy() | |
| img_copy.thumbnail((max_dim, max_dim)) | |
| buffer = io.BytesIO() | |
| try: | |
| # Save as JPEG for smaller data size and wide support | |
| img_copy.save(buffer, format="JPEG", quality=80) | |
| except Exception: | |
| return "" | |
| img_str = base64.b64encode(buffer.getvalue()).decode() | |
| return f"data:image/jpeg;base64,{img_str}" | |
| # -------------------------- | |
| # --- MODIFIED FUNCTION SIGNATURE AND LOGIC --- | |
| def build_results_table_html(entries: List[Tuple[str, str, str, float]]): | |
| """ | |
| entries: list of tuples (data_uri_for_preview, download_path, filename, score) | |
| - data_uri_for_preview: base64 string for <img src> (Fixes display issue) | |
| - download_path: file path for the full-res image to use as <a href> (Fixes download issue) | |
| - filename: used as download filename | |
| """ | |
| parts = [] | |
| parts.append("<div id='results_table_container' style='padding:8px; border-radius:8px; background:#0f1113; color:#e8e8e8;'>") | |
| parts.append("<table style='width:100%; border-collapse:collapse; font-family:Arial,Helvetica,sans-serif;'>") | |
| # table header | |
| parts.append( | |
| "<thead><tr>" | |
| "<th style='text-align:left; padding:10px; border-bottom:1px solid rgba(255,255,255,0.6); color:#e8e8e8;'>Source</th>" | |
| "<th style='text-align:left; padding:10px; border-bottom:1px solid rgba(255,255,255,0.6); color:#e8e8e8; width:110px;'>Score</th>" | |
| "</tr></thead>" | |
| ) | |
| parts.append("<tbody>") | |
| # NEW UNPACKING: (data_uri, download_path, filename, score) | |
| for (data_uri, download_path, filename, score) in entries: | |
| # image cell - no outline, image constrained to max-height:256px, object-fit:contain | |
| # link points to download_path. The img src is the data URI. | |
| href = download_path or "#" | |
| # src is now the data URI for guaranteed display. | |
| img_html = f"<a href='{href}' target='_blank' rel='noopener noreferrer' download='{filename}'><img src='{data_uri}' alt='{filename}' style='max-height:256px; width:auto; display:block; border:none; box-shadow:none;'></a>" | |
| score_html = f"<div style='font-weight:600; padding:4px 0; color:#f0f0f0;'>{score:.3f}</div>" | |
| parts.append("<tr style='vertical-align:middle;'>") | |
| parts.append(f"<td style='padding:10px; border-bottom:1px solid rgba(255,255,255,0.06);'>{img_html}</td>") | |
| parts.append(f"<td style='padding:10px; border-bottom:1px solid rgba(255,255,255,0.06);'>{score_html}</td>") | |
| parts.append("</tr>") | |
| parts.append("</tbody></table></div>") | |
| return "\n".join(parts) | |
| # --------------------------------------------- | |
| # --- MODIFIED FUNCTION LOGIC --- | |
| def run_classify( | |
| uploaded_image_path, | |
| url_input, | |
| batch_files, | |
| batch_urls_text, | |
| model_key, | |
| ) -> str: | |
| """ | |
| Returns HTML for the custom table. Uses original file paths for images. | |
| Works for single and batch modes. | |
| """ | |
| images_to_classify = [] | |
| original_paths = [] # Original paths for filename reference | |
| # Batch files (filepaths) | |
| if batch_files: | |
| for f in batch_files: | |
| try: | |
| img = Image.open(f).convert("RGB") | |
| except Exception: | |
| continue | |
| images_to_classify.append(img) | |
| original_paths.append(f) | |
| # Batch URLs | |
| if batch_urls_text: | |
| for line in batch_urls_text.splitlines(): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| saved_path, _ = download_url_to_temp(line) | |
| if saved_path: | |
| try: | |
| img = Image.open(saved_path).convert("RGB") | |
| except Exception: | |
| continue | |
| images_to_classify.append(img) | |
| original_paths.append(saved_path) | |
| pipe = load_model(model_key) | |
| # Batch mode | |
| if images_to_classify: | |
| scores = classify_images(images_to_classify, pipe) | |
| entries = [] | |
| for i, score in enumerate(scores): | |
| img = images_to_classify[i] | |
| # 1. Create Data URI for display (Fixes "text not image") | |
| data_uri = pil_to_data_uri(img) | |
| # 2. Save the full-res image to a new temp path for download/full-view. | |
| # This is the full, uncorrupted image. (Fixes "corrupted/smaller on click") | |
| temp_download_path = os.path.join(tempfile.gettempdir(), f"download_{uuid.uuid4().hex}.jpg") | |
| img.save(temp_download_path, "JPEG", quality=95) # Save full quality | |
| # Get original filename for download attribute | |
| src_path = original_paths[i] if i < len(original_paths) else None | |
| filename = os.path.basename(src_path) if src_path else f"{uuid.uuid4().hex}.jpg" | |
| # Pass the new structure: (data_uri, download_path, filename, score) | |
| entries.append((data_uri, temp_download_path, filename, float(score))) | |
| return build_results_table_html(entries) | |
| # Single-image mode: prefer URL input, then upload | |
| img = None | |
| src_path = None | |
| if url_input: | |
| saved_path, _ = download_url_to_temp(url_input.strip()) | |
| if saved_path: | |
| try: | |
| img = Image.open(saved_path).convert("RGB") | |
| src_path = saved_path | |
| except Exception: | |
| img = None | |
| src_path = None | |
| if img is None and uploaded_image_path: | |
| # uploaded_image_path is a filepath string | |
| try: | |
| img = Image.open(uploaded_image_path).convert("RGB") | |
| src_path = uploaded_image_path | |
| except Exception: | |
| img = None | |
| src_path = None | |
| if img is None: | |
| return "<div style='color:#ff7b7b;'>No valid image(s) provided. Please upload or supply a URL.</div>" | |
| scores = classify_images([img], pipe) | |
| score = float(scores[0]) if scores else 0.0 | |
| # 1. Create Data URI for display | |
| data_uri = pil_to_data_uri(img) | |
| # 2. Save the full-res image to a new temp path for download/full-view. | |
| temp_download_path = os.path.join(tempfile.gettempdir(), f"download_{uuid.uuid4().hex}.jpg") | |
| img.save(temp_download_path, "JPEG", quality=95) | |
| filename = os.path.basename(src_path) if src_path else f"{uuid.uuid4().hex}.jpg" | |
| entries = [(data_uri, temp_download_path, filename, score)] # Use the new temp path | |
| return build_results_table_html(entries) | |
| # --------------------------------------------- | |
| # Build the Gradio UI with CSS tweaks | |
| css = """ | |
| /* clamp the left upload image preview to 40vh */ | |
| /* make left column scroll if content would grow too tall, and limit to viewport */ | |
| #left_column { | |
| max-height: 90vh; /* entire column won't exceed viewport*/ | |
| overflow: auto; | |
| } | |
| /* clamp the upload preview to 40% of the viewport height */ | |
| #left_input_image img, | |
| #left_input_image canvas { | |
| max-height: 40vh !important; | |
| height: auto !important; | |
| width: auto !important; | |
| object-fit: contain !important; | |
| display: block !important; | |
| margin: 0 auto !important; | |
| } | |
| /* results table tweaks kept from before */ | |
| #results_table_container table th { | |
| border-color: rgba(255,255,255,0.6) !important; | |
| } | |
| #results_table_container table td { | |
| border-color: rgba(255,255,255,0.06) !important; | |
| } | |
| #results_table_container img { | |
| outline: none !important; | |
| border: none !important; | |
| box-shadow: none !important; | |
| max-height: 256px; | |
| } | |
| """ | |
| with gr.Blocks(title="Aesthetic Shadow - Anime Image Quality Classifier (CPU)", css=css) as demo: | |
| gr.Markdown("Aesthetic Shadow - Anime Image Quality Classifier (running on CPU | ETA for single image: ~50s (fp16 & fp32 are likely same speed on cpu))") | |
| gr.Markdown("All Aesthetic Shadow models are by shadowlilac. V2 is using reuploads by other people.") | |
| with gr.Row(): | |
| with gr.Column(scale=2, elem_id="left_column"): | |
| gr.Markdown("### Input") | |
| with gr.Tabs(): | |
| with gr.TabItem("Single"): | |
| uploaded_image = gr.Image(label="Drop or click to upload", type="filepath", elem_id="left_input_image") | |
| url_input = gr.Textbox(label="Image URL (optional)", placeholder="https://...", lines=1) | |
| # it takes reaaaaally long and there is a bug i think where it will take images from both tabs instead of just the active one | |
| with gr.TabItem("Batch"): | |
| batch_files = gr.File(label="Upload multiple images (batch)", file_count="multiple", type="filepath") | |
| batch_urls_text = gr.Textbox(label="Batch URLs (one per line)", placeholder="https://...", lines=4) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Model & Run") | |
| model_dropdown = gr.Dropdown( | |
| choices=[k for k in MODEL_CHOICES.keys() if k != "default"], | |
| value=DEFAULT_MODEL_KEY, | |
| label="Model Selection", | |
| ) | |
| run_button = gr.Button("Run", variant="primary") | |
| gr.Markdown("### Results") | |
| result_table_html = gr.HTML("<div id='results_table_container'>Results will appear here after running.</div>") | |
| def run_handler(up_img, url_txt, b_files, b_urls, model_key): | |
| html = run_classify(up_img, url_txt, b_files, b_urls, model_key) | |
| return html | |
| run_button.click( | |
| fn=run_handler, | |
| inputs=[uploaded_image, url_input, batch_files, batch_urls_text, model_dropdown], | |
| outputs=[result_table_html], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |