Interior-Images-From-Inspiration / style_transfer_pipeline.py
Lasya18's picture
Upload 3 files
01c9ed4 verified
raw
history blame
14.3 kB
"""
Main Style Transfer Pipeline for Interior Design
Combines segmentation, style extraction, and SDXL generation
"""
import cv2
import numpy as np
import torch
from PIL import Image
import os
from typing import Tuple, Dict, Any, Optional
from diffusers import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLControlNetPipeline
from diffusers.utils import load_image
from controlnet_aux import OpenposeDetector
import cv2
from tqdm import tqdm
from config import Config
from segmentation import RoomSegmentation
from style_extractor import StyleExtractor
class InteriorStyleTransferPipeline:
def __init__(self, config: Config):
self.config = config
self.device = config.DEVICE
# Initialize components
self.segmentation = RoomSegmentation(device=self.device)
self.style_extractor = StyleExtractor()
# Initialize SDXL pipeline
self.sdxl_pipeline = None
self.controlnet_pipeline = None
self._load_models()
def _load_models(self):
"""Load SDXL and ControlNet models"""
try:
print("Loading SDXL Img2Img pipeline...")
self.sdxl_pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
self.config.SDXL_MODEL_ID,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
use_safetensors=True,
variant="fp16" if self.device == "cuda" else None
)
if self.device == "cuda":
self.sdxl_pipeline.enable_xformers_memory_efficient_attention()
self.sdxl_pipeline.enable_model_cpu_offload()
self.sdxl_pipeline.to(self.device)
print("Loading ControlNet pipeline...")
self.controlnet_pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
self.config.SDXL_MODEL_ID,
controlnet=self.config.CONTROLNET_MODEL_ID,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
use_safetensors=True,
variant="fp16" if self.device == "cuda" else None
)
if self.device == "cuda":
self.controlnet_pipeline.enable_xformers_memory_efficient_attention()
self.controlnet_pipeline.enable_model_cpu_offload()
self.controlnet_pipeline.to(self.device)
except Exception as e:
print(f"Warning: Could not load full pipeline: {e}")
print("Falling back to basic SDXL pipeline")
self.controlnet_pipeline = None
def transfer_style(self, user_room_path: str, inspiration_room_path: str,
output_path: str = None) -> Dict[str, Any]:
"""
Main method to transfer style from inspiration room to user room
Args:
user_room_path: Path to user's room image
inspiration_room_path: Path to inspiration room image
output_path: Path to save output image
Returns:
Dictionary containing results and metadata
"""
print("Starting interior style transfer...")
# Load images
user_room = cv2.imread(user_room_path)
inspiration_room = cv2.imread(inspiration_room_path)
if user_room is None or inspiration_room is None:
raise ValueError("Could not load one or both images")
# Resize images to match
user_room = cv2.resize(user_room, (self.config.IMAGE_SIZE, self.config.IMAGE_SIZE))
inspiration_room = cv2.resize(inspiration_room, (self.config.IMAGE_SIZE, self.config.IMAGE_SIZE))
# Step 1: Segment user room to preserve structure
print("Segmenting user room...")
user_masks = self.segmentation.segment_room(user_room)
preservation_mask = self.segmentation.create_preservation_mask(
user_masks, self.config.PRESERVE_CLASSES
)
# Step 2: Extract style from inspiration room
print("Extracting style from inspiration room...")
style_info = self.style_extractor.extract_style(inspiration_room)
style_prompt = self.style_extractor.generate_style_prompt(style_info)
# Step 3: Generate enhanced prompt
enhanced_prompt = self._create_enhanced_prompt(style_info, user_room)
# Step 4: Apply style transfer with structure preservation
print("Applying style transfer...")
result_image = self._apply_style_transfer(
user_room, inspiration_room, preservation_mask,
enhanced_prompt, style_info
)
# Step 5: Post-process and blend
print("Post-processing result...")
final_result = self._post_process_result(
result_image, user_room, preservation_mask, style_info
)
# Save results
if output_path is None:
output_path = os.path.join(self.config.OUTPUT_DIR, "style_transfer_result.jpg")
cv2.imwrite(output_path, final_result)
# Save style analysis
style_analysis_path = output_path.replace('.jpg', '_style_analysis.json')
self.style_extractor.save_style_analysis(style_info, style_analysis_path)
# Create visualization
viz_path = output_path.replace('.jpg', '_analysis.jpg')
analysis_viz = self.style_extractor.visualize_style_analysis(inspiration_room, style_info)
cv2.imwrite(viz_path, cv2.cvtColor(analysis_viz, cv2.COLOR_RGB2BGR))
print(f"Style transfer completed! Results saved to {output_path}")
return {
'result_image': final_result,
'style_info': style_info,
'preservation_mask': preservation_mask,
'enhanced_prompt': enhanced_prompt,
'output_path': output_path
}
def _create_enhanced_prompt(self, style_info: Dict[str, Any],
user_room: np.ndarray) -> str:
"""Create an enhanced prompt combining style and room context"""
base_prompt = style_info.get('style_category', 'interior design')
color_temp = style_info['color_palette']['color_temperature']
tone = style_info['color_palette']['overall_tone']
# Analyze user room to add context
user_style = self.style_extractor.extract_style(user_room)
user_tone = user_style['color_palette']['overall_tone']
# Create context-aware prompt
enhanced_prompt = f"professional interior design photography, {base_prompt} style, "
enhanced_prompt += f"{color_temp} color palette, {tone} lighting, "
enhanced_prompt += f"high quality furniture, elegant decorations, "
enhanced_prompt += f"realistic textures, professional lighting, "
enhanced_prompt += f"architectural photography, 8k resolution, "
enhanced_prompt += f"detailed interior design, magazine quality"
return enhanced_prompt
def _apply_style_transfer(self, user_room: np.ndarray, inspiration_room: np.ndarray,
preservation_mask: np.ndarray, enhanced_prompt: str,
style_info: Dict[str, Any]) -> np.ndarray:
"""Apply style transfer using SDXL with structure preservation"""
# Convert to PIL images
user_room_pil = Image.fromarray(cv2.cvtColor(user_room, cv2.COLOR_BGR2RGB))
inspiration_room_pil = Image.fromarray(cv2.cvtColor(inspiration_room, cv2.COLOR_BGR2RGB))
# Create negative prompt
negative_prompt = "blurry, low quality, distorted, unrealistic, poor lighting, "
negative_prompt += "bad architecture, ugly furniture, cluttered, messy"
# Use ControlNet if available for better structure preservation
if self.controlnet_pipeline is not None:
result = self._controlnet_style_transfer(
user_room_pil, enhanced_prompt, negative_prompt, preservation_mask
)
else:
result = self._basic_style_transfer(
user_room_pil, enhanced_prompt, negative_prompt
)
return cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR)
def _controlnet_style_transfer(self, user_room_pil: Image.Image,
enhanced_prompt: str, negative_prompt: str,
preservation_mask: np.ndarray) -> Image.Image:
"""Use ControlNet for better structure preservation"""
# Create control image from preservation mask
control_image = Image.fromarray(preservation_mask * 255)
# Generate with ControlNet
result = self.controlnet_pipeline(
prompt=enhanced_prompt,
negative_prompt=negative_prompt,
image=user_room_pil,
control_image=control_image,
num_inference_steps=self.config.NUM_INFERENCE_STEPS,
guidance_scale=self.config.GUIDANCE_SCALE,
strength=self.config.STRENGTH,
controlnet_conditioning_scale=0.8
).images[0]
return result
def _basic_style_transfer(self, user_room_pil: Image.Image,
enhanced_prompt: str, negative_prompt: str) -> Image.Image:
"""Basic style transfer using SDXL Img2Img"""
result = self.sdxl_pipeline(
prompt=enhanced_prompt,
negative_prompt=negative_prompt,
image=user_room_pil,
num_inference_steps=self.config.NUM_INFERENCE_STEPS,
guidance_scale=self.config.GUIDANCE_SCALE,
strength=self.config.STRENGTH
).images[0]
return result
def _post_process_result(self, result_image: np.ndarray, user_room: np.ndarray,
preservation_mask: np.ndarray, style_info: Dict[str, Any]) -> np.ndarray:
"""Post-process the result for better integration"""
# Blend with original structure where needed
alpha = self.config.BLEND_ALPHA
# Create smooth blending mask
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20, 20))
smooth_mask = cv2.dilate(preservation_mask, kernel, iterations=1)
smooth_mask = cv2.GaussianBlur(smooth_mask.astype(np.float32), (21, 21), 0)
smooth_mask = smooth_mask / 255.0
# Blend images
blended = (result_image * (1 - smooth_mask[..., np.newaxis]) +
user_room * smooth_mask[..., np.newaxis]).astype(np.uint8)
# Color correction to match inspiration style
corrected = self._apply_color_correction(blended, style_info)
# Final enhancement
enhanced = self._enhance_final_result(corrected)
return enhanced
def _apply_color_correction(self, image: np.ndarray,
style_info: Dict[str, Any]) -> np.ndarray:
"""Apply color correction to match inspiration style"""
# Get dominant colors from inspiration
inspiration_colors = np.array(style_info['color_palette']['dominant_colors'])
inspiration_proportions = np.array(style_info['color_palette']['proportions'])
# Calculate target color temperature
target_temp = style_info['color_palette']['color_temperature']
# Convert to LAB for better color manipulation
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
if target_temp == "warm":
# Increase red channel, decrease blue
lab[:, :, 1] = np.clip(lab[:, :, 1] * 1.1, 0, 255) # Increase a (red-green)
lab[:, :, 2] = np.clip(lab[:, :, 2] * 0.9, 0, 255) # Decrease b (blue-yellow)
elif target_temp == "cool":
# Increase blue channel, decrease red
lab[:, :, 1] = np.clip(lab[:, :, 1] * 0.9, 0, 255) # Decrease a
lab[:, :, 2] = np.clip(lab[:, :, 2] * 1.1, 0, 255) # Increase b
# Convert back to BGR
corrected = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
return corrected
def _enhance_final_result(self, image: np.ndarray) -> np.ndarray:
"""Apply final enhancements to the result"""
# Slight sharpening
kernel = np.array([[-1, -1, -1],
[-1, 9, -1],
[-1, -1, -1]])
sharpened = cv2.filter2D(image, -1, kernel)
# Blend with original for natural look
enhanced = cv2.addWeighted(image, 0.7, sharpened, 0.3, 0)
# Slight contrast enhancement
lab = cv2.cvtColor(enhanced, cv2.COLOR_BGR2LAB)
lab[:, :, 0] = cv2.equalizeHist(lab[:, :, 0])
enhanced = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
return enhanced
def batch_process(self, user_rooms: list, inspiration_rooms: list,
output_dir: str = None) -> list:
"""Process multiple room pairs"""
if output_dir is None:
output_dir = self.config.OUTPUT_DIR
results = []
for i, (user_room, inspiration_room) in enumerate(zip(user_rooms, inspiration_rooms)):
print(f"Processing pair {i+1}/{len(user_rooms)}")
try:
result = self.transfer_style(
user_room, inspiration_room,
os.path.join(output_dir, f"result_{i+1}.jpg")
)
results.append(result)
except Exception as e:
print(f"Error processing pair {i+1}: {e}")
results.append(None)
return results