Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import logging | |
| import gc | |
| import time | |
| from typing import Optional, Dict, Any, Tuple, List, Callable | |
| from pathlib import Path | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler | |
| import open_clip | |
| import traceback | |
| from mask_generator import MaskGenerator | |
| from image_blender import ImageBlender | |
| from quality_checker import QualityChecker | |
| from model_manager import get_model_manager, ModelPriority | |
| from inpainting_module import InpaintingModule | |
| from inpainting_templates import InpaintingTemplateManager | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| class SceneWeaverCore: | |
| """ | |
| SceneWeaver Core Engine - Facade for all AI generation subsystems. | |
| Integrates SDXL pipeline, OpenCLIP analysis, mask generation, image blending, | |
| and inpainting functionality into a unified interface. | |
| Attributes: | |
| device: Computation device (cuda/mps/cpu) | |
| is_initialized: Whether models are loaded | |
| inpainting_module: Optional InpaintingModule instance | |
| Example: | |
| >>> core = SceneWeaverCore() | |
| >>> core.load_models() | |
| >>> result = core.generate_and_combine(image, prompt="sunset beach") | |
| """ | |
| # Model registry names | |
| MODEL_SDXL_PIPELINE = "sdxl_background_pipeline" | |
| MODEL_OPENCLIP = "openclip_analyzer" | |
| MODEL_INPAINTING_PIPELINE = "inpainting_pipeline" | |
| # Style presets for diversity generation mode | |
| STYLE_PRESETS = { | |
| "professional": { | |
| "name": "Professional Business", | |
| "modifier": "professional office environment, clean background, corporate setting, bright even lighting", | |
| "negative_extra": "casual, messy, cluttered", | |
| "guidance_scale": 8.0 | |
| }, | |
| "casual": { | |
| "name": "Casual Lifestyle", | |
| "modifier": "casual outdoor setting, natural environment, relaxed atmosphere, warm natural lighting", | |
| "negative_extra": "formal, studio", | |
| "guidance_scale": 7.5 | |
| }, | |
| "artistic": { | |
| "name": "Artistic Creative", | |
| "modifier": "artistic background, creative composition, vibrant colors, interesting lighting", | |
| "negative_extra": "boring, plain", | |
| "guidance_scale": 6.5 | |
| }, | |
| "nature": { | |
| "name": "Natural Scenery", | |
| "modifier": "beautiful natural scenery, outdoor landscape, scenic view, natural lighting", | |
| "negative_extra": "urban, indoor", | |
| "guidance_scale": 7.5 | |
| } | |
| } | |
| def __init__(self, device: str = "auto"): | |
| self.device = self._setup_device(device) | |
| # Model configurations - KEEP SAME FOR PERFECT GENERATION | |
| self.base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
| self.clip_model_name = "ViT-B-32" | |
| self.clip_pretrained = "openai" | |
| # Pipeline objects | |
| self.pipeline = None | |
| self.clip_model = None | |
| self.clip_preprocess = None | |
| self.clip_tokenizer = None | |
| self.is_initialized = False | |
| # Generation settings - KEEP SAME | |
| self.max_image_size = 1024 | |
| self.default_steps = 25 | |
| self.use_fp16 = True | |
| # Enhanced memory management | |
| self.generation_count = 0 | |
| self.cleanup_frequency = 1 # More frequent cleanup | |
| self.max_history = 3 # Limit generation history | |
| # Initialize helper classes | |
| self.mask_generator = MaskGenerator(self.max_image_size) | |
| self.image_blender = ImageBlender() | |
| self.quality_checker = QualityChecker() | |
| # Model manager reference | |
| self._model_manager = get_model_manager() | |
| # Inpainting module (lazy loaded) | |
| self._inpainting_module = None | |
| self._inpainting_initialized = False | |
| # Current mode tracking | |
| self._current_mode = "background" # "background" or "inpainting" | |
| logger.info(f"SceneWeaverCore initialized on {self.device}") | |
| def _setup_device(self, device: str) -> str: | |
| """Setup computation device (ZeroGPU compatible)""" | |
| import os | |
| # On Hugging Face Spaces with ZeroGPU, use CPU for initialization | |
| # GPU will be allocated by @spaces.GPU decorator at runtime | |
| if os.getenv('SPACE_ID') is not None: | |
| logger.info("Running on Hugging Face Spaces - using CPU for initialization") | |
| return "cpu" | |
| if device == "auto": | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): | |
| return "mps" | |
| else: | |
| return "cpu" | |
| return device | |
| def _ultra_memory_cleanup(self): | |
| """Ultra aggressive memory cleanup for Colab stability""" | |
| import os | |
| logger.debug("🧹 Ultra memory cleanup...") | |
| # Multiple rounds of garbage collection | |
| for i in range(5): | |
| gc.collect() | |
| # On Hugging Face Spaces, skip CUDA operations in main process | |
| is_spaces = os.getenv('SPACE_ID') is not None | |
| if not is_spaces and torch.cuda.is_available(): | |
| # Clear all cached memory | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| # Force synchronization | |
| torch.cuda.synchronize() | |
| # Clear any remaining memory fragments | |
| try: | |
| torch.cuda.memory.empty_cache() | |
| except: | |
| pass | |
| logger.debug("✅ Ultra cleanup completed") | |
| def load_models(self, progress_callback: Optional[callable] = None): | |
| """Load AI models - KEEP SAME FOR PERFECT GENERATION""" | |
| if self.is_initialized: | |
| logger.info("Models already loaded") | |
| return | |
| logger.info("📥 Loading AI models...") | |
| try: | |
| self._ultra_memory_cleanup() | |
| if progress_callback: | |
| progress_callback("Loading OpenCLIP for image understanding...", 20) | |
| # Load OpenCLIP - KEEP SAME | |
| self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms( | |
| self.clip_model_name, | |
| pretrained=self.clip_pretrained, | |
| device=self.device | |
| ) | |
| self.clip_tokenizer = open_clip.get_tokenizer(self.clip_model_name) | |
| self.clip_model.eval() | |
| logger.info("✅ OpenCLIP loaded") | |
| if progress_callback: | |
| progress_callback("Loading SDXL text-to-image pipeline...", 60) | |
| # Load standard SDXL text-to-image pipeline - KEEP SAME | |
| self.pipeline = StableDiffusionXLPipeline.from_pretrained( | |
| self.base_model_id, | |
| torch_dtype=torch.float16 if self.use_fp16 else torch.float32, | |
| use_safetensors=True, | |
| variant="fp16" if self.use_fp16 else None | |
| ) | |
| # Use DPM solver for faster generation - KEEP SAME | |
| self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config( | |
| self.pipeline.scheduler.config | |
| ) | |
| # Move to device | |
| self.pipeline = self.pipeline.to(self.device) | |
| if progress_callback: | |
| progress_callback("Applying optimizations...", 90) | |
| # Memory optimizations - ENHANCED | |
| try: | |
| self.pipeline.enable_xformers_memory_efficient_attention() | |
| logger.info("✅ xformers enabled") | |
| except Exception: | |
| try: | |
| self.pipeline.enable_attention_slicing() | |
| logger.info("✅ Attention slicing enabled") | |
| except Exception: | |
| logger.warning("⚠️ No memory optimizations available") | |
| # Additional memory optimizations | |
| if hasattr(self.pipeline, 'enable_vae_tiling'): | |
| self.pipeline.enable_vae_tiling() | |
| if hasattr(self.pipeline, 'enable_vae_slicing'): | |
| self.pipeline.enable_vae_slicing() | |
| # Set to eval mode | |
| self.pipeline.unet.eval() | |
| if hasattr(self.pipeline, 'vae'): | |
| self.pipeline.vae.eval() | |
| # Enable sequential CPU offload if very low on memory | |
| try: | |
| if torch.cuda.is_available(): | |
| free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated() | |
| if free_memory < 4 * 1024**3: # Less than 4GB free | |
| self.pipeline.enable_sequential_cpu_offload() | |
| logger.info("✅ Sequential CPU offload enabled for low memory") | |
| except: | |
| pass | |
| self.is_initialized = True | |
| if progress_callback: | |
| progress_callback("Models loaded successfully!", 100) | |
| # Memory status | |
| if torch.cuda.is_available(): | |
| memory_used = torch.cuda.memory_allocated() / 1024**3 | |
| memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| logger.info(f"📊 GPU Memory: {memory_used:.1f}GB / {memory_total:.1f}GB") | |
| except Exception as e: | |
| logger.error(f"❌ Model loading failed: {e}") | |
| raise RuntimeError(f"Failed to load models: {str(e)}") | |
| def analyze_image_with_clip(self, image: Image.Image) -> str: | |
| """Analyze uploaded image using OpenCLIP - KEEP SAME""" | |
| if not self.clip_model: | |
| return "Image analysis not available" | |
| try: | |
| image_input = self.clip_preprocess(image).unsqueeze(0).to(self.device) | |
| categories = [ | |
| "a photo of a person", | |
| "a photo of an animal", | |
| "a photo of an object", | |
| "a photo of a character", | |
| "a photo of a cartoon", | |
| "a photo of nature", | |
| "a photo of a building", | |
| "a photo of a landscape" | |
| ] | |
| text_inputs = self.clip_tokenizer(categories).to(self.device) | |
| with torch.no_grad(): | |
| image_features = self.clip_model.encode_image(image_input) | |
| text_features = self.clip_model.encode_text(text_inputs) | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) | |
| best_match_idx = similarity.argmax().item() | |
| confidence = similarity[0, best_match_idx].item() | |
| category = categories[best_match_idx].replace("a photo of ", "") | |
| return f"Detected: {category} (confidence: {confidence:.1%})" | |
| except Exception as e: | |
| logger.error(f"CLIP analysis failed: {e}") | |
| return "Image analysis failed" | |
| def enhance_prompt( | |
| self, | |
| user_prompt: str, | |
| foreground_image: Image.Image | |
| ) -> str: | |
| """ | |
| Smart prompt enhancement based on image analysis. | |
| Adds appropriate lighting, atmosphere, and quality descriptors. | |
| Args: | |
| user_prompt: Original user-provided prompt | |
| foreground_image: Foreground image for analysis | |
| Returns: | |
| Enhanced prompt string | |
| """ | |
| logger.info("✨ Enhancing prompt based on image analysis...") | |
| try: | |
| # Analyze image characteristics | |
| img_array = np.array(foreground_image.convert('RGB')) | |
| # Analyze color temperature | |
| # Convert to LAB to analyze color temperature | |
| lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB) | |
| avg_a = np.mean(lab[:, :, 1]) # a channel: green(-) to red(+) | |
| avg_b = np.mean(lab[:, :, 2]) # b channel: blue(-) to yellow(+) | |
| # Determine warm/cool tone | |
| is_warm = avg_b > 128 # b > 128 means more yellow/warm | |
| # Analyze brightness | |
| gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
| avg_brightness = np.mean(gray) | |
| is_bright = avg_brightness > 127 | |
| # Get subject type from CLIP | |
| clip_analysis = self.analyze_image_with_clip(foreground_image) | |
| subject_type = "unknown" | |
| if "person" in clip_analysis.lower(): | |
| subject_type = "person" | |
| elif "animal" in clip_analysis.lower(): | |
| subject_type = "animal" | |
| elif "object" in clip_analysis.lower(): | |
| subject_type = "object" | |
| elif "character" in clip_analysis.lower() or "cartoon" in clip_analysis.lower(): | |
| subject_type = "character" | |
| elif "nature" in clip_analysis.lower() or "landscape" in clip_analysis.lower(): | |
| subject_type = "nature" | |
| # Build prompt fragments library | |
| lighting_options = { | |
| "warm_bright": "warm golden hour lighting, soft natural light", | |
| "warm_dark": "warm ambient lighting, cozy atmosphere", | |
| "cool_bright": "bright daylight, clear sky lighting", | |
| "cool_dark": "soft diffused light, gentle shadows" | |
| } | |
| atmosphere_options = { | |
| "person": "professional, elegant composition", | |
| "animal": "natural, harmonious setting", | |
| "object": "clean product photography style", | |
| "character": "artistic, vibrant, imaginative", | |
| "nature": "scenic, peaceful atmosphere", | |
| "unknown": "balanced composition" | |
| } | |
| quality_modifiers = "high quality, detailed, sharp focus, photorealistic" | |
| # Select appropriate fragments | |
| # Lighting based on color temperature and brightness | |
| if is_warm and is_bright: | |
| lighting = lighting_options["warm_bright"] | |
| elif is_warm and not is_bright: | |
| lighting = lighting_options["warm_dark"] | |
| elif not is_warm and is_bright: | |
| lighting = lighting_options["cool_bright"] | |
| else: | |
| lighting = lighting_options["cool_dark"] | |
| # Atmosphere based on subject type | |
| atmosphere = atmosphere_options.get(subject_type, atmosphere_options["unknown"]) | |
| # Check for conflicts in user prompt | |
| user_prompt_lower = user_prompt.lower() | |
| # Avoid adding conflicting descriptions | |
| if "sunset" in user_prompt_lower or "golden" in user_prompt_lower: | |
| lighting = "" # User already specified lighting | |
| if "dark" in user_prompt_lower or "night" in user_prompt_lower: | |
| lighting = lighting.replace("bright", "").replace("daylight", "") | |
| # Combine enhanced prompt | |
| fragments = [user_prompt] | |
| if lighting: | |
| fragments.append(lighting) | |
| if atmosphere: | |
| fragments.append(atmosphere) | |
| fragments.append(quality_modifiers) | |
| enhanced_prompt = ", ".join(filter(None, fragments)) | |
| logger.info(f"📝 Original prompt: {user_prompt[:50]}...") | |
| logger.info(f"📝 Enhanced prompt: {enhanced_prompt[:80]}...") | |
| return enhanced_prompt | |
| except Exception as e: | |
| logger.warning(f"⚠️ Prompt enhancement failed: {e}, using original prompt") | |
| return user_prompt | |
| def _prepare_image(self, image: Image.Image) -> Image.Image: | |
| """Prepare image for processing - KEEP SAME""" | |
| # Convert to RGB | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Resize if too large | |
| width, height = image.size | |
| max_size = self.max_image_size | |
| if width > max_size or height > max_size: | |
| ratio = min(max_size/width, max_size/height) | |
| new_width = int(width * ratio) | |
| new_height = int(height * ratio) | |
| image = image.resize((new_width, new_height), Image.LANCZOS) | |
| # Ensure dimensions are multiple of 8 | |
| width, height = image.size | |
| new_width = (width // 8) * 8 | |
| new_height = (height // 8) * 8 | |
| if new_width != width or new_height != height: | |
| image = image.resize((new_width, new_height), Image.LANCZOS) | |
| return image | |
| def generate_background( | |
| self, | |
| prompt: str, | |
| width: int, | |
| height: int, | |
| negative_prompt: str = "blurry, low quality, distorted", | |
| num_inference_steps: int = 25, | |
| guidance_scale: float = 7.5, | |
| progress_callback: Optional[callable] = None | |
| ) -> Image.Image: | |
| """Generate complete background using standard text-to-image - KEEP SAME""" | |
| if not self.is_initialized: | |
| raise RuntimeError("Models not loaded. Call load_models() first.") | |
| logger.info(f"🎨 Generating background: {prompt[:50]}...") | |
| try: | |
| with torch.inference_mode(): | |
| if progress_callback: | |
| progress_callback("Generating background with SDXL...", 50) | |
| # Standard text-to-image generation - KEEP SAME | |
| result = self.pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| generator=torch.Generator(device=self.device).manual_seed(42) | |
| ) | |
| generated_image = result.images[0] | |
| if progress_callback: | |
| progress_callback("Background generated successfully!", 100) | |
| logger.info("✅ Background generation completed!") | |
| return generated_image | |
| except torch.cuda.OutOfMemoryError: | |
| logger.error("❌ GPU memory exhausted") | |
| self._ultra_memory_cleanup() | |
| raise RuntimeError("GPU memory insufficient") | |
| except Exception as e: | |
| logger.error(f"❌ Background generation failed: {e}") | |
| raise RuntimeError(f"Generation failed: {str(e)}") | |
| def generate_and_combine( | |
| self, | |
| original_image: Image.Image, | |
| prompt: str, | |
| combination_mode: str = "center", | |
| focus_mode: str = "person", | |
| negative_prompt: str = "blurry, low quality, distorted", | |
| num_inference_steps: int = 25, | |
| guidance_scale: float = 7.5, | |
| progress_callback: Optional[callable] = None, | |
| enable_prompt_enhancement: bool = True | |
| ) -> Dict[str, Any]: | |
| """ | |
| Generate background and combine with foreground using advanced blending. | |
| Args: | |
| original_image: Foreground image | |
| prompt: User's background description | |
| combination_mode: How to position foreground ("center", "left_half", "right_half", "full") | |
| focus_mode: Focus type ("person" for tight crop, "scene" for wider context) | |
| negative_prompt: What to avoid in generation | |
| num_inference_steps: SDXL inference steps | |
| guidance_scale: Classifier-free guidance scale | |
| progress_callback: Progress reporting callback | |
| enable_prompt_enhancement: Whether to use smart prompt enhancement | |
| Returns: | |
| Dictionary containing results and metadata | |
| """ | |
| if not self.is_initialized: | |
| raise RuntimeError("Models not loaded. Call load_models() first.") | |
| logger.info(f"🎨 Starting generation and combination with advanced features...") | |
| try: | |
| # Enhanced memory management | |
| if self.generation_count % self.cleanup_frequency == 0: | |
| self._ultra_memory_cleanup() | |
| if progress_callback: | |
| progress_callback("Analyzing uploaded image...", 5) | |
| # Analyze original image | |
| image_analysis = self.analyze_image_with_clip(original_image) | |
| if progress_callback: | |
| progress_callback("Preparing images...", 10) | |
| # Prepare original image | |
| processed_original = self._prepare_image(original_image) | |
| target_width, target_height = processed_original.size | |
| if progress_callback: | |
| progress_callback("Optimizing prompt...", 15) | |
| # Smart prompt enhancement | |
| if enable_prompt_enhancement: | |
| enhanced_prompt = self.enhance_prompt(prompt, processed_original) | |
| else: | |
| enhanced_prompt = f"{prompt}, high quality, detailed, photorealistic, beautiful scenery" | |
| enhanced_negative = f"{negative_prompt}, people, characters, cartoons, logos" | |
| if progress_callback: | |
| progress_callback("Generating complete background scene...", 25) | |
| def bg_progress(msg, pct): | |
| if progress_callback: | |
| progress_callback(f"Background: {msg}", 25 + (pct/100) * 50) | |
| generated_background = self.generate_background( | |
| prompt=enhanced_prompt, | |
| width=target_width, | |
| height=target_height, | |
| negative_prompt=enhanced_negative, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| progress_callback=bg_progress | |
| ) | |
| if progress_callback: | |
| progress_callback("Creating intelligent mask for person detection...", 80) | |
| # Use intelligent mask generation with enhanced logging | |
| logger.info("🎭 Starting intelligent mask generation...") | |
| combination_mask = self.mask_generator.create_gradient_based_mask( | |
| processed_original, | |
| combination_mode, | |
| focus_mode | |
| ) | |
| # Log mask quality for debugging | |
| try: | |
| mask_array = np.array(combination_mask) | |
| logger.info(f"📊 Generated mask stats - Mean: {mask_array.mean():.1f}, Non-zero pixels: {np.count_nonzero(mask_array)}") | |
| except Exception as mask_debug_error: | |
| logger.warning(f"⚠️ Mask debug logging failed: {mask_debug_error}") | |
| if progress_callback: | |
| progress_callback("Advanced image blending...", 90) | |
| # Use advanced image blending with logging | |
| logger.info("🖌️ Starting advanced image blending...") | |
| combined_image = self.image_blender.simple_blend_images( | |
| processed_original, | |
| generated_background, | |
| combination_mask | |
| ) | |
| logger.info("✅ Image blending completed successfully") | |
| if progress_callback: | |
| progress_callback("Creating debug images...", 95) | |
| # Generate debug images | |
| debug_images = self.image_blender.create_debug_images( | |
| processed_original, | |
| generated_background, | |
| combination_mask, | |
| combined_image | |
| ) | |
| # Memory cleanup after generation | |
| self._ultra_memory_cleanup() | |
| # Update generation count | |
| self.generation_count += 1 | |
| if progress_callback: | |
| progress_callback("Generation complete!", 100) | |
| logger.info("✅ Complete generation and combination with fixed blending successful!") | |
| return { | |
| "combined_image": combined_image, | |
| "generated_scene": generated_background, | |
| "original_image": processed_original, | |
| "combination_mask": combination_mask, | |
| "debug_mask_gray": debug_images["mask_gray"], | |
| "debug_alpha_heatmap": debug_images["alpha_heatmap"], | |
| "image_analysis": image_analysis, | |
| "enhanced_prompt": enhanced_prompt, | |
| "original_prompt": prompt, | |
| "success": True, | |
| "generation_count": self.generation_count | |
| } | |
| except Exception as e: | |
| error_traceback = traceback.format_exc() | |
| logger.error(f"❌ Generation and combination failed: {str(e)}") | |
| logger.error(f"📍 Full traceback:\n{error_traceback}") | |
| print(f"❌ DETAILED ERROR in scene_weaver_core.generate_and_combine:") | |
| print(f"Error: {str(e)}") | |
| print(f"Traceback:\n{error_traceback}") | |
| self._ultra_memory_cleanup() # Cleanup on error too | |
| return { | |
| "success": False, | |
| "error": f"Failed: {str(e)}" | |
| } | |
| def generate_diversity_variants( | |
| self, | |
| original_image: Image.Image, | |
| prompt: str, | |
| selected_styles: Optional[List[str]] = None, | |
| combination_mode: str = "center", | |
| focus_mode: str = "person", | |
| negative_prompt: str = "blurry, low quality, distorted", | |
| progress_callback: Optional[callable] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Generate multiple style variants of the background. | |
| Uses reduced quality for faster preview generation. | |
| Args: | |
| original_image: Foreground image | |
| prompt: Base background description | |
| selected_styles: List of style keys to use (None = all styles) | |
| combination_mode: Foreground positioning mode | |
| focus_mode: Focus type for mask generation | |
| negative_prompt: Base negative prompt | |
| progress_callback: Progress callback function | |
| Returns: | |
| Dictionary containing variants and metadata | |
| """ | |
| if not self.is_initialized: | |
| raise RuntimeError("Models not loaded. Call load_models() first.") | |
| logger.info("🎨 Starting diversity generation mode...") | |
| # Determine which styles to generate | |
| styles_to_generate = selected_styles or list(self.STYLE_PRESETS.keys()) | |
| num_styles = len(styles_to_generate) | |
| results = { | |
| "variants": [], | |
| "success": True, | |
| "num_variants": 0 | |
| } | |
| try: | |
| # Pre-process image once | |
| processed_original = self._prepare_image(original_image) | |
| target_width, target_height = processed_original.size | |
| # Reduce resolution for faster generation | |
| preview_size = min(768, max(target_width, target_height)) | |
| scale = preview_size / max(target_width, target_height) | |
| preview_width = int(target_width * scale) // 8 * 8 | |
| preview_height = int(target_height * scale) // 8 * 8 | |
| # Generate mask once (reusable for all variants) | |
| if progress_callback: | |
| progress_callback("Creating foreground mask...", 5) | |
| combination_mask = self.mask_generator.create_gradient_based_mask( | |
| processed_original, combination_mode, focus_mode | |
| ) | |
| # Resize mask for preview resolution | |
| preview_mask = combination_mask.resize((preview_width, preview_height), Image.LANCZOS) | |
| preview_original = processed_original.resize((preview_width, preview_height), Image.LANCZOS) | |
| # Generate each style variant | |
| for idx, style_key in enumerate(styles_to_generate): | |
| if style_key not in self.STYLE_PRESETS: | |
| logger.warning(f"⚠️ Unknown style: {style_key}, skipping") | |
| continue | |
| style = self.STYLE_PRESETS[style_key] | |
| style_name = style["name"] | |
| if progress_callback: | |
| base_pct = 10 + (idx / num_styles) * 80 | |
| progress_callback(f"Generating {style_name} variant...", int(base_pct)) | |
| logger.info(f"🎨 Generating variant: {style_name}") | |
| try: | |
| # Build style-specific prompt | |
| styled_prompt = f"{prompt}, {style['modifier']}, high quality, detailed" | |
| styled_negative = f"{negative_prompt}, {style['negative_extra']}, people, characters" | |
| # Generate background with reduced steps for speed | |
| background = self.generate_background( | |
| prompt=styled_prompt, | |
| width=preview_width, | |
| height=preview_height, | |
| negative_prompt=styled_negative, | |
| num_inference_steps=15, # Reduced for speed | |
| guidance_scale=style["guidance_scale"] | |
| ) | |
| # Blend images | |
| combined = self.image_blender.simple_blend_images( | |
| preview_original, | |
| background, | |
| preview_mask, | |
| use_multi_scale=False # Skip for speed | |
| ) | |
| results["variants"].append({ | |
| "style_key": style_key, | |
| "style_name": style_name, | |
| "combined_image": combined, | |
| "background": background, | |
| "prompt_used": styled_prompt | |
| }) | |
| # Memory cleanup between variants | |
| self._ultra_memory_cleanup() | |
| except Exception as variant_error: | |
| logger.error(f"❌ Failed to generate {style_name} variant: {variant_error}") | |
| continue | |
| results["num_variants"] = len(results["variants"]) | |
| if progress_callback: | |
| progress_callback("Diversity generation complete!", 100) | |
| logger.info(f"✅ Generated {results['num_variants']} style variants") | |
| return results | |
| except Exception as e: | |
| logger.error(f"❌ Diversity generation failed: {e}") | |
| self._ultra_memory_cleanup() | |
| return { | |
| "variants": [], | |
| "success": False, | |
| "error": str(e), | |
| "num_variants": 0 | |
| } | |
| def regenerate_high_quality( | |
| self, | |
| original_image: Image.Image, | |
| prompt: str, | |
| style_key: str, | |
| combination_mode: str = "center", | |
| focus_mode: str = "person", | |
| negative_prompt: str = "blurry, low quality, distorted", | |
| progress_callback: Optional[callable] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Regenerate a specific style at full quality. | |
| Args: | |
| original_image: Original foreground image | |
| prompt: Base prompt | |
| style_key: Style preset key to use | |
| combination_mode: Foreground positioning | |
| focus_mode: Mask focus mode | |
| negative_prompt: Base negative prompt | |
| progress_callback: Progress callback | |
| Returns: | |
| Full quality result dictionary | |
| """ | |
| if style_key not in self.STYLE_PRESETS: | |
| return {"success": False, "error": f"Unknown style: {style_key}"} | |
| style = self.STYLE_PRESETS[style_key] | |
| # Build styled prompt | |
| styled_prompt = f"{prompt}, {style['modifier']}" | |
| styled_negative = f"{negative_prompt}, {style['negative_extra']}" | |
| # Use full generate_and_combine with style parameters | |
| return self.generate_and_combine( | |
| original_image=original_image, | |
| prompt=styled_prompt, | |
| combination_mode=combination_mode, | |
| focus_mode=focus_mode, | |
| negative_prompt=styled_negative, | |
| num_inference_steps=25, # Full quality | |
| guidance_scale=style["guidance_scale"], | |
| progress_callback=progress_callback, | |
| enable_prompt_enhancement=True | |
| ) | |
| def get_memory_status(self) -> Dict[str, Any]: | |
| """Enhanced memory status reporting""" | |
| status = {"device": self.device} | |
| if torch.cuda.is_available(): | |
| allocated = torch.cuda.memory_allocated() / 1024**3 | |
| total = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| cached = torch.cuda.memory_reserved() / 1024**3 | |
| status.update({ | |
| "gpu_allocated_gb": round(allocated, 2), | |
| "gpu_total_gb": round(total, 2), | |
| "gpu_cached_gb": round(cached, 2), | |
| "gpu_free_gb": round(total - allocated, 2), | |
| "gpu_usage_percent": round((allocated / total) * 100, 1), | |
| "generation_count": self.generation_count | |
| }) | |
| return status | |
| # INPAINTING FACADE METHODS | |
| def get_inpainting_module(self): | |
| """ | |
| Get or create the InpaintingModule instance. | |
| Implements lazy loading - module is only created when first accessed. | |
| Returns | |
| ------- | |
| InpaintingModule | |
| The inpainting module instance | |
| """ | |
| if self._inpainting_module is None: | |
| self._inpainting_module = InpaintingModule(device=self.device) | |
| self._inpainting_module.set_model_manager(self._model_manager) | |
| logger.info("InpaintingModule created (lazy load)") | |
| return self._inpainting_module | |
| def switch_to_inpainting_mode( | |
| self, | |
| conditioning_type: str = "canny", | |
| progress_callback: Optional[Callable[[str, int], None]] = None | |
| ) -> bool: | |
| """ | |
| Switch to inpainting mode, unloading background pipeline. | |
| Implements mutual exclusion between pipelines to conserve memory. | |
| Parameters | |
| ---------- | |
| conditioning_type : str | |
| ControlNet conditioning type: "canny" or "depth" | |
| progress_callback : callable, optional | |
| Progress update function(message, percentage) | |
| Returns | |
| ------- | |
| bool | |
| True if switch was successful | |
| """ | |
| logger.info(f"Switching to inpainting mode (conditioning: {conditioning_type})") | |
| try: | |
| # Unload background pipeline first | |
| if self.pipeline is not None: | |
| if progress_callback: | |
| progress_callback("Unloading background pipeline...", 10) | |
| del self.pipeline | |
| self.pipeline = None | |
| self._ultra_memory_cleanup() | |
| logger.info("Background pipeline unloaded") | |
| # Load inpainting pipeline | |
| if progress_callback: | |
| progress_callback("Loading inpainting pipeline...", 20) | |
| inpaint_module = self.get_inpainting_module() | |
| def inpaint_progress(msg, pct): | |
| if progress_callback: | |
| # Map inpainting progress (0-100) to (20-90) | |
| mapped_pct = 20 + int(pct * 0.7) | |
| progress_callback(msg, mapped_pct) | |
| success, error_msg = inpaint_module.load_inpainting_pipeline( | |
| conditioning_type=conditioning_type, | |
| progress_callback=inpaint_progress | |
| ) | |
| if success: | |
| self._current_mode = "inpainting" | |
| self._inpainting_initialized = True | |
| if progress_callback: | |
| progress_callback("Inpainting mode ready!", 100) | |
| logger.info("Successfully switched to inpainting mode") | |
| else: | |
| self._last_inpainting_error = error_msg | |
| logger.error(f"Failed to load inpainting pipeline: {error_msg}") | |
| return success | |
| except Exception as e: | |
| traceback.print_exc() | |
| self._last_inpainting_error = str(e) | |
| logger.error(f"Failed to switch to inpainting mode: {e}") | |
| if progress_callback: | |
| progress_callback(f"Error: {str(e)}", 0) | |
| return False | |
| def switch_to_background_mode( | |
| self, | |
| progress_callback: Optional[Callable[[str, int], None]] = None | |
| ) -> bool: | |
| """ | |
| Switch back to background generation mode. | |
| Parameters | |
| ---------- | |
| progress_callback : callable, optional | |
| Progress update function | |
| Returns | |
| ------- | |
| bool | |
| True if switch was successful | |
| """ | |
| logger.info("Switching to background generation mode") | |
| try: | |
| # Unload inpainting pipeline | |
| if self._inpainting_module is not None and self._inpainting_module.is_initialized: | |
| if progress_callback: | |
| progress_callback("Unloading inpainting pipeline...", 10) | |
| self._inpainting_module._unload_pipeline() | |
| self._ultra_memory_cleanup() | |
| # Reload background pipeline | |
| if progress_callback: | |
| progress_callback("Loading background pipeline...", 30) | |
| # Reset initialization flag to force reload | |
| self.is_initialized = False | |
| self.load_models(progress_callback=progress_callback) | |
| self._current_mode = "background" | |
| if progress_callback: | |
| progress_callback("Background mode ready!", 100) | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to switch to background mode: {e}") | |
| return False | |
| def execute_inpainting( | |
| self, | |
| image: Image.Image, | |
| mask: Image.Image, | |
| prompt: str, | |
| preview_only: bool = False, | |
| template_key: Optional[str] = None, | |
| progress_callback: Optional[Callable[[str, int], None]] = None, | |
| **kwargs | |
| ) -> Dict[str, Any]: | |
| """ | |
| Execute inpainting operation through the Facade. | |
| This is the main entry point for inpainting functionality. | |
| Parameters | |
| ---------- | |
| image : PIL.Image | |
| Original image to inpaint | |
| mask : PIL.Image | |
| Inpainting mask (white = area to regenerate) | |
| prompt : str | |
| Text description of desired content | |
| preview_only : bool | |
| If True, generate quick preview only | |
| template_key : str, optional | |
| Inpainting template key to use | |
| progress_callback : callable, optional | |
| Progress update function | |
| **kwargs | |
| Additional inpainting parameters | |
| Returns | |
| ------- | |
| dict | |
| Result dictionary with images and metadata | |
| """ | |
| # Ensure inpainting mode is active | |
| if self._current_mode != "inpainting" or not self._inpainting_initialized: | |
| conditioning = kwargs.get('conditioning_type', 'canny') | |
| if not self.switch_to_inpainting_mode(conditioning, progress_callback): | |
| error_detail = getattr(self, '_last_inpainting_error', 'Unknown error') | |
| return { | |
| "success": False, | |
| "error": f"Failed to initialize inpainting mode: {error_detail}" | |
| } | |
| inpaint_module = self.get_inpainting_module() | |
| # Apply template if specified | |
| if template_key: | |
| template_mgr = InpaintingTemplateManager() | |
| template = template_mgr.get_template(template_key) | |
| if template: | |
| # Build prompt from template | |
| prompt = template_mgr.build_prompt(template_key, prompt) | |
| # Apply template parameters as defaults | |
| params = template_mgr.get_parameters_for_template(template_key) | |
| for key, value in params.items(): | |
| if key not in kwargs: | |
| kwargs[key] = value | |
| # Pass enhance_prompt flag to inpainting module | |
| if 'enhance_prompt' not in kwargs: | |
| kwargs['enhance_prompt'] = template.enhance_prompt | |
| # Execute inpainting | |
| result = inpaint_module.execute_inpainting( | |
| image=image, | |
| mask=mask, | |
| prompt=prompt, | |
| preview_only=preview_only, | |
| progress_callback=progress_callback, | |
| template_key=template_key, # Pass template_key for conditional prompt enhancement | |
| **kwargs | |
| ) | |
| # Convert InpaintingResult to dictionary format | |
| return { | |
| "success": result.success, | |
| "combined_image": result.blended_image or result.result_image, | |
| "generated_image": result.result_image, | |
| "preview_image": result.preview_image, | |
| "control_image": result.control_image, | |
| "original_image": image, | |
| "mask": mask, | |
| "quality_score": result.quality_score, | |
| "generation_time": result.generation_time, | |
| "metadata": result.metadata, | |
| "error": result.error_message if not result.success else None | |
| } | |
| def execute_inpainting_with_optimization( | |
| self, | |
| image: Image.Image, | |
| mask: Image.Image, | |
| prompt: str, | |
| progress_callback: Optional[Callable[[str, int], None]] = None, | |
| **kwargs | |
| ) -> Dict[str, Any]: | |
| """ | |
| Execute inpainting with automatic quality optimization. | |
| Retries with adjusted parameters if quality is below threshold. | |
| Parameters | |
| ---------- | |
| image : PIL.Image | |
| Original image | |
| mask : PIL.Image | |
| Inpainting mask | |
| prompt : str | |
| Text prompt | |
| progress_callback : callable, optional | |
| Progress callback | |
| **kwargs | |
| Additional parameters | |
| Returns | |
| ------- | |
| dict | |
| Optimized result dictionary | |
| """ | |
| # Ensure inpainting mode | |
| if self._current_mode != "inpainting" or not self._inpainting_initialized: | |
| conditioning = kwargs.get('conditioning_type', 'canny') | |
| if not self.switch_to_inpainting_mode(conditioning, progress_callback): | |
| error_detail = getattr(self, '_last_inpainting_error', 'Unknown error') | |
| return { | |
| "success": False, | |
| "error": f"Failed to initialize inpainting mode: {error_detail}" | |
| } | |
| inpaint_module = self.get_inpainting_module() | |
| result = inpaint_module.execute_with_auto_optimization( | |
| image=image, | |
| mask=mask, | |
| prompt=prompt, | |
| quality_checker=self.quality_checker, | |
| progress_callback=progress_callback, | |
| **kwargs | |
| ) | |
| return { | |
| "success": result.success, | |
| "combined_image": result.blended_image or result.result_image, | |
| "generated_image": result.result_image, | |
| "preview_image": result.preview_image, | |
| "control_image": result.control_image, | |
| "quality_score": result.quality_score, | |
| "quality_details": result.quality_details, | |
| "retries": result.retries, | |
| "generation_time": result.generation_time, | |
| "metadata": result.metadata, | |
| "error": result.error_message if not result.success else None | |
| } | |
| def get_current_mode(self) -> str: | |
| """ | |
| Get current operation mode. | |
| Returns | |
| ------- | |
| str | |
| "background" or "inpainting" | |
| """ | |
| return self._current_mode | |
| def is_inpainting_ready(self) -> bool: | |
| """ | |
| Check if inpainting is ready to use. | |
| Returns | |
| ------- | |
| bool | |
| True if inpainting module is loaded and ready | |
| """ | |
| return ( | |
| self._inpainting_module is not None and | |
| self._inpainting_module.is_initialized | |
| ) | |
| def get_inpainting_status(self) -> Dict[str, Any]: | |
| """ | |
| Get inpainting module status. | |
| Returns | |
| ------- | |
| dict | |
| Status information | |
| """ | |
| if self._inpainting_module is None: | |
| return { | |
| "initialized": False, | |
| "mode": self._current_mode | |
| } | |
| status = self._inpainting_module.get_status() | |
| status["mode"] = self._current_mode | |
| return status |