""" Main Style Transfer Pipeline for Interior Design Combines segmentation, style extraction, and SDXL generation """ import cv2 import numpy as np import torch from PIL import Image import os from typing import Tuple, Dict, Any, Optional from diffusers import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLControlNetPipeline from diffusers.utils import load_image from controlnet_aux import OpenposeDetector import cv2 from tqdm import tqdm from config import Config from segmentation import RoomSegmentation from style_extractor import StyleExtractor class InteriorStyleTransferPipeline: def __init__(self, config: Config): self.config = config self.device = config.DEVICE # Initialize components self.segmentation = RoomSegmentation(device=self.device) self.style_extractor = StyleExtractor() # Initialize SDXL pipeline self.sdxl_pipeline = None self.controlnet_pipeline = None self._load_models() def _load_models(self): """Load SDXL and ControlNet models""" try: print("Loading SDXL Img2Img pipeline...") self.sdxl_pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained( self.config.SDXL_MODEL_ID, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, use_safetensors=True, variant="fp16" if self.device == "cuda" else None ) if self.device == "cuda": self.sdxl_pipeline.enable_xformers_memory_efficient_attention() self.sdxl_pipeline.enable_model_cpu_offload() self.sdxl_pipeline.to(self.device) print("Loading ControlNet pipeline...") self.controlnet_pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( self.config.SDXL_MODEL_ID, controlnet=self.config.CONTROLNET_MODEL_ID, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, use_safetensors=True, variant="fp16" if self.device == "cuda" else None ) if self.device == "cuda": self.controlnet_pipeline.enable_xformers_memory_efficient_attention() self.controlnet_pipeline.enable_model_cpu_offload() self.controlnet_pipeline.to(self.device) except Exception as e: print(f"Warning: Could not load full pipeline: {e}") print("Falling back to basic SDXL pipeline") self.controlnet_pipeline = None def transfer_style(self, user_room_path: str, inspiration_room_path: str, output_path: str = None) -> Dict[str, Any]: """ Main method to transfer style from inspiration room to user room Args: user_room_path: Path to user's room image inspiration_room_path: Path to inspiration room image output_path: Path to save output image Returns: Dictionary containing results and metadata """ print("Starting interior style transfer...") # Load images user_room = cv2.imread(user_room_path) inspiration_room = cv2.imread(inspiration_room_path) if user_room is None or inspiration_room is None: raise ValueError("Could not load one or both images") # Resize images to match user_room = cv2.resize(user_room, (self.config.IMAGE_SIZE, self.config.IMAGE_SIZE)) inspiration_room = cv2.resize(inspiration_room, (self.config.IMAGE_SIZE, self.config.IMAGE_SIZE)) # Step 1: Segment user room to preserve structure print("Segmenting user room...") user_masks = self.segmentation.segment_room(user_room) preservation_mask = self.segmentation.create_preservation_mask( user_masks, self.config.PRESERVE_CLASSES ) # Step 2: Extract style from inspiration room print("Extracting style from inspiration room...") style_info = self.style_extractor.extract_style(inspiration_room) style_prompt = self.style_extractor.generate_style_prompt(style_info) # Step 3: Generate enhanced prompt enhanced_prompt = self._create_enhanced_prompt(style_info, user_room) # Step 4: Apply style transfer with structure preservation print("Applying style transfer...") result_image = self._apply_style_transfer( user_room, inspiration_room, preservation_mask, enhanced_prompt, style_info ) # Step 5: Post-process and blend print("Post-processing result...") final_result = self._post_process_result( result_image, user_room, preservation_mask, style_info ) # Save results if output_path is None: output_path = os.path.join(self.config.OUTPUT_DIR, "style_transfer_result.jpg") cv2.imwrite(output_path, final_result) # Save style analysis style_analysis_path = output_path.replace('.jpg', '_style_analysis.json') self.style_extractor.save_style_analysis(style_info, style_analysis_path) # Create visualization viz_path = output_path.replace('.jpg', '_analysis.jpg') analysis_viz = self.style_extractor.visualize_style_analysis(inspiration_room, style_info) cv2.imwrite(viz_path, cv2.cvtColor(analysis_viz, cv2.COLOR_RGB2BGR)) print(f"Style transfer completed! Results saved to {output_path}") return { 'result_image': final_result, 'style_info': style_info, 'preservation_mask': preservation_mask, 'enhanced_prompt': enhanced_prompt, 'output_path': output_path } def _create_enhanced_prompt(self, style_info: Dict[str, Any], user_room: np.ndarray) -> str: """Create an enhanced prompt combining style and room context""" base_prompt = style_info.get('style_category', 'interior design') color_temp = style_info['color_palette']['color_temperature'] tone = style_info['color_palette']['overall_tone'] # Analyze user room to add context user_style = self.style_extractor.extract_style(user_room) user_tone = user_style['color_palette']['overall_tone'] # Create context-aware prompt enhanced_prompt = f"professional interior design photography, {base_prompt} style, " enhanced_prompt += f"{color_temp} color palette, {tone} lighting, " enhanced_prompt += f"high quality furniture, elegant decorations, " enhanced_prompt += f"realistic textures, professional lighting, " enhanced_prompt += f"architectural photography, 8k resolution, " enhanced_prompt += f"detailed interior design, magazine quality" return enhanced_prompt def _apply_style_transfer(self, user_room: np.ndarray, inspiration_room: np.ndarray, preservation_mask: np.ndarray, enhanced_prompt: str, style_info: Dict[str, Any]) -> np.ndarray: """Apply style transfer using SDXL with structure preservation""" # Convert to PIL images user_room_pil = Image.fromarray(cv2.cvtColor(user_room, cv2.COLOR_BGR2RGB)) inspiration_room_pil = Image.fromarray(cv2.cvtColor(inspiration_room, cv2.COLOR_BGR2RGB)) # Create negative prompt negative_prompt = "blurry, low quality, distorted, unrealistic, poor lighting, " negative_prompt += "bad architecture, ugly furniture, cluttered, messy" # Use ControlNet if available for better structure preservation if self.controlnet_pipeline is not None: result = self._controlnet_style_transfer( user_room_pil, enhanced_prompt, negative_prompt, preservation_mask ) else: result = self._basic_style_transfer( user_room_pil, enhanced_prompt, negative_prompt ) return cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR) def _controlnet_style_transfer(self, user_room_pil: Image.Image, enhanced_prompt: str, negative_prompt: str, preservation_mask: np.ndarray) -> Image.Image: """Use ControlNet for better structure preservation""" # Create control image from preservation mask control_image = Image.fromarray(preservation_mask * 255) # Generate with ControlNet result = self.controlnet_pipeline( prompt=enhanced_prompt, negative_prompt=negative_prompt, image=user_room_pil, control_image=control_image, num_inference_steps=self.config.NUM_INFERENCE_STEPS, guidance_scale=self.config.GUIDANCE_SCALE, strength=self.config.STRENGTH, controlnet_conditioning_scale=0.8 ).images[0] return result def _basic_style_transfer(self, user_room_pil: Image.Image, enhanced_prompt: str, negative_prompt: str) -> Image.Image: """Basic style transfer using SDXL Img2Img""" result = self.sdxl_pipeline( prompt=enhanced_prompt, negative_prompt=negative_prompt, image=user_room_pil, num_inference_steps=self.config.NUM_INFERENCE_STEPS, guidance_scale=self.config.GUIDANCE_SCALE, strength=self.config.STRENGTH ).images[0] return result def _post_process_result(self, result_image: np.ndarray, user_room: np.ndarray, preservation_mask: np.ndarray, style_info: Dict[str, Any]) -> np.ndarray: """Post-process the result for better integration""" # Blend with original structure where needed alpha = self.config.BLEND_ALPHA # Create smooth blending mask kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20, 20)) smooth_mask = cv2.dilate(preservation_mask, kernel, iterations=1) smooth_mask = cv2.GaussianBlur(smooth_mask.astype(np.float32), (21, 21), 0) smooth_mask = smooth_mask / 255.0 # Blend images blended = (result_image * (1 - smooth_mask[..., np.newaxis]) + user_room * smooth_mask[..., np.newaxis]).astype(np.uint8) # Color correction to match inspiration style corrected = self._apply_color_correction(blended, style_info) # Final enhancement enhanced = self._enhance_final_result(corrected) return enhanced def _apply_color_correction(self, image: np.ndarray, style_info: Dict[str, Any]) -> np.ndarray: """Apply color correction to match inspiration style""" # Get dominant colors from inspiration inspiration_colors = np.array(style_info['color_palette']['dominant_colors']) inspiration_proportions = np.array(style_info['color_palette']['proportions']) # Calculate target color temperature target_temp = style_info['color_palette']['color_temperature'] # Convert to LAB for better color manipulation lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) if target_temp == "warm": # Increase red channel, decrease blue lab[:, :, 1] = np.clip(lab[:, :, 1] * 1.1, 0, 255) # Increase a (red-green) lab[:, :, 2] = np.clip(lab[:, :, 2] * 0.9, 0, 255) # Decrease b (blue-yellow) elif target_temp == "cool": # Increase blue channel, decrease red lab[:, :, 1] = np.clip(lab[:, :, 1] * 0.9, 0, 255) # Decrease a lab[:, :, 2] = np.clip(lab[:, :, 2] * 1.1, 0, 255) # Increase b # Convert back to BGR corrected = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) return corrected def _enhance_final_result(self, image: np.ndarray) -> np.ndarray: """Apply final enhancements to the result""" # Slight sharpening kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]]) sharpened = cv2.filter2D(image, -1, kernel) # Blend with original for natural look enhanced = cv2.addWeighted(image, 0.7, sharpened, 0.3, 0) # Slight contrast enhancement lab = cv2.cvtColor(enhanced, cv2.COLOR_BGR2LAB) lab[:, :, 0] = cv2.equalizeHist(lab[:, :, 0]) enhanced = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) return enhanced def batch_process(self, user_rooms: list, inspiration_rooms: list, output_dir: str = None) -> list: """Process multiple room pairs""" if output_dir is None: output_dir = self.config.OUTPUT_DIR results = [] for i, (user_room, inspiration_room) in enumerate(zip(user_rooms, inspiration_rooms)): print(f"Processing pair {i+1}/{len(user_rooms)}") try: result = self.transfer_style( user_room, inspiration_room, os.path.join(output_dir, f"result_{i+1}.jpg") ) results.append(result) except Exception as e: print(f"Error processing pair {i+1}: {e}") results.append(None) return results