Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| import time | |
| import traceback | |
| from pathlib import Path | |
| from typing import Optional, Tuple, Dict, Any, List | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import gradio as gr | |
| import spaces | |
| from scene_weaver_core import SceneWeaverCore | |
| from css_styles import CSSStyles | |
| from scene_templates import SceneTemplateManager | |
| from inpainting_templates import InpaintingTemplateManager | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s [%(name)s] %(levelname)s: %(message)s', | |
| datefmt='%H:%M:%S' | |
| ) | |
| class UIManager: | |
| """ | |
| Gradio UI Manager with support for background generation and inpainting. | |
| Provides a professional interface with mode switching, template selection, | |
| and advanced parameter controls. | |
| Attributes: | |
| sceneweaver: SceneWeaverCore instance | |
| template_manager: Scene template manager | |
| inpainting_template_manager: Inpainting template manager | |
| """ | |
| def __init__(self): | |
| self.sceneweaver = SceneWeaverCore() | |
| self.template_manager = SceneTemplateManager() | |
| self.inpainting_template_manager = InpaintingTemplateManager() | |
| self.generation_history = [] | |
| self.inpainting_history = [] | |
| self._preview_sensitivity = 0.5 | |
| self._current_mode = "background" # "background" or "inpainting" | |
| def apply_template(self, display_name: str, current_negative: str) -> Tuple[str, str, float]: | |
| """ | |
| Apply a scene template to the prompt fields. | |
| Args: | |
| display_name: The display name from dropdown (e.g., "🏢 Modern Office") | |
| current_negative: Current negative prompt value | |
| Returns: | |
| Tuple of (prompt, negative_prompt, guidance_scale) | |
| """ | |
| if not display_name: | |
| return "", current_negative, 7.5 | |
| # Convert display name to template key | |
| template_key = self.template_manager.get_template_key_from_display(display_name) | |
| if not template_key: | |
| return "", current_negative, 7.5 | |
| template = self.template_manager.get_template(template_key) | |
| if template: | |
| prompt = template.prompt | |
| negative = self.template_manager.get_negative_prompt_for_template( | |
| template_key, current_negative | |
| ) | |
| guidance = template.guidance_scale | |
| return prompt, negative, guidance | |
| return "", current_negative, 7.5 | |
| def quick_preview( | |
| self, | |
| uploaded_image: Optional[Image.Image], | |
| sensitivity: float = 0.5 | |
| ) -> Optional[Image.Image]: | |
| """ | |
| Generate quick foreground preview using lightweight traditional methods. | |
| Args: | |
| uploaded_image: Uploaded PIL Image | |
| sensitivity: Detection sensitivity (0.0 - 1.0) | |
| Returns: | |
| Preview image with colored overlay or None | |
| """ | |
| if uploaded_image is None: | |
| return None | |
| try: | |
| logger.info(f"Generating quick preview (sensitivity={sensitivity:.2f})") | |
| img_array = np.array(uploaded_image.convert('RGB')) | |
| height, width = img_array.shape[:2] | |
| max_preview_size = 512 | |
| if max(width, height) > max_preview_size: | |
| scale = max_preview_size / max(width, height) | |
| new_w = int(width * scale) | |
| new_h = int(height * scale) | |
| img_array = cv2.resize(img_array, (new_w, new_h), interpolation=cv2.INTER_AREA) | |
| height, width = new_h, new_w | |
| gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
| blurred = cv2.GaussianBlur(gray, (5, 5), 0) | |
| low_threshold = int(30 + (1 - sensitivity) * 50) | |
| high_threshold = int(100 + (1 - sensitivity) * 100) | |
| edges = cv2.Canny(blurred, low_threshold, high_threshold) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) | |
| dilated = cv2.dilate(edges, kernel, iterations=2) | |
| contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| mask = np.zeros((height, width), dtype=np.uint8) | |
| if contours: | |
| sorted_contours = sorted(contours, key=cv2.contourArea, reverse=True) | |
| min_area = (width * height) * 0.01 * (1 - sensitivity) | |
| for contour in sorted_contours: | |
| if cv2.contourArea(contour) > min_area: | |
| cv2.fillPoly(mask, [contour], 255) | |
| kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (11, 11)) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel_close) | |
| overlay = img_array.copy().astype(np.float32) | |
| fg_mask = mask > 127 | |
| overlay[fg_mask] = overlay[fg_mask] * 0.5 + np.array([0, 255, 0]) * 0.5 | |
| bg_mask = mask <= 127 | |
| overlay[bg_mask] = overlay[bg_mask] * 0.5 + np.array([255, 0, 0]) * 0.5 | |
| overlay = np.clip(overlay, 0, 255).astype(np.uint8) | |
| original_size = uploaded_image.size | |
| preview_image = Image.fromarray(overlay) | |
| if preview_image.size != original_size: | |
| preview_image = preview_image.resize(original_size, Image.LANCZOS) | |
| logger.info("Quick preview generated successfully") | |
| return preview_image | |
| except Exception as e: | |
| logger.error(f"Quick preview failed: {e}") | |
| return None | |
| def _save_result(self, combined_image: Image.Image, prompt: str): | |
| """Save result with memory-conscious history management""" | |
| if not combined_image: | |
| return | |
| output_dir = Path("outputs") | |
| output_dir.mkdir(exist_ok=True) | |
| combined_image.save(output_dir / "latest_combined.png") | |
| self.generation_history.append({ | |
| "prompt": prompt, | |
| "timestamp": time.time() | |
| }) | |
| max_history = self.sceneweaver.max_history | |
| if len(self.generation_history) > max_history: | |
| self.generation_history = self.generation_history[-max_history:] | |
| def generate_handler( | |
| self, | |
| uploaded_image: Optional[Image.Image], | |
| prompt: str, | |
| combination_mode: str, | |
| focus_mode: str, | |
| negative_prompt: str, | |
| steps: int, | |
| guidance: float, | |
| progress=gr.Progress() | |
| ): | |
| """Enhanced generation handler with memory management and ZeroGPU support""" | |
| if uploaded_image is None: | |
| return None, None, None, "Please upload an image to get started!", gr.update(visible=False) | |
| if not prompt.strip(): | |
| return None, None, None, "Please describe the background scene you'd like!", gr.update(visible=False) | |
| try: | |
| if not self.sceneweaver.is_initialized: | |
| progress(0.05, desc="Loading AI models (first time may take 2-3 minutes)...") | |
| def init_progress(msg, pct): | |
| if pct < 30: | |
| desc = "Loading image analysis models..." | |
| elif pct < 60: | |
| desc = "Loading Stable Diffusion XL..." | |
| elif pct < 90: | |
| desc = "Applying memory optimizations..." | |
| else: | |
| desc = "Almost ready..." | |
| progress(0.05 + (pct/100) * 0.2, desc=desc) | |
| self.sceneweaver.load_models(progress_callback=init_progress) | |
| def gen_progress(msg, pct): | |
| if pct < 20: | |
| desc = "Analyzing your image..." | |
| elif pct < 50: | |
| desc = "Generating background scene..." | |
| elif pct < 80: | |
| desc = "Blending foreground and background..." | |
| elif pct < 95: | |
| desc = "Applying final touches..." | |
| else: | |
| desc = "Complete!" | |
| progress(0.25 + (pct/100) * 0.75, desc=desc) | |
| result = self.sceneweaver.generate_and_combine( | |
| original_image=uploaded_image, | |
| prompt=prompt, | |
| combination_mode=combination_mode, | |
| focus_mode=focus_mode, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=int(steps), | |
| guidance_scale=float(guidance), | |
| progress_callback=gen_progress | |
| ) | |
| if result["success"]: | |
| combined = result["combined_image"] | |
| generated = result["generated_scene"] | |
| original = result["original_image"] | |
| self._save_result(combined, prompt) | |
| status_msg = "Image created successfully!" | |
| return combined, generated, original, status_msg, gr.update(visible=True) | |
| else: | |
| error_msg = result.get("error", "Something went wrong") | |
| return None, None, None, f"Error: {error_msg}", gr.update(visible=False) | |
| except Exception as e: | |
| error_traceback = traceback.format_exc() | |
| logger.error(f"Generation handler error: {str(e)}") | |
| logger.error(f"Traceback:\n{error_traceback}") | |
| return None, None, None, f"Error: {str(e)}", gr.update(visible=False) | |
| def create_interface(self): | |
| """Create professional user interface""" | |
| self._css = CSSStyles.get_main_css() | |
| # Check Gradio version for API compatibility | |
| self._gradio_version = gr.__version__ | |
| self._gradio_major = int(self._gradio_version.split('.')[0]) | |
| # Compatible with Gradio 4.44.0+ | |
| # Use minimal constructor arguments for maximum compatibility | |
| with gr.Blocks() as interface: | |
| # Inject CSS (compatible with all Gradio versions) | |
| gr.HTML(f"<style>{self._css}</style>") | |
| # Header | |
| gr.HTML(""" | |
| <div class="main-header"> | |
| <h1 class="main-title"> | |
| <span class="title-emoji">🎨</span> | |
| SceneWeaver | |
| </h1> | |
| <p class="main-subtitle">AI-powered background generation and inpainting with professional edge processing</p> | |
| </div> | |
| """) | |
| # Main Tabs for Mode Selection | |
| with gr.Tabs(elem_id="main-mode-tabs") as main_tabs: | |
| # Background Generation Tab | |
| with gr.Tab("Background Generation", elem_id="bg-gen-tab"): | |
| with gr.Row(): | |
| # Left Column - Input controls | |
| with gr.Column(scale=1, min_width=350, elem_classes=["feature-card"]): | |
| gr.HTML(""" | |
| <div class="card-content"> | |
| <h3 class="card-title"> | |
| <span class="section-emoji">📸</span> | |
| Upload & Generate | |
| </h3> | |
| </div> | |
| """) | |
| uploaded_image = gr.Image( | |
| label="Upload Your Image", | |
| type="pil", | |
| height=280, | |
| elem_classes=["input-field"] | |
| ) | |
| # Scene Template Selector (without Accordion to fix dropdown positioning in Gradio 5.x) | |
| template_dropdown = gr.Dropdown( | |
| label="Scene Templates", | |
| choices=[""] + self.template_manager.get_template_choices_sorted(), | |
| value="", | |
| info="24 curated scenes sorted A-Z (optional)", | |
| elem_classes=["template-dropdown"] | |
| ) | |
| prompt_input = gr.Textbox( | |
| label="Background Scene Description", | |
| placeholder="Select a template above or describe your own scene...", | |
| lines=3, | |
| elem_classes=["input-field"] | |
| ) | |
| combination_mode = gr.Dropdown( | |
| label="Composition Mode", | |
| choices=["center", "left_half", "right_half", "full"], | |
| value="center", | |
| info="center=Smart Center | left_half=Left Half | right_half=Right Half | full=Full Image", | |
| elem_classes=["input-field"] | |
| ) | |
| focus_mode = gr.Dropdown( | |
| label="Focus Mode", | |
| choices=["person", "scene"], | |
| value="person", | |
| info="person=Tight Crop | scene=Include Surrounding Objects", | |
| elem_classes=["input-field"] | |
| ) | |
| with gr.Accordion("Advanced Options", open=False): | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| value="blurry, low quality, distorted, people, characters", | |
| lines=2, | |
| elem_classes=["input-field"] | |
| ) | |
| steps_slider = gr.Slider( | |
| label="Quality Steps", | |
| minimum=15, | |
| maximum=50, | |
| value=25, | |
| step=5, | |
| elem_classes=["input-field"] | |
| ) | |
| guidance_slider = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=5.0, | |
| maximum=15.0, | |
| value=7.5, | |
| step=0.5, | |
| elem_classes=["input-field"] | |
| ) | |
| generate_btn = gr.Button( | |
| "Generate Background", | |
| variant="primary", | |
| size="lg", | |
| elem_classes=["primary-button"] | |
| ) | |
| # Right Column - Results display | |
| with gr.Column(scale=2, elem_classes=["feature-card"], elem_id="results-gallery-centered"): | |
| gr.HTML(""" | |
| <div class="card-content"> | |
| <h3 class="card-title"> | |
| <span class="section-emoji">🎭</span> | |
| Results Gallery | |
| </h3> | |
| </div> | |
| """) | |
| # Loading notice | |
| gr.HTML(""" | |
| <div class="loading-notice"> | |
| <span class="loading-notice-icon">⏱️</span> | |
| <span class="loading-notice-text"> | |
| <strong>First-time users:</strong> Initial model loading takes 1-2 minutes. | |
| Subsequent generations are much faster (~30s). | |
| </span> | |
| </div> | |
| """) | |
| # Quick start guide | |
| gr.HTML(""" | |
| <details class="user-guidance-panel"> | |
| <summary class="guidance-summary"> | |
| <span class="emoji-enhanced">💡</span> | |
| Quick Start Guide | |
| </summary> | |
| <div class="guidance-content"> | |
| <p><strong>Step 1:</strong> Upload any image with a clear subject</p> | |
| <p><strong>Step 2:</strong> Describe or Choose your desired background scene</p> | |
| <p><strong>Step 3:</strong> Choose composition mode (center works best)</p> | |
| <p><strong>Step 4:</strong> Click Generate and wait for the magic!</p> | |
| <p><strong>Tip:</strong> For dark clothing, ensure good lighting in original photo.</p> | |
| </div> | |
| </details> | |
| """) | |
| with gr.Tabs(): | |
| with gr.TabItem("Final Result"): | |
| combined_output = gr.Image( | |
| label="Your Generated Image", | |
| elem_classes=["result-gallery"], | |
| show_label=False | |
| ) | |
| with gr.TabItem("Background"): | |
| generated_output = gr.Image( | |
| label="Generated Background", | |
| elem_classes=["result-gallery"], | |
| show_label=False | |
| ) | |
| with gr.TabItem("Original"): | |
| original_output = gr.Image( | |
| label="Processed Original", | |
| elem_classes=["result-gallery"], | |
| show_label=False | |
| ) | |
| status_output = gr.Textbox( | |
| label="Status", | |
| value="Ready to create! Upload an image and describe your vision.", | |
| interactive=False, | |
| elem_classes=["status-panel", "status-ready"] | |
| ) | |
| with gr.Row(): | |
| download_btn = gr.DownloadButton( | |
| "Download Result", | |
| value=None, | |
| visible=False, | |
| elem_classes=["secondary-button"] | |
| ) | |
| clear_btn = gr.Button( | |
| "Clear All", | |
| elem_classes=["secondary-button"] | |
| ) | |
| memory_btn = gr.Button( | |
| "Clean Memory", | |
| elem_classes=["secondary-button"] | |
| ) | |
| # Event handlers for Background Generation Tab | |
| # Template selection handler | |
| template_dropdown.change( | |
| fn=self.apply_template, | |
| inputs=[template_dropdown, negative_prompt], | |
| outputs=[prompt_input, negative_prompt, guidance_slider] | |
| ) | |
| generate_btn.click( | |
| fn=self.generate_handler, | |
| inputs=[ | |
| uploaded_image, | |
| prompt_input, | |
| combination_mode, | |
| focus_mode, | |
| negative_prompt, | |
| steps_slider, | |
| guidance_slider | |
| ], | |
| outputs=[ | |
| combined_output, | |
| generated_output, | |
| original_output, | |
| status_output, | |
| download_btn | |
| ] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, None, None, "Ready to create!", gr.update(visible=False)), | |
| outputs=[combined_output, generated_output, original_output, status_output, download_btn] | |
| ) | |
| memory_btn.click( | |
| fn=lambda: self.sceneweaver._ultra_memory_cleanup() or "Memory cleaned!", | |
| outputs=[status_output] | |
| ) | |
| combined_output.change( | |
| fn=lambda img: gr.update(value="outputs/latest_combined.png", visible=True) if (img is not None) else gr.update(visible=False), | |
| inputs=[combined_output], | |
| outputs=[download_btn] | |
| ) | |
| # End of Background Generation Tab | |
| # Inpainting Tab | |
| self.create_inpainting_tab() | |
| # Footer with tech credits (outside tabs) | |
| gr.HTML(""" | |
| <div class="app-footer"> | |
| <div class="footer-powered"> | |
| <p class="footer-powered-title">Powered By</p> | |
| <div class="footer-tech-grid"> | |
| <span class="footer-tech-item">Stable Diffusion XL</span> | |
| <span class="footer-tech-item">OpenCLIP</span> | |
| <span class="footer-tech-item">BiRefNet</span> | |
| <span class="footer-tech-item">rembg</span> | |
| <span class="footer-tech-item">PyTorch</span> | |
| <span class="footer-tech-item">Gradio</span> | |
| </div> | |
| </div> | |
| <div class="footer-divider"></div> | |
| <p class="footer-copyright"> | |
| SceneWeaver © 2025 | | |
| Built with <a href="https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0" target="_blank">SDXL</a> | |
| and <a href="https://github.com/mlfoundations/open_clip" target="_blank">OpenCLIP</a> | |
| </p> | |
| </div> | |
| """) | |
| return interface | |
| def launch(self, share: bool = True, debug: bool = False): | |
| """Launch the UI interface""" | |
| interface = self.create_interface() | |
| # Launch kwargs compatible with Gradio 4.44.0+ | |
| # Keep minimal for maximum compatibility | |
| launch_kwargs = { | |
| "share": share, | |
| "debug": debug, | |
| "show_error": True, | |
| "quiet": False | |
| } | |
| return interface.launch(**launch_kwargs) | |
| # INPAINTING UI METHODS | |
| def apply_inpainting_template( | |
| self, | |
| display_name: str, | |
| current_prompt: str | |
| ) -> Tuple[str, float, int, str]: | |
| """ | |
| Apply an inpainting template to the UI fields. | |
| Parameters | |
| ---------- | |
| display_name : str | |
| Template display name from dropdown | |
| current_prompt : str | |
| Current prompt content | |
| Returns | |
| ------- | |
| tuple | |
| (prompt, conditioning_scale, feather_radius, conditioning_type) | |
| """ | |
| if not display_name: | |
| return current_prompt, 0.7, 8, "canny" | |
| template_key = self.inpainting_template_manager.get_template_key_from_display(display_name) | |
| if not template_key: | |
| return current_prompt, 0.7, 8, "canny" | |
| template = self.inpainting_template_manager.get_template(template_key) | |
| if template: | |
| params = self.inpainting_template_manager.get_parameters_for_template(template_key) | |
| return ( | |
| current_prompt, | |
| params.get('controlnet_conditioning_scale', 0.7), | |
| params.get('feather_radius', 8), | |
| params.get('preferred_conditioning', 'canny') | |
| ) | |
| return current_prompt, 0.7, 8, "canny" | |
| def extract_mask_from_editor(self, editor_output: Dict[str, Any]) -> Optional[Image.Image]: | |
| """ | |
| Extract mask from Gradio ImageEditor output. | |
| Handles different Gradio versions' output formats. | |
| Parameters | |
| ---------- | |
| editor_output : dict | |
| Output from gr.ImageEditor component | |
| Returns | |
| ------- | |
| PIL.Image or None | |
| Extracted mask as grayscale image | |
| """ | |
| if editor_output is None: | |
| return None | |
| try: | |
| # Gradio 5.x format | |
| if isinstance(editor_output, dict): | |
| # Check for 'layers' key (Gradio 5.x ImageEditor) | |
| if 'layers' in editor_output and editor_output['layers']: | |
| # Get the first layer as mask | |
| layer = editor_output['layers'][0] | |
| if isinstance(layer, np.ndarray): | |
| mask_array = layer | |
| elif isinstance(layer, Image.Image): | |
| mask_array = np.array(layer) | |
| else: | |
| return None | |
| # Check for 'composite' key | |
| elif 'composite' in editor_output: | |
| composite = editor_output['composite'] | |
| if isinstance(composite, np.ndarray): | |
| mask_array = composite | |
| elif isinstance(composite, Image.Image): | |
| mask_array = np.array(composite) | |
| else: | |
| return None | |
| else: | |
| return None | |
| elif isinstance(editor_output, np.ndarray): | |
| mask_array = editor_output | |
| elif isinstance(editor_output, Image.Image): | |
| mask_array = np.array(editor_output) | |
| else: | |
| logger.warning(f"Unexpected editor output type: {type(editor_output)}") | |
| return None | |
| # Convert to grayscale mask | |
| if len(mask_array.shape) == 3: | |
| if mask_array.shape[2] == 4: | |
| # RGBA format - extract white brush strokes from RGB channels | |
| # White brush strokes have high RGB values AND high alpha | |
| rgb_part = mask_array[:, :, :3] | |
| alpha_part = mask_array[:, :, 3] | |
| # Convert RGB to grayscale to detect white areas | |
| gray = cv2.cvtColor(rgb_part, cv2.COLOR_RGB2GRAY) | |
| # Combine: white areas (high gray value) with opacity (high alpha) | |
| # This captures white brush strokes | |
| mask_gray = np.minimum(gray, alpha_part) | |
| elif mask_array.shape[2] == 3: | |
| # RGB - convert to grayscale (white areas become white in mask) | |
| mask_gray = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY) | |
| else: | |
| mask_gray = mask_array[:, :, 0] | |
| else: | |
| # Already grayscale | |
| mask_gray = mask_array | |
| return Image.fromarray(mask_gray.astype(np.uint8), mode='L') | |
| except Exception as e: | |
| logger.error(f"Failed to extract mask from editor: {e}") | |
| return None | |
| def inpainting_handler( | |
| self, | |
| image: Optional[Image.Image], | |
| mask_editor: Dict[str, Any], | |
| prompt: str, | |
| template_dropdown: str, | |
| conditioning_type: str, | |
| conditioning_scale: float, | |
| feather_radius: int, | |
| guidance_scale: float, | |
| num_steps: int, | |
| progress: gr.Progress = gr.Progress() | |
| ) -> Tuple[Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], str]: | |
| """ | |
| Handle inpainting generation request. | |
| Parameters | |
| ---------- | |
| image : PIL.Image | |
| Original image to inpaint | |
| mask_editor : dict | |
| Mask editor output | |
| prompt : str | |
| Text description of desired content | |
| template_dropdown : str | |
| Selected template (optional) | |
| conditioning_type : str | |
| ControlNet conditioning type | |
| conditioning_scale : float | |
| ControlNet influence strength | |
| feather_radius : int | |
| Mask feathering radius | |
| guidance_scale : float | |
| Guidance scale for generation | |
| num_steps : int | |
| Number of inference steps | |
| progress : gr.Progress | |
| Progress callback | |
| Returns | |
| ------- | |
| tuple | |
| (result_image, control_image, status_message) | |
| """ | |
| if image is None: | |
| return None, None, "⚠️ Please upload an image first" | |
| # Extract mask | |
| mask = self.extract_mask_from_editor(mask_editor) | |
| if mask is None: | |
| return None, None, "⚠️ Please draw a mask on the image" | |
| # Validate mask | |
| mask_array = np.array(mask) | |
| coverage = np.count_nonzero(mask_array > 127) / mask_array.size | |
| if coverage < 0.01: | |
| return None, None, "⚠️ Mask too small - please select a larger area" | |
| if coverage > 0.95: | |
| return None, None, "⚠️ Mask too large - consider using background generation instead" | |
| def progress_callback(msg: str, pct: int): | |
| progress(pct / 100, desc=msg) | |
| try: | |
| start_time = time.time() | |
| # Get template key if selected | |
| template_key = None | |
| if template_dropdown: | |
| template_key = self.inpainting_template_manager.get_template_key_from_display( | |
| template_dropdown | |
| ) | |
| # Execute inpainting through SceneWeaverCore facade | |
| result = self.sceneweaver.execute_inpainting( | |
| image=image, | |
| mask=mask, | |
| prompt=prompt, | |
| preview_only=False, | |
| template_key=template_key, | |
| conditioning_type=conditioning_type, | |
| controlnet_conditioning_scale=conditioning_scale, | |
| feather_radius=feather_radius, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_steps, | |
| progress_callback=progress_callback | |
| ) | |
| elapsed = time.time() - start_time | |
| if result.get('success'): | |
| # Store in history | |
| self.inpainting_history.append({ | |
| 'result': result.get('combined_image'), | |
| 'prompt': prompt, | |
| 'time': elapsed | |
| }) | |
| if len(self.inpainting_history) > 3: | |
| self.inpainting_history.pop(0) | |
| quality_score = result.get('quality_score', 0) | |
| # Clean, simple status message | |
| status = f"✅ Inpainting complete in {elapsed:.1f}s" | |
| if quality_score > 0: | |
| status += f" | Quality: {quality_score:.0f}/100" | |
| return ( | |
| result.get('combined_image'), | |
| result.get('control_image'), | |
| status | |
| ) | |
| else: | |
| error_msg = result.get('error', 'Unknown error') | |
| return None, None, f"❌ Inpainting failed: {error_msg}" | |
| except Exception as e: | |
| logger.error(f"Inpainting handler error: {e}") | |
| logger.error(traceback.format_exc()) | |
| return None, None, f"❌ Error: {str(e)}" | |
| def create_inpainting_tab(self) -> gr.Tab: | |
| """ | |
| Create the inpainting tab UI. | |
| Returns | |
| ------- | |
| gr.Tab | |
| Configured inpainting tab component | |
| """ | |
| with gr.Tab("Inpainting", elem_id="inpainting-tab") as tab: | |
| gr.HTML(""" | |
| <div class="inpainting-header"> | |
| <h3 style="display: flex; align-items: center; gap: 10px; margin-bottom: 8px;"> | |
| ControlNet Inpainting | |
| <span style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| padding: 3px 10px; | |
| border-radius: 12px; | |
| font-size: 0.65em; | |
| font-weight: 700; | |
| letter-spacing: 0.5px; | |
| box-shadow: 0 2px 4px rgba(102, 126, 234, 0.3);"> | |
| BETA | |
| </span> | |
| </h3> | |
| <p style="color: #666; margin-bottom: 12px;">Draw a mask to select the area you want to regenerate</p> | |
| <div style="background: linear-gradient(to right, #FFF4E6, #FFE8CC); | |
| border-left: 4px solid #FF9500; | |
| padding: 12px 15px; | |
| border-radius: 6px; | |
| margin-top: 10px; | |
| box-shadow: 0 2px 4px rgba(255, 149, 0, 0.1);"> | |
| <p style="color: #8B4513; font-size: 0.9em; margin: 0; line-height: 1.5;"> | |
| <strong>⚠️ Beta Feature - Continuously Optimizing</strong><br> | |
| Results may vary depending on complexity. Use templates and detailed prompts for best results. | |
| Advanced features (like Add Accessories) may require multiple attempts. | |
| </p> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| # Left column - Input | |
| with gr.Column(scale=1): | |
| # Image upload | |
| inpaint_image = gr.Image( | |
| label="Upload Image", | |
| type="pil", | |
| height=300 | |
| ) | |
| # Mask editor | |
| mask_editor = gr.ImageEditor( | |
| label="Draw Mask (white = area to inpaint)", | |
| type="pil", | |
| height=300, | |
| brush=gr.Brush(colors=["#FFFFFF"], default_size=20), | |
| eraser=gr.Eraser(default_size=20), | |
| layers=True, | |
| sources=["upload"], | |
| image_mode="RGBA" | |
| ) | |
| # Template selection | |
| with gr.Accordion("Inpainting Templates", open=False): | |
| inpaint_template = gr.Dropdown( | |
| choices=[""] + self.inpainting_template_manager.get_template_choices_sorted(), | |
| value="", | |
| label="Select Template", | |
| elem_classes=["template-dropdown"] | |
| ) | |
| template_tips = gr.Markdown("") | |
| # Prompt | |
| inpaint_prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe what you want to generate in the masked area...", | |
| lines=2 | |
| ) | |
| # Right column - Settings and Output | |
| with gr.Column(scale=1): | |
| # Settings | |
| with gr.Accordion("Generation Settings", open=True): | |
| conditioning_type = gr.Radio( | |
| choices=["canny", "depth"], | |
| value="canny", | |
| label="ControlNet Mode" | |
| ) | |
| conditioning_scale = gr.Slider( | |
| minimum=0.05, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.05, | |
| label="ControlNet Strength" | |
| ) | |
| feather_radius = gr.Slider( | |
| minimum=0, | |
| maximum=20, | |
| value=8, | |
| step=1, | |
| label="Feather Radius (px)" | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| inpaint_guidance = gr.Slider( | |
| minimum=5.0, | |
| maximum=15.0, | |
| value=7.5, | |
| step=0.5, | |
| label="Guidance Scale" | |
| ) | |
| inpaint_steps = gr.Slider( | |
| minimum=15, | |
| maximum=50, | |
| value=25, | |
| step=5, | |
| label="Inference Steps" | |
| ) | |
| # Generate button | |
| inpaint_btn = gr.Button( | |
| "Generate Inpainting", | |
| variant="primary", | |
| elem_classes=["primary-button"] | |
| ) | |
| # Processing time reminder | |
| gr.Markdown( | |
| """ | |
| <div style="background: linear-gradient(135deg, #fff8e1 0%, #ffecb3 100%); | |
| border-left: 4px solid #ffa000; | |
| padding: 12px 16px; | |
| border-radius: 8px; | |
| margin: 12px 0;"> | |
| <p style="margin: 0; color: #5d4037; font-size: 14px;"> | |
| ⏳ <strong>Please be patient!</strong> Inpainting typically takes <strong>5-7 minutes</strong> | |
| depending on GPU availability and image complexity. | |
| Please don't refresh the page while processing. | |
| </p> | |
| </div> | |
| <div style="background: linear-gradient(135deg, #e3f2fd 0%, #bbdefb 100%); | |
| border-left: 4px solid #1976d2; | |
| padding: 12px 16px; | |
| border-radius: 8px; | |
| margin: 12px 0;"> | |
| <p style="margin: 0; color: #0d47a1; font-size: 14px;"> | |
| 🔄 <strong>Want to make more changes?</strong> After each generation, please | |
| <strong>re-upload your image</strong> and draw a new mask if you want to apply additional edits. | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| # Status | |
| inpaint_status = gr.Textbox( | |
| label="Status", | |
| value="Ready for inpainting", | |
| interactive=False | |
| ) | |
| # Output row | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| inpaint_result = gr.Image( | |
| label="Result", | |
| type="pil", | |
| height=400 | |
| ) | |
| with gr.Column(scale=1): | |
| # Control image (structure guidance visualization) | |
| inpaint_control = gr.Image( | |
| label="Control Image (Structure Guidance)", | |
| type="pil", | |
| height=400 | |
| ) | |
| # Event handlers | |
| inpaint_template.change( | |
| fn=self.apply_inpainting_template, | |
| inputs=[inpaint_template, inpaint_prompt], | |
| outputs=[inpaint_prompt, conditioning_scale, feather_radius, conditioning_type] | |
| ) | |
| inpaint_template.change( | |
| fn=lambda x: self._get_template_tips(x), | |
| inputs=[inpaint_template], | |
| outputs=[template_tips] | |
| ) | |
| # Copy uploaded image to mask editor | |
| inpaint_image.change( | |
| fn=lambda x: x, | |
| inputs=[inpaint_image], | |
| outputs=[mask_editor] | |
| ) | |
| inpaint_btn.click( | |
| fn=self.inpainting_handler, | |
| inputs=[ | |
| inpaint_image, | |
| mask_editor, | |
| inpaint_prompt, | |
| inpaint_template, | |
| conditioning_type, | |
| conditioning_scale, | |
| feather_radius, | |
| inpaint_guidance, | |
| inpaint_steps | |
| ], | |
| outputs=[ | |
| inpaint_result, | |
| inpaint_control, | |
| inpaint_status | |
| ] | |
| ) | |
| return tab | |
| def _get_template_tips(self, display_name: str) -> str: | |
| """Get usage tips for selected template.""" | |
| if not display_name: | |
| return "" | |
| template_key = self.inpainting_template_manager.get_template_key_from_display(display_name) | |
| if not template_key: | |
| return "" | |
| tips = self.inpainting_template_manager.get_usage_tips(template_key) | |
| if tips: | |
| return "**Tips:**\n" + "\n".join(f"- {tip}" for tip in tips) | |
| return "" |