File size: 14,260 Bytes
01c9ed4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 |
"""
Main Style Transfer Pipeline for Interior Design
Combines segmentation, style extraction, and SDXL generation
"""
import cv2
import numpy as np
import torch
from PIL import Image
import os
from typing import Tuple, Dict, Any, Optional
from diffusers import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLControlNetPipeline
from diffusers.utils import load_image
from controlnet_aux import OpenposeDetector
import cv2
from tqdm import tqdm
from config import Config
from segmentation import RoomSegmentation
from style_extractor import StyleExtractor
class InteriorStyleTransferPipeline:
def __init__(self, config: Config):
self.config = config
self.device = config.DEVICE
# Initialize components
self.segmentation = RoomSegmentation(device=self.device)
self.style_extractor = StyleExtractor()
# Initialize SDXL pipeline
self.sdxl_pipeline = None
self.controlnet_pipeline = None
self._load_models()
def _load_models(self):
"""Load SDXL and ControlNet models"""
try:
print("Loading SDXL Img2Img pipeline...")
self.sdxl_pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
self.config.SDXL_MODEL_ID,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
use_safetensors=True,
variant="fp16" if self.device == "cuda" else None
)
if self.device == "cuda":
self.sdxl_pipeline.enable_xformers_memory_efficient_attention()
self.sdxl_pipeline.enable_model_cpu_offload()
self.sdxl_pipeline.to(self.device)
print("Loading ControlNet pipeline...")
self.controlnet_pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
self.config.SDXL_MODEL_ID,
controlnet=self.config.CONTROLNET_MODEL_ID,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
use_safetensors=True,
variant="fp16" if self.device == "cuda" else None
)
if self.device == "cuda":
self.controlnet_pipeline.enable_xformers_memory_efficient_attention()
self.controlnet_pipeline.enable_model_cpu_offload()
self.controlnet_pipeline.to(self.device)
except Exception as e:
print(f"Warning: Could not load full pipeline: {e}")
print("Falling back to basic SDXL pipeline")
self.controlnet_pipeline = None
def transfer_style(self, user_room_path: str, inspiration_room_path: str,
output_path: str = None) -> Dict[str, Any]:
"""
Main method to transfer style from inspiration room to user room
Args:
user_room_path: Path to user's room image
inspiration_room_path: Path to inspiration room image
output_path: Path to save output image
Returns:
Dictionary containing results and metadata
"""
print("Starting interior style transfer...")
# Load images
user_room = cv2.imread(user_room_path)
inspiration_room = cv2.imread(inspiration_room_path)
if user_room is None or inspiration_room is None:
raise ValueError("Could not load one or both images")
# Resize images to match
user_room = cv2.resize(user_room, (self.config.IMAGE_SIZE, self.config.IMAGE_SIZE))
inspiration_room = cv2.resize(inspiration_room, (self.config.IMAGE_SIZE, self.config.IMAGE_SIZE))
# Step 1: Segment user room to preserve structure
print("Segmenting user room...")
user_masks = self.segmentation.segment_room(user_room)
preservation_mask = self.segmentation.create_preservation_mask(
user_masks, self.config.PRESERVE_CLASSES
)
# Step 2: Extract style from inspiration room
print("Extracting style from inspiration room...")
style_info = self.style_extractor.extract_style(inspiration_room)
style_prompt = self.style_extractor.generate_style_prompt(style_info)
# Step 3: Generate enhanced prompt
enhanced_prompt = self._create_enhanced_prompt(style_info, user_room)
# Step 4: Apply style transfer with structure preservation
print("Applying style transfer...")
result_image = self._apply_style_transfer(
user_room, inspiration_room, preservation_mask,
enhanced_prompt, style_info
)
# Step 5: Post-process and blend
print("Post-processing result...")
final_result = self._post_process_result(
result_image, user_room, preservation_mask, style_info
)
# Save results
if output_path is None:
output_path = os.path.join(self.config.OUTPUT_DIR, "style_transfer_result.jpg")
cv2.imwrite(output_path, final_result)
# Save style analysis
style_analysis_path = output_path.replace('.jpg', '_style_analysis.json')
self.style_extractor.save_style_analysis(style_info, style_analysis_path)
# Create visualization
viz_path = output_path.replace('.jpg', '_analysis.jpg')
analysis_viz = self.style_extractor.visualize_style_analysis(inspiration_room, style_info)
cv2.imwrite(viz_path, cv2.cvtColor(analysis_viz, cv2.COLOR_RGB2BGR))
print(f"Style transfer completed! Results saved to {output_path}")
return {
'result_image': final_result,
'style_info': style_info,
'preservation_mask': preservation_mask,
'enhanced_prompt': enhanced_prompt,
'output_path': output_path
}
def _create_enhanced_prompt(self, style_info: Dict[str, Any],
user_room: np.ndarray) -> str:
"""Create an enhanced prompt combining style and room context"""
base_prompt = style_info.get('style_category', 'interior design')
color_temp = style_info['color_palette']['color_temperature']
tone = style_info['color_palette']['overall_tone']
# Analyze user room to add context
user_style = self.style_extractor.extract_style(user_room)
user_tone = user_style['color_palette']['overall_tone']
# Create context-aware prompt
enhanced_prompt = f"professional interior design photography, {base_prompt} style, "
enhanced_prompt += f"{color_temp} color palette, {tone} lighting, "
enhanced_prompt += f"high quality furniture, elegant decorations, "
enhanced_prompt += f"realistic textures, professional lighting, "
enhanced_prompt += f"architectural photography, 8k resolution, "
enhanced_prompt += f"detailed interior design, magazine quality"
return enhanced_prompt
def _apply_style_transfer(self, user_room: np.ndarray, inspiration_room: np.ndarray,
preservation_mask: np.ndarray, enhanced_prompt: str,
style_info: Dict[str, Any]) -> np.ndarray:
"""Apply style transfer using SDXL with structure preservation"""
# Convert to PIL images
user_room_pil = Image.fromarray(cv2.cvtColor(user_room, cv2.COLOR_BGR2RGB))
inspiration_room_pil = Image.fromarray(cv2.cvtColor(inspiration_room, cv2.COLOR_BGR2RGB))
# Create negative prompt
negative_prompt = "blurry, low quality, distorted, unrealistic, poor lighting, "
negative_prompt += "bad architecture, ugly furniture, cluttered, messy"
# Use ControlNet if available for better structure preservation
if self.controlnet_pipeline is not None:
result = self._controlnet_style_transfer(
user_room_pil, enhanced_prompt, negative_prompt, preservation_mask
)
else:
result = self._basic_style_transfer(
user_room_pil, enhanced_prompt, negative_prompt
)
return cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR)
def _controlnet_style_transfer(self, user_room_pil: Image.Image,
enhanced_prompt: str, negative_prompt: str,
preservation_mask: np.ndarray) -> Image.Image:
"""Use ControlNet for better structure preservation"""
# Create control image from preservation mask
control_image = Image.fromarray(preservation_mask * 255)
# Generate with ControlNet
result = self.controlnet_pipeline(
prompt=enhanced_prompt,
negative_prompt=negative_prompt,
image=user_room_pil,
control_image=control_image,
num_inference_steps=self.config.NUM_INFERENCE_STEPS,
guidance_scale=self.config.GUIDANCE_SCALE,
strength=self.config.STRENGTH,
controlnet_conditioning_scale=0.8
).images[0]
return result
def _basic_style_transfer(self, user_room_pil: Image.Image,
enhanced_prompt: str, negative_prompt: str) -> Image.Image:
"""Basic style transfer using SDXL Img2Img"""
result = self.sdxl_pipeline(
prompt=enhanced_prompt,
negative_prompt=negative_prompt,
image=user_room_pil,
num_inference_steps=self.config.NUM_INFERENCE_STEPS,
guidance_scale=self.config.GUIDANCE_SCALE,
strength=self.config.STRENGTH
).images[0]
return result
def _post_process_result(self, result_image: np.ndarray, user_room: np.ndarray,
preservation_mask: np.ndarray, style_info: Dict[str, Any]) -> np.ndarray:
"""Post-process the result for better integration"""
# Blend with original structure where needed
alpha = self.config.BLEND_ALPHA
# Create smooth blending mask
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20, 20))
smooth_mask = cv2.dilate(preservation_mask, kernel, iterations=1)
smooth_mask = cv2.GaussianBlur(smooth_mask.astype(np.float32), (21, 21), 0)
smooth_mask = smooth_mask / 255.0
# Blend images
blended = (result_image * (1 - smooth_mask[..., np.newaxis]) +
user_room * smooth_mask[..., np.newaxis]).astype(np.uint8)
# Color correction to match inspiration style
corrected = self._apply_color_correction(blended, style_info)
# Final enhancement
enhanced = self._enhance_final_result(corrected)
return enhanced
def _apply_color_correction(self, image: np.ndarray,
style_info: Dict[str, Any]) -> np.ndarray:
"""Apply color correction to match inspiration style"""
# Get dominant colors from inspiration
inspiration_colors = np.array(style_info['color_palette']['dominant_colors'])
inspiration_proportions = np.array(style_info['color_palette']['proportions'])
# Calculate target color temperature
target_temp = style_info['color_palette']['color_temperature']
# Convert to LAB for better color manipulation
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
if target_temp == "warm":
# Increase red channel, decrease blue
lab[:, :, 1] = np.clip(lab[:, :, 1] * 1.1, 0, 255) # Increase a (red-green)
lab[:, :, 2] = np.clip(lab[:, :, 2] * 0.9, 0, 255) # Decrease b (blue-yellow)
elif target_temp == "cool":
# Increase blue channel, decrease red
lab[:, :, 1] = np.clip(lab[:, :, 1] * 0.9, 0, 255) # Decrease a
lab[:, :, 2] = np.clip(lab[:, :, 2] * 1.1, 0, 255) # Increase b
# Convert back to BGR
corrected = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
return corrected
def _enhance_final_result(self, image: np.ndarray) -> np.ndarray:
"""Apply final enhancements to the result"""
# Slight sharpening
kernel = np.array([[-1, -1, -1],
[-1, 9, -1],
[-1, -1, -1]])
sharpened = cv2.filter2D(image, -1, kernel)
# Blend with original for natural look
enhanced = cv2.addWeighted(image, 0.7, sharpened, 0.3, 0)
# Slight contrast enhancement
lab = cv2.cvtColor(enhanced, cv2.COLOR_BGR2LAB)
lab[:, :, 0] = cv2.equalizeHist(lab[:, :, 0])
enhanced = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
return enhanced
def batch_process(self, user_rooms: list, inspiration_rooms: list,
output_dir: str = None) -> list:
"""Process multiple room pairs"""
if output_dir is None:
output_dir = self.config.OUTPUT_DIR
results = []
for i, (user_room, inspiration_room) in enumerate(zip(user_rooms, inspiration_rooms)):
print(f"Processing pair {i+1}/{len(user_rooms)}")
try:
result = self.transfer_style(
user_room, inspiration_room,
os.path.join(output_dir, f"result_{i+1}.jpg")
)
results.append(result)
except Exception as e:
print(f"Error processing pair {i+1}: {e}")
results.append(None)
return results
|