import gc import logging import os import time import traceback from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Tuple, Union import cv2 import numpy as np import torch from PIL import Image, ImageFilter from diffusers import ControlNetModel, DPMSolverMultistepScheduler from diffusers import StableDiffusionXLControlNetInpaintPipeline from diffusers import StableDiffusionXLInpaintPipeline from transformers import AutoImageProcessor, AutoModelForDepthEstimation from transformers import DPTImageProcessor, DPTForDepthEstimation logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @dataclass class InpaintingConfig: """Configuration for inpainting operations.""" # ControlNet settings controlnet_conditioning_scale: float = 0.7 conditioning_type: str = "canny" # "canny" or "depth" # Canny edge detection parameters canny_low_threshold: int = 100 canny_high_threshold: int = 200 # Mask settings feather_radius: int = 8 min_mask_coverage: float = 0.01 max_mask_coverage: float = 0.95 # Generation settings num_inference_steps: int = 25 guidance_scale: float = 7.5 strength: float = 1.0 # Inpainting strength (0.0-1.0), 1.0 = full repaint preview_steps: int = 15 preview_guidance_scale: float = 8.0 # Quality settings enable_auto_optimization: bool = True max_optimization_retries: int = 3 min_quality_score: float = 70.0 # Memory settings enable_vae_tiling: bool = True enable_attention_slicing: bool = True max_resolution: int = 1024 @dataclass class InpaintingResult: """Result container for inpainting operations.""" success: bool result_image: Optional[Image.Image] = None preview_image: Optional[Image.Image] = None control_image: Optional[Image.Image] = None blended_image: Optional[Image.Image] = None quality_score: float = 0.0 quality_details: Dict[str, Any] = field(default_factory=dict) generation_time: float = 0.0 retries: int = 0 error_message: str = "" metadata: Dict[str, Any] = field(default_factory=dict) class InpaintingModule: """ ControlNet-based Inpainting Module for SceneWeaver. Implements StableDiffusionXLControlNetInpaintPipeline with support for Canny edge and depth map conditioning. Features two-stage generation (preview + full quality) and automatic quality optimization. Attributes: device: Computation device (cuda/mps/cpu) config: InpaintingConfig instance is_initialized: Whether pipeline is loaded Example: >>> module = InpaintingModule(device="cuda") >>> module.load_inpainting_pipeline(progress_callback=my_callback) >>> result = module.execute_inpainting( ... image=my_image, ... mask=my_mask, ... prompt="a beautiful garden" ... ) """ # Model identifiers CONTROLNET_CANNY_MODEL = "diffusers/controlnet-canny-sdxl-1.0" CONTROLNET_DEPTH_MODEL = "diffusers/controlnet-depth-sdxl-1.0" DEPTH_MODEL_PRIMARY = "LiheYoung/depth-anything-small-hf" DEPTH_MODEL_FALLBACK = "Intel/dpt-hybrid-midas" BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" def __init__( self, device: str = "auto", config: Optional[InpaintingConfig] = None ): """ Initialize the InpaintingModule. Parameters ---------- device : str, optional Computation device. "auto" for automatic detection. config : InpaintingConfig, optional Configuration object. Uses defaults if not provided. """ self.device = self._setup_device(device) self.config = config or InpaintingConfig() # Pipeline instances (lazy loaded) self._inpaint_pipeline = None self._controlnet_canny = None self._controlnet_depth = None self._depth_estimator = None self._depth_processor = None # State tracking self.is_initialized = False self._current_conditioning_type = None self._last_seed = None self._cached_latents = None self._use_controlnet = True # Track if ControlNet is available # Reference to model manager (set by SceneWeaverCore) self._model_manager = None logger.info(f"InpaintingModule initialized on {self.device}") def _setup_device(self, device: str) -> str: """ Setup computation device. Parameters ---------- device : str Device specification or "auto" Returns ------- str Resolved device name """ if device == "auto": if torch.cuda.is_available(): return "cuda" elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): return "mps" return "cpu" return device def set_model_manager(self, manager: Any) -> None: """ Set reference to ModelManager for coordinated model lifecycle. Parameters ---------- manager : ModelManager The global model manager instance """ self._model_manager = manager logger.info("ModelManager reference set for InpaintingModule") def _memory_cleanup(self, aggressive: bool = False) -> None: """ Perform memory cleanup. Parameters ---------- aggressive : bool If True, perform multiple GC rounds and sync CUDA """ rounds = 5 if aggressive else 2 for _ in range(rounds): gc.collect() # On Hugging Face Spaces, avoid CUDA operations in main process # CUDA operations must only happen within @spaces.GPU decorated functions is_spaces = os.getenv('SPACE_ID') is not None if not is_spaces and torch.cuda.is_available(): torch.cuda.empty_cache() if aggressive: torch.cuda.ipc_collect() torch.cuda.synchronize() logger.debug(f"Memory cleanup completed (aggressive={aggressive}, spaces={is_spaces})") def _check_memory_status(self) -> Dict[str, float]: """ Check current GPU memory status. Returns ------- dict Memory statistics including allocated, total, and usage ratio """ # On Spaces, skip CUDA checks in main process is_spaces = os.getenv('SPACE_ID') is not None if is_spaces or not torch.cuda.is_available(): return {"available": True, "usage_ratio": 0.0} allocated = torch.cuda.memory_allocated() / 1024**3 total = torch.cuda.get_device_properties(0).total_memory / 1024**3 usage_ratio = allocated / total return { "allocated_gb": round(allocated, 2), "total_gb": round(total, 2), "free_gb": round(total - allocated, 2), "usage_ratio": round(usage_ratio, 3), "available": usage_ratio < 0.9 } def load_inpainting_pipeline( self, conditioning_type: str = "canny", progress_callback: Optional[Callable[[str, int], None]] = None ) -> Tuple[bool, str]: """ Load the ControlNet inpainting pipeline. Implements mutual exclusion with background generation pipeline. Only one pipeline can be loaded at a time. Parameters ---------- conditioning_type : str Type of ControlNet conditioning: "canny" or "depth" progress_callback : callable, optional Function(message, percentage) for progress updates Returns ------- tuple (success: bool, error_message: str) """ if self.is_initialized and self._current_conditioning_type == conditioning_type: logger.info(f"Inpainting pipeline already loaded with {conditioning_type}") return True, "" logger.info(f"Loading inpainting pipeline with {conditioning_type} conditioning...") try: self._memory_cleanup(aggressive=True) if progress_callback: progress_callback("Preparing to load inpainting models...", 5) # Unload existing pipeline if different conditioning type if self._inpaint_pipeline is not None: self._unload_pipeline() # Use ControlNet inpainting by default use_controlnet_inpaint = True logger.info("Using StableDiffusionXLControlNetInpaintPipeline") if progress_callback: progress_callback("Loading ControlNet model...", 20) # Load appropriate ControlNet dtype = torch.float16 if self.device == "cuda" else torch.float32 controlnet = None if use_controlnet_inpaint: if conditioning_type == "canny": controlnet = ControlNetModel.from_pretrained( self.CONTROLNET_CANNY_MODEL, torch_dtype=dtype, use_safetensors=True ) self._controlnet_canny = controlnet logger.info("Loaded ControlNet Canny model") elif conditioning_type == "depth": controlnet = ControlNetModel.from_pretrained( self.CONTROLNET_DEPTH_MODEL, torch_dtype=dtype, use_safetensors=True ) self._controlnet_depth = controlnet # Load depth estimator if progress_callback: progress_callback("Loading depth estimation model...", 35) self._load_depth_estimator() logger.info("Loaded ControlNet Depth model") else: raise ValueError(f"Unknown conditioning type: {conditioning_type}") else: # Skip ControlNet loading for fallback mode logger.info(f"Skipping ControlNet loading (fallback mode)") if progress_callback: progress_callback("Loading SDXL Inpainting pipeline...", 50) # Load the inpainting pipeline if use_controlnet_inpaint and controlnet is not None: self._inpaint_pipeline = StableDiffusionXLControlNetInpaintPipeline.from_pretrained( self.BASE_MODEL, controlnet=controlnet, torch_dtype=dtype, use_safetensors=True, variant="fp16" if dtype == torch.float16 else None ) else: # Fallback: Use dedicated inpainting model without ControlNet self._inpaint_pipeline = StableDiffusionXLInpaintPipeline.from_pretrained( "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=dtype, use_safetensors=True, variant="fp16" if dtype == torch.float16 else None ) self._use_controlnet = False # Track ControlNet usage self._use_controlnet = use_controlnet_inpaint and controlnet is not None if progress_callback: progress_callback("Configuring scheduler...", 70) # Configure scheduler for faster generation self._inpaint_pipeline.scheduler = DPMSolverMultistepScheduler.from_config( self._inpaint_pipeline.scheduler.config ) # Move to device self._inpaint_pipeline = self._inpaint_pipeline.to(self.device) if progress_callback: progress_callback("Applying optimizations...", 85) # Apply memory optimizations self._apply_pipeline_optimizations() # Set eval mode self._inpaint_pipeline.unet.eval() if hasattr(self._inpaint_pipeline, 'vae'): self._inpaint_pipeline.vae.eval() self.is_initialized = True self._current_conditioning_type = conditioning_type if self._use_controlnet else "none" if progress_callback: progress_callback("Inpainting pipeline ready!", 100) # Log memory status mem_status = self._check_memory_status() logger.info(f"Pipeline loaded. GPU memory: {mem_status.get('allocated_gb', 0):.1f}GB used") return True, "" except Exception as e: error_msg = str(e) logger.error(f"Failed to load inpainting pipeline: {error_msg}") traceback.print_exc() self._unload_pipeline() return False, error_msg def _load_depth_estimator(self) -> None: """ Load depth estimation model with fallback strategy. Tries Depth-Anything first, falls back to MiDaS if unavailable. """ try: logger.info(f"Attempting to load depth model: {self.DEPTH_MODEL_PRIMARY}") self._depth_processor = AutoImageProcessor.from_pretrained( self.DEPTH_MODEL_PRIMARY ) self._depth_estimator = AutoModelForDepthEstimation.from_pretrained( self.DEPTH_MODEL_PRIMARY, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ) self._depth_estimator.to(self.device) self._depth_estimator.eval() logger.info("Successfully loaded Depth-Anything model") except Exception as e: logger.warning(f"Primary depth model failed: {e}, trying fallback...") try: self._depth_processor = DPTImageProcessor.from_pretrained( self.DEPTH_MODEL_FALLBACK ) self._depth_estimator = DPTForDepthEstimation.from_pretrained( self.DEPTH_MODEL_FALLBACK, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ) self._depth_estimator.to(self.device) self._depth_estimator.eval() logger.info("Successfully loaded MiDaS fallback model") except Exception as fallback_e: logger.error(f"Fallback depth model also failed: {fallback_e}") raise RuntimeError("Unable to load any depth estimation model") def _apply_pipeline_optimizations(self) -> None: """Apply memory and performance optimizations to the pipeline.""" if self._inpaint_pipeline is None: return # Try xformers first try: self._inpaint_pipeline.enable_xformers_memory_efficient_attention() logger.info("Enabled xformers memory efficient attention") except Exception: try: self._inpaint_pipeline.enable_attention_slicing() logger.info("Enabled attention slicing") except Exception: logger.warning("No attention optimization available") # VAE optimizations if self.config.enable_vae_tiling: if hasattr(self._inpaint_pipeline, 'enable_vae_tiling'): self._inpaint_pipeline.enable_vae_tiling() logger.debug("Enabled VAE tiling") if hasattr(self._inpaint_pipeline, 'enable_vae_slicing'): self._inpaint_pipeline.enable_vae_slicing() logger.debug("Enabled VAE slicing") def _unload_pipeline(self) -> None: """Unload the inpainting pipeline and free memory.""" logger.info("Unloading inpainting pipeline...") if self._inpaint_pipeline is not None: del self._inpaint_pipeline self._inpaint_pipeline = None if self._controlnet_canny is not None: del self._controlnet_canny self._controlnet_canny = None if self._controlnet_depth is not None: del self._controlnet_depth self._controlnet_depth = None if self._depth_estimator is not None: del self._depth_estimator self._depth_estimator = None if self._depth_processor is not None: del self._depth_processor self._depth_processor = None self.is_initialized = False self._current_conditioning_type = None self._cached_latents = None self._memory_cleanup(aggressive=True) logger.info("Inpainting pipeline unloaded") def prepare_control_image( self, image: Image.Image, mode: str = "canny", mask: Optional[Image.Image] = None, preserve_structure: bool = False ) -> Image.Image: """ Generate ControlNet conditioning image. Parameters ---------- image : PIL.Image Input image mode : str Conditioning mode: "canny" or "depth" mask : PIL.Image, optional If provided, can suppress edges in masked region (when preserve_structure=False). preserve_structure : bool If True, keep edges in masked region (for color change tasks). If False, suppress edges in masked region (for replacement/removal tasks). Returns ------- PIL.Image Generated control image (edges or depth map) """ logger.info(f"Preparing control image with mode: {mode}, preserve_structure: {preserve_structure}") # Convert to RGB if needed if image.mode != 'RGB': image = image.convert('RGB') img_array = np.array(image) if mode == "canny": canny_image = self._generate_canny_edges(img_array) # Mask-aware processing: suppress edges in masked region ONLY if not preserving structure if mask is not None and not preserve_structure: canny_array = np.array(canny_image) mask_array = np.array(mask.convert('L')) # In masked region, completely suppress Canny edges # This allows complete replacement/removal of the object mask_region = mask_array > 128 # White = masked area canny_array[mask_region] = 0 canny_image = Image.fromarray(canny_array) logger.info("Suppressed edges in masked region for replacement/removal") elif preserve_structure: logger.info("Preserving edges in masked region for color change") return canny_image elif mode == "depth": return self._generate_depth_map(image) else: raise ValueError(f"Unknown control mode: {mode}") def _generate_canny_edges(self, img_array: np.ndarray) -> Image.Image: """ Generate Canny edge detection image. Parameters ---------- img_array : np.ndarray Input image as RGB numpy array Returns ------- PIL.Image Edge detection result as grayscale image """ # Convert to grayscale gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) # Apply Gaussian blur to reduce noise blurred = cv2.GaussianBlur(gray, (5, 5), 1.4) # Canny edge detection edges = cv2.Canny( blurred, self.config.canny_low_threshold, self.config.canny_high_threshold ) # Convert to 3-channel for ControlNet edges_3ch = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) logger.debug(f"Generated Canny edges with thresholds " f"{self.config.canny_low_threshold}/{self.config.canny_high_threshold}") return Image.fromarray(edges_3ch) def _generate_depth_map(self, image: Image.Image) -> Image.Image: """ Generate depth map using depth estimation model. Parameters ---------- image : PIL.Image Input RGB image Returns ------- PIL.Image Depth map as grayscale image """ if self._depth_estimator is None or self._depth_processor is None: raise RuntimeError("Depth estimator not loaded") # Preprocess inputs = self._depth_processor(images=image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} # Inference with torch.no_grad(): outputs = self._depth_estimator(**inputs) predicted_depth = outputs.predicted_depth # Interpolate to original size prediction = torch.nn.functional.interpolate( predicted_depth.unsqueeze(1), size=image.size[::-1], # (H, W) mode="bicubic", align_corners=False ) # Normalize to 0-255 depth_array = prediction.squeeze().cpu().numpy() depth_min = depth_array.min() depth_max = depth_array.max() if depth_max - depth_min > 0: depth_normalized = ((depth_array - depth_min) / (depth_max - depth_min) * 255) else: depth_normalized = np.zeros_like(depth_array) depth_normalized = depth_normalized.astype(np.uint8) # Convert to 3-channel for ControlNet depth_3ch = cv2.cvtColor(depth_normalized, cv2.COLOR_GRAY2RGB) logger.debug(f"Generated depth map, range: {depth_min:.2f} - {depth_max:.2f}") return Image.fromarray(depth_3ch) def prepare_mask( self, mask: Image.Image, target_size: Tuple[int, int], feather_radius: Optional[int] = None ) -> Tuple[Image.Image, Dict[str, Any]]: """ Prepare and validate mask for inpainting. Parameters ---------- mask : PIL.Image Input mask (white = inpaint area) target_size : tuple Target (width, height) to match input image feather_radius : int, optional Feathering radius in pixels. Uses config default if None. Returns ------- tuple (processed_mask, validation_info) Raises ------ ValueError If mask coverage is outside acceptable range """ feather = feather_radius if feather_radius is not None else self.config.feather_radius # Convert to grayscale if mask.mode != 'L': mask = mask.convert('L') # Resize to match target if mask.size != target_size: mask = mask.resize(target_size, Image.LANCZOS) # Convert to array for processing mask_array = np.array(mask) # Calculate coverage total_pixels = mask_array.size white_pixels = np.count_nonzero(mask_array > 127) coverage = white_pixels / total_pixels validation_info = { "coverage": coverage, "white_pixels": white_pixels, "total_pixels": total_pixels, "feather_radius": feather, "valid": True, "warning": "" } # Validate coverage if coverage < self.config.min_mask_coverage: validation_info["valid"] = False validation_info["warning"] = ( f"Mask coverage too low ({coverage:.1%}). " f"Please select a larger area to inpaint." ) logger.warning(f"Mask coverage {coverage:.1%} below minimum {self.config.min_mask_coverage:.1%}") elif coverage > self.config.max_mask_coverage: validation_info["valid"] = False validation_info["warning"] = ( f"Mask coverage too high ({coverage:.1%}). " f"Consider using background generation instead." ) logger.warning(f"Mask coverage {coverage:.1%} above maximum {self.config.max_mask_coverage:.1%}") # Apply feathering if feather > 0: mask_array = cv2.GaussianBlur( mask_array, (feather * 2 + 1, feather * 2 + 1), feather / 2 ) logger.debug(f"Applied {feather}px feathering to mask") processed_mask = Image.fromarray(mask_array, mode='L') return processed_mask, validation_info def enhance_prompt_for_inpainting( self, prompt: str, image: Image.Image, mask: Image.Image ) -> Tuple[str, str]: """ Enhance prompt based on non-masked region analysis. Analyzes the surrounding context to generate appropriate lighting and color descriptors. Parameters ---------- prompt : str User-provided prompt image : PIL.Image Original image mask : PIL.Image Inpainting mask Returns ------- tuple (enhanced_prompt, negative_prompt) """ logger.info("Enhancing prompt for inpainting context...") # Convert to arrays img_array = np.array(image.convert('RGB')) mask_array = np.array(mask.convert('L')) # Analyze non-masked regions non_masked = mask_array < 127 if not np.any(non_masked): # No context available enhanced_prompt = f"{prompt}, high quality, detailed, photorealistic" negative_prompt = self._get_inpainting_negative_prompt() return enhanced_prompt, negative_prompt # Extract context pixels context_pixels = img_array[non_masked] # Convert to Lab for analysis context_lab = cv2.cvtColor( context_pixels.reshape(-1, 1, 3), cv2.COLOR_RGB2LAB ).reshape(-1, 3) # Use robust statistics (median) to avoid outlier influence median_l = np.median(context_lab[:, 0]) median_a = np.median(context_lab[:, 1]) median_b = np.median(context_lab[:, 2]) # Analyze lighting conditions lighting_descriptors = [] if median_l > 170: lighting_descriptors.append("bright") elif median_l > 130: lighting_descriptors.append("well-lit") elif median_l > 80: lighting_descriptors.append("moderate lighting") else: lighting_descriptors.append("dim lighting") # Analyze color temperature (b channel: blue(-) to yellow(+)) if median_b > 140: lighting_descriptors.append("warm golden tones") elif median_b > 120: lighting_descriptors.append("warm afternoon light") elif median_b < 110: lighting_descriptors.append("cool neutral tones") # Calculate saturation from context hsv = cv2.cvtColor(context_pixels.reshape(-1, 1, 3), cv2.COLOR_RGB2HSV) median_saturation = np.median(hsv[:, :, 1]) if median_saturation > 150: lighting_descriptors.append("vibrant colors") elif median_saturation < 80: lighting_descriptors.append("subtle muted colors") # Build enhanced prompt lighting_desc = ", ".join(lighting_descriptors) if lighting_descriptors else "" quality_suffix = "high quality, detailed, photorealistic, seamless integration" if lighting_desc: enhanced_prompt = f"{prompt}, {lighting_desc}, {quality_suffix}" else: enhanced_prompt = f"{prompt}, {quality_suffix}" negative_prompt = self._get_inpainting_negative_prompt() logger.info(f"Enhanced prompt with context: {lighting_desc}") return enhanced_prompt, negative_prompt def _get_inpainting_negative_prompt(self) -> str: """Get standard negative prompt for inpainting.""" return ( "inconsistent lighting, wrong perspective, mismatched colors, " "visible seams, blending artifacts, color bleeding, " "blurry, low quality, distorted, deformed, " "harsh edges, unnatural transition" ) def execute_inpainting( self, image: Image.Image, mask: Image.Image, prompt: str, preview_only: bool = False, seed: Optional[int] = None, progress_callback: Optional[Callable[[str, int], None]] = None, **kwargs ) -> InpaintingResult: """ Execute the inpainting operation. Implements two-stage generation: fast preview followed by full quality generation if requested. Parameters ---------- image : PIL.Image Original image to inpaint mask : PIL.Image Inpainting mask (white = area to regenerate) prompt : str Text description of desired content preview_only : bool If True, only generate preview (faster) seed : int, optional Random seed for reproducibility progress_callback : callable, optional Progress update function(message, percentage) **kwargs Additional parameters: - controlnet_conditioning_scale: float - feather_radius: int - num_inference_steps: int - guidance_scale: float Returns ------- InpaintingResult Result container with generated images and metadata """ start_time = time.time() if not self.is_initialized: return InpaintingResult( success=False, error_message="Inpainting pipeline not initialized. Call load_inpainting_pipeline() first." ) logger.info(f"Starting inpainting: prompt='{prompt[:50]}...', preview_only={preview_only}") try: # Update config with kwargs conditioning_scale = kwargs.get( 'controlnet_conditioning_scale', self.config.controlnet_conditioning_scale ) feather_radius = kwargs.get('feather_radius', self.config.feather_radius) strength = kwargs.get('strength', self.config.strength) preserve_structure = kwargs.get('preserve_structure_in_mask', False) if progress_callback: progress_callback("Preparing images...", 5) # Prepare image if image.mode != 'RGB': image = image.convert('RGB') # Ensure dimensions are multiple of 8 width, height = image.size new_width = (width // 8) * 8 new_height = (height // 8) * 8 if new_width != width or new_height != height: image = image.resize((new_width, new_height), Image.LANCZOS) # Check and potentially reduce resolution for memory max_res = self.config.max_resolution if max(new_width, new_height) > max_res: scale = max_res / max(new_width, new_height) new_width = int(new_width * scale) // 8 * 8 new_height = int(new_height * scale) // 8 * 8 image = image.resize((new_width, new_height), Image.LANCZOS) logger.info(f"Reduced resolution to {new_width}x{new_height} for memory") # Prepare mask if progress_callback: progress_callback("Processing mask...", 10) processed_mask, mask_info = self.prepare_mask( mask, (new_width, new_height), feather_radius ) if not mask_info["valid"]: return InpaintingResult( success=False, error_message=mask_info["warning"] ) # Generate control image if progress_callback: progress_callback("Generating control image...", 20) control_image = self.prepare_control_image( image, self._current_conditioning_type, mask=processed_mask, preserve_structure=preserve_structure # True for color change, False for replacement/removal ) # Conditional prompt enhancement based on template # Check if we should enhance the prompt or use it directly should_enhance = kwargs.get('enhance_prompt', False) # Default: no enhancement if should_enhance: if progress_callback: progress_callback("Enhancing prompt...", 25) enhanced_prompt, negative_prompt = self.enhance_prompt_for_inpainting( prompt, image, processed_mask ) logger.info(f"Prompt enhanced with OpenCLIP context") else: # Use prompt directly without enhancement enhanced_prompt = prompt negative_prompt = self._get_inpainting_negative_prompt() logger.info("Prompt enhancement disabled for this template") # Setup generator for reproducibility if seed is None: seed = int(time.time() * 1000) % (2**32) self._last_seed = seed generator = torch.Generator(device=self.device).manual_seed(seed) # Check if running on Hugging Face Spaces is_spaces = os.getenv('SPACE_ID') is not None # Stage 1: Preview generation # On Spaces, skip preview to save time (300s hard limit) preview_result = None if preview_only or not is_spaces: if progress_callback: progress_callback("Generating preview...", 30) # Optimize preview steps for Hugging Face Spaces preview_steps = self.config.preview_steps if is_spaces: # On Spaces, use minimal preview steps preview_steps = min(preview_steps, 8) logger.debug(f"Spaces environment - using {preview_steps} preview steps") preview_result = self._generate_inpaint( image=image, mask=processed_mask, control_image=control_image, prompt=enhanced_prompt, negative_prompt=negative_prompt, num_inference_steps=preview_steps, guidance_scale=self.config.preview_guidance_scale, controlnet_conditioning_scale=conditioning_scale, strength=strength, generator=generator ) else: logger.debug("Spaces environment - skipping preview to fit 300s limit") if preview_only: generation_time = time.time() - start_time return InpaintingResult( success=True, preview_image=preview_result, control_image=control_image, generation_time=generation_time, metadata={ "seed": seed, "prompt": enhanced_prompt, "conditioning_type": self._current_conditioning_type, "conditioning_scale": conditioning_scale, "preview_only": True } ) # Stage 2: Full quality generation if progress_callback: progress_callback("Generating full quality...", 60) # Use same seed for reproducibility generator = torch.Generator(device=self.device).manual_seed(seed) num_steps = kwargs.get('num_inference_steps', self.config.num_inference_steps) guidance = kwargs.get('guidance_scale', self.config.guidance_scale) # Optimize for Hugging Face Spaces ZeroGPU (stateless, 300s hard limit) if is_spaces: # ZeroGPU timing breakdown with model caching (actual measurements): # - Model loading from cache: ~60s (cached models, CPU to GPU transfer) # - Inference: ~28-29s/step (observed on shared H200) # - Blending & overhead: ~35s # - Platform limit: 300s hard limit (Pro tier) # # Strategy with unified 10-step approach: # - Skip preview completely (done above) # - Use 10 steps for balance of quality and speed # - Time budget: 60s (load) + 285s (10 steps) + 35s (blend) = 380s # - Note: Still may timeout, but parameter optimization is more important than step count # - Quality comes from correct conditioning_scale, not high step count spaces_max_steps = 10 # Optimized: 10 steps sufficient with proper parameters if num_steps > spaces_max_steps: num_steps = spaces_max_steps logger.debug(f"Spaces deployment: using {num_steps} steps (optimized for parameter quality)") full_result = self._generate_inpaint( image=image, mask=processed_mask, control_image=control_image, prompt=enhanced_prompt, negative_prompt=negative_prompt, num_inference_steps=num_steps, guidance_scale=guidance, controlnet_conditioning_scale=conditioning_scale, strength=strength, generator=generator ) if progress_callback: progress_callback("Blending result...", 90) # Blend result blended = self.blend_result(image, full_result, processed_mask) generation_time = time.time() - start_time if progress_callback: progress_callback("Complete!", 100) return InpaintingResult( success=True, result_image=full_result, preview_image=preview_result, control_image=control_image, blended_image=blended, generation_time=generation_time, metadata={ "seed": seed, "prompt": enhanced_prompt, "negative_prompt": negative_prompt, "conditioning_type": self._current_conditioning_type, "conditioning_scale": conditioning_scale, "strength": strength, "preserve_structure": preserve_structure, "num_inference_steps": num_steps, "guidance_scale": guidance, "feather_radius": feather_radius, "mask_coverage": mask_info["coverage"], "preview_only": False } ) except torch.cuda.OutOfMemoryError: logger.error("CUDA out of memory during inpainting") self._memory_cleanup(aggressive=True) return InpaintingResult( success=False, error_message="GPU memory exhausted. Try reducing image size or closing other applications." ) except Exception as e: logger.error(f"Inpainting failed: {e}") logger.error(traceback.format_exc()) return InpaintingResult( success=False, error_message=f"Inpainting failed: {str(e)}" ) def _generate_inpaint( self, image: Image.Image, mask: Image.Image, control_image: Image.Image, prompt: str, negative_prompt: str, num_inference_steps: int, guidance_scale: float, controlnet_conditioning_scale: float, strength: float, generator: torch.Generator ) -> Image.Image: """ Internal method to run the inpainting pipeline. Supports both ControlNet and non-ControlNet pipelines. Parameters ---------- image : PIL.Image Original image mask : PIL.Image Processed mask control_image : PIL.Image ControlNet conditioning image (ignored if ControlNet not available) prompt : str Enhanced prompt negative_prompt : str Negative prompt num_inference_steps : int Number of denoising steps guidance_scale : float Classifier-free guidance scale controlnet_conditioning_scale : float ControlNet influence strength (ignored if ControlNet not available) strength : float Inpainting strength (0.0-1.0). 1.0 = fully repaint masked area. generator : torch.Generator Random generator for reproducibility Returns ------- PIL.Image Generated image """ with torch.inference_mode(): if self._use_controlnet: # Full ControlNet inpainting pipeline result = self._inpaint_pipeline( prompt=prompt, negative_prompt=negative_prompt, image=image, mask_image=mask, control_image=control_image, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, controlnet_conditioning_scale=controlnet_conditioning_scale, strength=strength, generator=generator ) else: # Fallback: Standard SDXL inpainting without ControlNet result = self._inpaint_pipeline( prompt=prompt, negative_prompt=negative_prompt, image=image, mask_image=mask, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, strength=strength, generator=generator ) return result.images[0] def blend_result( self, original: Image.Image, generated: Image.Image, mask: Image.Image ) -> Image.Image: """ Blend generated content with original image. Uses linear color space blending for accurate results. Parameters ---------- original : PIL.Image Original image generated : PIL.Image Generated inpainted image mask : PIL.Image Blending mask (white = use generated) Returns ------- PIL.Image Blended result """ logger.info("Blending inpainting result...") # Ensure same size if generated.size != original.size: generated = generated.resize(original.size, Image.LANCZOS) if mask.size != original.size: mask = mask.resize(original.size, Image.LANCZOS) # Convert to arrays orig_array = np.array(original.convert('RGB')).astype(np.float32) gen_array = np.array(generated.convert('RGB')).astype(np.float32) mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0 # sRGB to linear conversion def srgb_to_linear(img): img_norm = img / 255.0 return np.where( img_norm <= 0.04045, img_norm / 12.92, np.power((img_norm + 0.055) / 1.055, 2.4) ) def linear_to_srgb(img): img_clipped = np.clip(img, 0, 1) return np.where( img_clipped <= 0.0031308, 12.92 * img_clipped, 1.055 * np.power(img_clipped, 1/2.4) - 0.055 ) # Convert to linear space orig_linear = srgb_to_linear(orig_array) gen_linear = srgb_to_linear(gen_array) # Alpha blending in linear space alpha = mask_array[:, :, np.newaxis] result_linear = gen_linear * alpha + orig_linear * (1 - alpha) # Convert back to sRGB result_srgb = linear_to_srgb(result_linear) result_array = (result_srgb * 255).astype(np.uint8) logger.debug("Blending completed in linear color space") return Image.fromarray(result_array) def execute_with_auto_optimization( self, image: Image.Image, mask: Image.Image, prompt: str, quality_checker: Any, progress_callback: Optional[Callable[[str, int], None]] = None, **kwargs ) -> InpaintingResult: """ Execute inpainting with automatic quality-based optimization. Retries with adjusted parameters if quality score is below threshold. Parameters ---------- image : PIL.Image Original image mask : PIL.Image Inpainting mask prompt : str Text prompt quality_checker : QualityChecker Quality assessment instance progress_callback : callable, optional Progress update function **kwargs Additional inpainting parameters Returns ------- InpaintingResult Best result achieved (may include retry information) """ if not self.config.enable_auto_optimization: return self.execute_inpainting( image, mask, prompt, progress_callback=progress_callback, **kwargs ) best_result = None best_score = 0.0 retry_count = 0 prev_score = 0.0 # Mutable parameters for optimization current_feather = kwargs.get('feather_radius', self.config.feather_radius) current_scale = kwargs.get( 'controlnet_conditioning_scale', self.config.controlnet_conditioning_scale ) current_guidance = kwargs.get('guidance_scale', self.config.guidance_scale) current_prompt = prompt while retry_count <= self.config.max_optimization_retries: if progress_callback and retry_count > 0: progress_callback(f"Optimizing (attempt {retry_count + 1})...", 5) # Execute inpainting result = self.execute_inpainting( image, mask, current_prompt, preview_only=False, feather_radius=current_feather, controlnet_conditioning_scale=current_scale, guidance_scale=current_guidance, progress_callback=progress_callback if retry_count == 0 else None, **{k: v for k, v in kwargs.items() if k not in ['feather_radius', 'controlnet_conditioning_scale', 'guidance_scale']} ) if not result.success: return result # Evaluate quality if result.blended_image is not None: quality_results = quality_checker.run_all_checks( foreground=image, background=result.result_image, mask=mask, combined=result.blended_image ) quality_score = quality_results.get("overall_score", 0) else: quality_score = 50.0 # Default if no blended image result.quality_score = quality_score result.quality_details = quality_results if result.blended_image else {} result.retries = retry_count logger.info(f"Quality score: {quality_score:.1f} (attempt {retry_count + 1})") # Track best result if quality_score > best_score: best_score = quality_score best_result = result # Check if quality is acceptable if quality_score >= self.config.min_quality_score: logger.info(f"Quality threshold met: {quality_score:.1f}") return best_result # Check for minimal improvement (early termination) if retry_count > 0 and abs(quality_score - prev_score) < 5.0: logger.info("Minimal improvement, stopping optimization") return best_result prev_score = quality_score retry_count += 1 if retry_count > self.config.max_optimization_retries: break # Adjust parameters based on quality issues checks = quality_results.get("checks", {}) edge_score = checks.get("edge_continuity", {}).get("score", 100) harmony_score = checks.get("color_harmony", {}).get("score", 100) if edge_score < 60: # Edge issues: increase feathering, decrease control strength current_feather = min(20, current_feather + 3) current_scale = max(0.5, current_scale - 0.1) logger.debug(f"Adjusting for edges: feather={current_feather}, scale={current_scale}") if harmony_score < 60: # Color harmony issues: emphasize consistency in prompt if "color consistent" not in current_prompt.lower(): current_prompt = f"{current_prompt}, color consistent with surroundings, matching lighting" current_guidance = min(12.0, current_guidance + 1.0) logger.debug(f"Adjusting for harmony: guidance={current_guidance}") if edge_score < 60 and harmony_score < 60: # Both issues: stronger guidance current_guidance = min(12.0, current_guidance + 1.5) logger.info(f"Optimization complete. Best score: {best_score:.1f}") return best_result def get_status(self) -> Dict[str, Any]: """ Get current module status. Returns ------- dict Status information including initialization state and memory usage """ status = { "initialized": self.is_initialized, "device": self.device, "conditioning_type": self._current_conditioning_type, "last_seed": self._last_seed, "config": { "controlnet_conditioning_scale": self.config.controlnet_conditioning_scale, "feather_radius": self.config.feather_radius, "num_inference_steps": self.config.num_inference_steps, "guidance_scale": self.config.guidance_scale } } status["memory"] = self._check_memory_status() return status