Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| import time | |
| from pathlib import Path | |
| from typing import Optional, Tuple | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import gradio as gr | |
| import spaces | |
| from scene_weaver_core import SceneWeaverCore | |
| from css_styles import CSSStyles | |
| from scene_templates import SceneTemplateManager | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s [%(name)s] %(levelname)s: %(message)s', | |
| datefmt='%H:%M:%S' | |
| ) | |
| class UIManager: | |
| """Gradio UI with enhanced memory management and professional design""" | |
| def __init__(self): | |
| self.sceneweaver = SceneWeaverCore() | |
| self.template_manager = SceneTemplateManager() | |
| self.generation_history = [] | |
| self._preview_sensitivity = 0.5 | |
| def apply_template(self, display_name: str, current_negative: str) -> Tuple[str, str, float]: | |
| """ | |
| Apply a scene template to the prompt fields. | |
| Args: | |
| display_name: The display name from dropdown (e.g., "🏢 Modern Office") | |
| current_negative: Current negative prompt value | |
| Returns: | |
| Tuple of (prompt, negative_prompt, guidance_scale) | |
| """ | |
| if not display_name: | |
| return "", current_negative, 7.5 | |
| # Convert display name to template key | |
| template_key = self.template_manager.get_template_key_from_display(display_name) | |
| if not template_key: | |
| return "", current_negative, 7.5 | |
| template = self.template_manager.get_template(template_key) | |
| if template: | |
| prompt = template.prompt | |
| negative = self.template_manager.get_negative_prompt_for_template( | |
| template_key, current_negative | |
| ) | |
| guidance = template.guidance_scale | |
| return prompt, negative, guidance | |
| return "", current_negative, 7.5 | |
| def quick_preview( | |
| self, | |
| uploaded_image: Optional[Image.Image], | |
| sensitivity: float = 0.5 | |
| ) -> Optional[Image.Image]: | |
| """ | |
| Generate quick foreground preview using lightweight traditional methods. | |
| Args: | |
| uploaded_image: Uploaded PIL Image | |
| sensitivity: Detection sensitivity (0.0 - 1.0) | |
| Returns: | |
| Preview image with colored overlay or None | |
| """ | |
| if uploaded_image is None: | |
| return None | |
| try: | |
| logger.info(f"Generating quick preview (sensitivity={sensitivity:.2f})") | |
| img_array = np.array(uploaded_image.convert('RGB')) | |
| height, width = img_array.shape[:2] | |
| max_preview_size = 512 | |
| if max(width, height) > max_preview_size: | |
| scale = max_preview_size / max(width, height) | |
| new_w = int(width * scale) | |
| new_h = int(height * scale) | |
| img_array = cv2.resize(img_array, (new_w, new_h), interpolation=cv2.INTER_AREA) | |
| height, width = new_h, new_w | |
| gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
| blurred = cv2.GaussianBlur(gray, (5, 5), 0) | |
| low_threshold = int(30 + (1 - sensitivity) * 50) | |
| high_threshold = int(100 + (1 - sensitivity) * 100) | |
| edges = cv2.Canny(blurred, low_threshold, high_threshold) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) | |
| dilated = cv2.dilate(edges, kernel, iterations=2) | |
| contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| mask = np.zeros((height, width), dtype=np.uint8) | |
| if contours: | |
| sorted_contours = sorted(contours, key=cv2.contourArea, reverse=True) | |
| min_area = (width * height) * 0.01 * (1 - sensitivity) | |
| for contour in sorted_contours: | |
| if cv2.contourArea(contour) > min_area: | |
| cv2.fillPoly(mask, [contour], 255) | |
| kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (11, 11)) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel_close) | |
| overlay = img_array.copy().astype(np.float32) | |
| fg_mask = mask > 127 | |
| overlay[fg_mask] = overlay[fg_mask] * 0.5 + np.array([0, 255, 0]) * 0.5 | |
| bg_mask = mask <= 127 | |
| overlay[bg_mask] = overlay[bg_mask] * 0.5 + np.array([255, 0, 0]) * 0.5 | |
| overlay = np.clip(overlay, 0, 255).astype(np.uint8) | |
| original_size = uploaded_image.size | |
| preview_image = Image.fromarray(overlay) | |
| if preview_image.size != original_size: | |
| preview_image = preview_image.resize(original_size, Image.LANCZOS) | |
| logger.info("Quick preview generated successfully") | |
| return preview_image | |
| except Exception as e: | |
| logger.error(f"Quick preview failed: {e}") | |
| return None | |
| def _save_result(self, combined_image: Image.Image, prompt: str): | |
| """Save result with memory-conscious history management""" | |
| if not combined_image: | |
| return | |
| output_dir = Path("outputs") | |
| output_dir.mkdir(exist_ok=True) | |
| combined_image.save(output_dir / "latest_combined.png") | |
| self.generation_history.append({ | |
| "prompt": prompt, | |
| "timestamp": time.time() | |
| }) | |
| max_history = self.sceneweaver.max_history | |
| if len(self.generation_history) > max_history: | |
| self.generation_history = self.generation_history[-max_history:] | |
| def generate_handler( | |
| self, | |
| uploaded_image: Optional[Image.Image], | |
| prompt: str, | |
| combination_mode: str, | |
| focus_mode: str, | |
| negative_prompt: str, | |
| steps: int, | |
| guidance: float, | |
| progress=gr.Progress() | |
| ): | |
| """Enhanced generation handler with memory management and ZeroGPU support""" | |
| if uploaded_image is None: | |
| return None, None, None, "Please upload an image to get started!", gr.update(visible=False) | |
| if not prompt.strip(): | |
| return None, None, None, "Please describe the background scene you'd like!", gr.update(visible=False) | |
| try: | |
| if not self.sceneweaver.is_initialized: | |
| progress(0.05, desc="Loading AI models (first time may take 2-3 minutes)...") | |
| def init_progress(msg, pct): | |
| if pct < 30: | |
| desc = "Loading image analysis models..." | |
| elif pct < 60: | |
| desc = "Loading Stable Diffusion XL..." | |
| elif pct < 90: | |
| desc = "Applying memory optimizations..." | |
| else: | |
| desc = "Almost ready..." | |
| progress(0.05 + (pct/100) * 0.2, desc=desc) | |
| self.sceneweaver.load_models(progress_callback=init_progress) | |
| def gen_progress(msg, pct): | |
| if pct < 20: | |
| desc = "Analyzing your image..." | |
| elif pct < 50: | |
| desc = "Generating background scene..." | |
| elif pct < 80: | |
| desc = "Blending foreground and background..." | |
| elif pct < 95: | |
| desc = "Applying final touches..." | |
| else: | |
| desc = "Complete!" | |
| progress(0.25 + (pct/100) * 0.75, desc=desc) | |
| result = self.sceneweaver.generate_and_combine( | |
| original_image=uploaded_image, | |
| prompt=prompt, | |
| combination_mode=combination_mode, | |
| focus_mode=focus_mode, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=int(steps), | |
| guidance_scale=float(guidance), | |
| progress_callback=gen_progress | |
| ) | |
| if result["success"]: | |
| combined = result["combined_image"] | |
| generated = result["generated_scene"] | |
| original = result["original_image"] | |
| self._save_result(combined, prompt) | |
| status_msg = "Image created successfully!" | |
| return combined, generated, original, status_msg, gr.update(visible=True) | |
| else: | |
| error_msg = result.get("error", "Something went wrong") | |
| return None, None, None, f"Error: {error_msg}", gr.update(visible=False) | |
| except Exception as e: | |
| import traceback | |
| error_traceback = traceback.format_exc() | |
| logger.error(f"Generation handler error: {str(e)}") | |
| logger.error(f"Traceback:\n{error_traceback}") | |
| return None, None, None, f"Error: {str(e)}", gr.update(visible=False) | |
| def create_interface(self): | |
| """Create professional user interface""" | |
| css = CSSStyles.get_main_css() | |
| with gr.Blocks( | |
| css=css, | |
| title="SceneWeaver - AI Background Generator", | |
| theme=gr.themes.Soft() | |
| ) as interface: | |
| # Header | |
| gr.HTML(""" | |
| <div class="main-header"> | |
| <h1 class="main-title"> | |
| <span class="title-emoji">🎨</span> | |
| SceneWeaver | |
| </h1> | |
| <p class="main-subtitle">AI-powered background generation with professional edge processing</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| # Left Column - Input controls | |
| with gr.Column(scale=1, min_width=350, elem_classes=["feature-card"]): | |
| gr.HTML(""" | |
| <div class="card-content"> | |
| <h3 class="card-title"> | |
| <span class="section-emoji">📸</span> | |
| Upload & Generate | |
| </h3> | |
| </div> | |
| """) | |
| uploaded_image = gr.Image( | |
| label="Upload Your Image", | |
| type="pil", | |
| height=280, | |
| elem_classes=["input-field"] | |
| ) | |
| # Scene Template Selector | |
| with gr.Accordion("Scene Templates", open=False): | |
| template_dropdown = gr.Dropdown( | |
| label="Select a Scene", | |
| choices=[""] + self.template_manager.get_template_choices_sorted(), | |
| value="", | |
| info="24 curated scenes sorted A-Z", | |
| elem_classes=["template-dropdown"] | |
| ) | |
| prompt_input = gr.Textbox( | |
| label="Background Scene Description", | |
| placeholder="Select a template above or describe your own scene...", | |
| lines=3, | |
| elem_classes=["input-field"] | |
| ) | |
| combination_mode = gr.Dropdown( | |
| label="Composition Mode", | |
| choices=["center", "left_half", "right_half", "full"], | |
| value="center", | |
| info="center=Smart Center | left_half=Left Half | right_half=Right Half | full=Full Image", | |
| elem_classes=["input-field"] | |
| ) | |
| focus_mode = gr.Dropdown( | |
| label="Focus Mode", | |
| choices=["person", "scene"], | |
| value="person", | |
| info="person=Tight Crop | scene=Include Surrounding Objects", | |
| elem_classes=["input-field"] | |
| ) | |
| with gr.Accordion("Advanced Options", open=False): | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| value="blurry, low quality, distorted, people, characters", | |
| lines=2, | |
| elem_classes=["input-field"] | |
| ) | |
| steps_slider = gr.Slider( | |
| label="Quality Steps", | |
| minimum=15, | |
| maximum=50, | |
| value=25, | |
| step=5, | |
| elem_classes=["input-field"] | |
| ) | |
| guidance_slider = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=5.0, | |
| maximum=15.0, | |
| value=7.5, | |
| step=0.5, | |
| elem_classes=["input-field"] | |
| ) | |
| generate_btn = gr.Button( | |
| "Generate Background", | |
| variant="primary", | |
| size="lg", | |
| elem_classes=["primary-button"] | |
| ) | |
| # Right Column - Results display | |
| with gr.Column(scale=2, elem_classes=["feature-card"], elem_id="results-gallery-centered"): | |
| gr.HTML(""" | |
| <div class="card-content"> | |
| <h3 class="card-title"> | |
| <span class="section-emoji">🎭</span> | |
| Results Gallery | |
| </h3> | |
| </div> | |
| """) | |
| # Loading notice | |
| gr.HTML(""" | |
| <div class="loading-notice"> | |
| <span class="loading-notice-icon">⏱️</span> | |
| <span class="loading-notice-text"> | |
| <strong>First-time users:</strong> Initial model loading takes 1-2 minutes. | |
| Subsequent generations are much faster (~30s). | |
| </span> | |
| </div> | |
| """) | |
| # Quick start guide | |
| gr.HTML(""" | |
| <details class="user-guidance-panel"> | |
| <summary class="guidance-summary"> | |
| <span class="emoji-enhanced">💡</span> | |
| Quick Start Guide | |
| </summary> | |
| <div class="guidance-content"> | |
| <p><strong>Step 1:</strong> Upload any image with a clear subject</p> | |
| <p><strong>Step 2:</strong> Describe or Choose your desired background scene</p> | |
| <p><strong>Step 3:</strong> Choose composition mode (center works best)</p> | |
| <p><strong>Step 4:</strong> Click Generate and wait for the magic!</p> | |
| <p><strong>Tip:</strong> For dark clothing, ensure good lighting in original photo.</p> | |
| </div> | |
| </details> | |
| """) | |
| with gr.Tabs(): | |
| with gr.TabItem("Final Result"): | |
| combined_output = gr.Image( | |
| label="Your Generated Image", | |
| elem_classes=["result-gallery"], | |
| show_label=False | |
| ) | |
| with gr.TabItem("Background"): | |
| generated_output = gr.Image( | |
| label="Generated Background", | |
| elem_classes=["result-gallery"], | |
| show_label=False | |
| ) | |
| with gr.TabItem("Original"): | |
| original_output = gr.Image( | |
| label="Processed Original", | |
| elem_classes=["result-gallery"], | |
| show_label=False | |
| ) | |
| status_output = gr.Textbox( | |
| label="Status", | |
| value="Ready to create! Upload an image and describe your vision.", | |
| interactive=False, | |
| elem_classes=["status-panel", "status-ready"] | |
| ) | |
| with gr.Row(): | |
| download_btn = gr.DownloadButton( | |
| "Download Result", | |
| value=None, | |
| visible=False, | |
| elem_classes=["secondary-button"] | |
| ) | |
| clear_btn = gr.Button( | |
| "Clear All", | |
| elem_classes=["secondary-button"] | |
| ) | |
| memory_btn = gr.Button( | |
| "Clean Memory", | |
| elem_classes=["secondary-button"] | |
| ) | |
| # Footer with tech credits | |
| gr.HTML(""" | |
| <div class="app-footer"> | |
| <div class="footer-powered"> | |
| <p class="footer-powered-title">Powered By</p> | |
| <div class="footer-tech-grid"> | |
| <span class="footer-tech-item">Stable Diffusion XL</span> | |
| <span class="footer-tech-item">OpenCLIP</span> | |
| <span class="footer-tech-item">BiRefNet</span> | |
| <span class="footer-tech-item">rembg</span> | |
| <span class="footer-tech-item">PyTorch</span> | |
| <span class="footer-tech-item">Gradio</span> | |
| </div> | |
| </div> | |
| <div class="footer-divider"></div> | |
| <p class="footer-copyright"> | |
| SceneWeaver © 2025 | | |
| Built with <a href="https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0" target="_blank">SDXL</a> | |
| and <a href="https://github.com/mlfoundations/open_clip" target="_blank">OpenCLIP</a> | |
| </p> | |
| </div> | |
| """) | |
| # Event handlers | |
| # Template selection handler | |
| template_dropdown.change( | |
| fn=self.apply_template, | |
| inputs=[template_dropdown, negative_prompt], | |
| outputs=[prompt_input, negative_prompt, guidance_slider] | |
| ) | |
| generate_btn.click( | |
| fn=self.generate_handler, | |
| inputs=[ | |
| uploaded_image, | |
| prompt_input, | |
| combination_mode, | |
| focus_mode, | |
| negative_prompt, | |
| steps_slider, | |
| guidance_slider | |
| ], | |
| outputs=[ | |
| combined_output, | |
| generated_output, | |
| original_output, | |
| status_output, | |
| download_btn | |
| ] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, None, None, "Ready to create!", gr.update(visible=False)), | |
| outputs=[combined_output, generated_output, original_output, status_output, download_btn] | |
| ) | |
| memory_btn.click( | |
| fn=lambda: self.sceneweaver._ultra_memory_cleanup() or "Memory cleaned!", | |
| outputs=[status_output] | |
| ) | |
| combined_output.change( | |
| fn=lambda img: gr.update(value="outputs/latest_combined.png", visible=True) if (img is not None) else gr.update(visible=False), | |
| inputs=[combined_output], | |
| outputs=[download_btn] | |
| ) | |
| return interface | |
| def launch(self, share: bool = True, debug: bool = False): | |
| """Launch the UI interface""" | |
| interface = self.create_interface() | |
| return interface.launch( | |
| share=share, | |
| debug=debug, | |
| show_error=True, | |
| height=800, | |
| favicon_path=None, | |
| ssl_verify=False, | |
| quiet=False | |
| ) | |