|
|
"""
|
|
|
Main Style Transfer Pipeline for Interior Design
|
|
|
Combines segmentation, style extraction, and SDXL generation
|
|
|
"""
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from PIL import Image
|
|
|
import os
|
|
|
from typing import Tuple, Dict, Any, Optional
|
|
|
from diffusers import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLControlNetPipeline
|
|
|
from diffusers.utils import load_image
|
|
|
from controlnet_aux import OpenposeDetector
|
|
|
import cv2
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
from config import Config
|
|
|
from segmentation import RoomSegmentation
|
|
|
from style_extractor import StyleExtractor
|
|
|
|
|
|
class InteriorStyleTransferPipeline:
|
|
|
def __init__(self, config: Config):
|
|
|
self.config = config
|
|
|
self.device = config.DEVICE
|
|
|
|
|
|
|
|
|
self.segmentation = RoomSegmentation(device=self.device)
|
|
|
self.style_extractor = StyleExtractor()
|
|
|
|
|
|
|
|
|
self.sdxl_pipeline = None
|
|
|
self.controlnet_pipeline = None
|
|
|
self._load_models()
|
|
|
|
|
|
def _load_models(self):
|
|
|
"""Load SDXL and ControlNet models"""
|
|
|
try:
|
|
|
print("Loading SDXL Img2Img pipeline...")
|
|
|
self.sdxl_pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
|
|
self.config.SDXL_MODEL_ID,
|
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
|
|
use_safetensors=True,
|
|
|
variant="fp16" if self.device == "cuda" else None
|
|
|
)
|
|
|
|
|
|
if self.device == "cuda":
|
|
|
self.sdxl_pipeline.enable_xformers_memory_efficient_attention()
|
|
|
self.sdxl_pipeline.enable_model_cpu_offload()
|
|
|
|
|
|
self.sdxl_pipeline.to(self.device)
|
|
|
|
|
|
print("Loading ControlNet pipeline...")
|
|
|
self.controlnet_pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
|
|
|
self.config.SDXL_MODEL_ID,
|
|
|
controlnet=self.config.CONTROLNET_MODEL_ID,
|
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
|
|
use_safetensors=True,
|
|
|
variant="fp16" if self.device == "cuda" else None
|
|
|
)
|
|
|
|
|
|
if self.device == "cuda":
|
|
|
self.controlnet_pipeline.enable_xformers_memory_efficient_attention()
|
|
|
self.controlnet_pipeline.enable_model_cpu_offload()
|
|
|
|
|
|
self.controlnet_pipeline.to(self.device)
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"Warning: Could not load full pipeline: {e}")
|
|
|
print("Falling back to basic SDXL pipeline")
|
|
|
self.controlnet_pipeline = None
|
|
|
|
|
|
def transfer_style(self, user_room_path: str, inspiration_room_path: str,
|
|
|
output_path: str = None) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Main method to transfer style from inspiration room to user room
|
|
|
|
|
|
Args:
|
|
|
user_room_path: Path to user's room image
|
|
|
inspiration_room_path: Path to inspiration room image
|
|
|
output_path: Path to save output image
|
|
|
|
|
|
Returns:
|
|
|
Dictionary containing results and metadata
|
|
|
"""
|
|
|
print("Starting interior style transfer...")
|
|
|
|
|
|
|
|
|
user_room = cv2.imread(user_room_path)
|
|
|
inspiration_room = cv2.imread(inspiration_room_path)
|
|
|
|
|
|
if user_room is None or inspiration_room is None:
|
|
|
raise ValueError("Could not load one or both images")
|
|
|
|
|
|
|
|
|
user_room = cv2.resize(user_room, (self.config.IMAGE_SIZE, self.config.IMAGE_SIZE))
|
|
|
inspiration_room = cv2.resize(inspiration_room, (self.config.IMAGE_SIZE, self.config.IMAGE_SIZE))
|
|
|
|
|
|
|
|
|
print("Segmenting user room...")
|
|
|
user_masks = self.segmentation.segment_room(user_room)
|
|
|
preservation_mask = self.segmentation.create_preservation_mask(
|
|
|
user_masks, self.config.PRESERVE_CLASSES
|
|
|
)
|
|
|
|
|
|
|
|
|
print("Extracting style from inspiration room...")
|
|
|
style_info = self.style_extractor.extract_style(inspiration_room)
|
|
|
style_prompt = self.style_extractor.generate_style_prompt(style_info)
|
|
|
|
|
|
|
|
|
enhanced_prompt = self._create_enhanced_prompt(style_info, user_room)
|
|
|
|
|
|
|
|
|
print("Applying style transfer...")
|
|
|
result_image = self._apply_style_transfer(
|
|
|
user_room, inspiration_room, preservation_mask,
|
|
|
enhanced_prompt, style_info
|
|
|
)
|
|
|
|
|
|
|
|
|
print("Post-processing result...")
|
|
|
final_result = self._post_process_result(
|
|
|
result_image, user_room, preservation_mask, style_info
|
|
|
)
|
|
|
|
|
|
|
|
|
if output_path is None:
|
|
|
output_path = os.path.join(self.config.OUTPUT_DIR, "style_transfer_result.jpg")
|
|
|
|
|
|
cv2.imwrite(output_path, final_result)
|
|
|
|
|
|
|
|
|
style_analysis_path = output_path.replace('.jpg', '_style_analysis.json')
|
|
|
self.style_extractor.save_style_analysis(style_info, style_analysis_path)
|
|
|
|
|
|
|
|
|
viz_path = output_path.replace('.jpg', '_analysis.jpg')
|
|
|
analysis_viz = self.style_extractor.visualize_style_analysis(inspiration_room, style_info)
|
|
|
cv2.imwrite(viz_path, cv2.cvtColor(analysis_viz, cv2.COLOR_RGB2BGR))
|
|
|
|
|
|
print(f"Style transfer completed! Results saved to {output_path}")
|
|
|
|
|
|
return {
|
|
|
'result_image': final_result,
|
|
|
'style_info': style_info,
|
|
|
'preservation_mask': preservation_mask,
|
|
|
'enhanced_prompt': enhanced_prompt,
|
|
|
'output_path': output_path
|
|
|
}
|
|
|
|
|
|
def _create_enhanced_prompt(self, style_info: Dict[str, Any],
|
|
|
user_room: np.ndarray) -> str:
|
|
|
"""Create an enhanced prompt combining style and room context"""
|
|
|
base_prompt = style_info.get('style_category', 'interior design')
|
|
|
color_temp = style_info['color_palette']['color_temperature']
|
|
|
tone = style_info['color_palette']['overall_tone']
|
|
|
|
|
|
|
|
|
user_style = self.style_extractor.extract_style(user_room)
|
|
|
user_tone = user_style['color_palette']['overall_tone']
|
|
|
|
|
|
|
|
|
enhanced_prompt = f"professional interior design photography, {base_prompt} style, "
|
|
|
enhanced_prompt += f"{color_temp} color palette, {tone} lighting, "
|
|
|
enhanced_prompt += f"high quality furniture, elegant decorations, "
|
|
|
enhanced_prompt += f"realistic textures, professional lighting, "
|
|
|
enhanced_prompt += f"architectural photography, 8k resolution, "
|
|
|
enhanced_prompt += f"detailed interior design, magazine quality"
|
|
|
|
|
|
return enhanced_prompt
|
|
|
|
|
|
def _apply_style_transfer(self, user_room: np.ndarray, inspiration_room: np.ndarray,
|
|
|
preservation_mask: np.ndarray, enhanced_prompt: str,
|
|
|
style_info: Dict[str, Any]) -> np.ndarray:
|
|
|
"""Apply style transfer using SDXL with structure preservation"""
|
|
|
|
|
|
|
|
|
user_room_pil = Image.fromarray(cv2.cvtColor(user_room, cv2.COLOR_BGR2RGB))
|
|
|
inspiration_room_pil = Image.fromarray(cv2.cvtColor(inspiration_room, cv2.COLOR_BGR2RGB))
|
|
|
|
|
|
|
|
|
negative_prompt = "blurry, low quality, distorted, unrealistic, poor lighting, "
|
|
|
negative_prompt += "bad architecture, ugly furniture, cluttered, messy"
|
|
|
|
|
|
|
|
|
if self.controlnet_pipeline is not None:
|
|
|
result = self._controlnet_style_transfer(
|
|
|
user_room_pil, enhanced_prompt, negative_prompt, preservation_mask
|
|
|
)
|
|
|
else:
|
|
|
result = self._basic_style_transfer(
|
|
|
user_room_pil, enhanced_prompt, negative_prompt
|
|
|
)
|
|
|
|
|
|
return cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
def _controlnet_style_transfer(self, user_room_pil: Image.Image,
|
|
|
enhanced_prompt: str, negative_prompt: str,
|
|
|
preservation_mask: np.ndarray) -> Image.Image:
|
|
|
"""Use ControlNet for better structure preservation"""
|
|
|
|
|
|
|
|
|
control_image = Image.fromarray(preservation_mask * 255)
|
|
|
|
|
|
|
|
|
result = self.controlnet_pipeline(
|
|
|
prompt=enhanced_prompt,
|
|
|
negative_prompt=negative_prompt,
|
|
|
image=user_room_pil,
|
|
|
control_image=control_image,
|
|
|
num_inference_steps=self.config.NUM_INFERENCE_STEPS,
|
|
|
guidance_scale=self.config.GUIDANCE_SCALE,
|
|
|
strength=self.config.STRENGTH,
|
|
|
controlnet_conditioning_scale=0.8
|
|
|
).images[0]
|
|
|
|
|
|
return result
|
|
|
|
|
|
def _basic_style_transfer(self, user_room_pil: Image.Image,
|
|
|
enhanced_prompt: str, negative_prompt: str) -> Image.Image:
|
|
|
"""Basic style transfer using SDXL Img2Img"""
|
|
|
|
|
|
result = self.sdxl_pipeline(
|
|
|
prompt=enhanced_prompt,
|
|
|
negative_prompt=negative_prompt,
|
|
|
image=user_room_pil,
|
|
|
num_inference_steps=self.config.NUM_INFERENCE_STEPS,
|
|
|
guidance_scale=self.config.GUIDANCE_SCALE,
|
|
|
strength=self.config.STRENGTH
|
|
|
).images[0]
|
|
|
|
|
|
return result
|
|
|
|
|
|
def _post_process_result(self, result_image: np.ndarray, user_room: np.ndarray,
|
|
|
preservation_mask: np.ndarray, style_info: Dict[str, Any]) -> np.ndarray:
|
|
|
"""Post-process the result for better integration"""
|
|
|
|
|
|
|
|
|
alpha = self.config.BLEND_ALPHA
|
|
|
|
|
|
|
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20, 20))
|
|
|
smooth_mask = cv2.dilate(preservation_mask, kernel, iterations=1)
|
|
|
smooth_mask = cv2.GaussianBlur(smooth_mask.astype(np.float32), (21, 21), 0)
|
|
|
smooth_mask = smooth_mask / 255.0
|
|
|
|
|
|
|
|
|
blended = (result_image * (1 - smooth_mask[..., np.newaxis]) +
|
|
|
user_room * smooth_mask[..., np.newaxis]).astype(np.uint8)
|
|
|
|
|
|
|
|
|
corrected = self._apply_color_correction(blended, style_info)
|
|
|
|
|
|
|
|
|
enhanced = self._enhance_final_result(corrected)
|
|
|
|
|
|
return enhanced
|
|
|
|
|
|
def _apply_color_correction(self, image: np.ndarray,
|
|
|
style_info: Dict[str, Any]) -> np.ndarray:
|
|
|
"""Apply color correction to match inspiration style"""
|
|
|
|
|
|
|
|
|
inspiration_colors = np.array(style_info['color_palette']['dominant_colors'])
|
|
|
inspiration_proportions = np.array(style_info['color_palette']['proportions'])
|
|
|
|
|
|
|
|
|
target_temp = style_info['color_palette']['color_temperature']
|
|
|
|
|
|
|
|
|
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
|
|
|
|
|
|
if target_temp == "warm":
|
|
|
|
|
|
lab[:, :, 1] = np.clip(lab[:, :, 1] * 1.1, 0, 255)
|
|
|
lab[:, :, 2] = np.clip(lab[:, :, 2] * 0.9, 0, 255)
|
|
|
elif target_temp == "cool":
|
|
|
|
|
|
lab[:, :, 1] = np.clip(lab[:, :, 1] * 0.9, 0, 255)
|
|
|
lab[:, :, 2] = np.clip(lab[:, :, 2] * 1.1, 0, 255)
|
|
|
|
|
|
|
|
|
corrected = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
|
|
|
|
|
|
return corrected
|
|
|
|
|
|
def _enhance_final_result(self, image: np.ndarray) -> np.ndarray:
|
|
|
"""Apply final enhancements to the result"""
|
|
|
|
|
|
|
|
|
kernel = np.array([[-1, -1, -1],
|
|
|
[-1, 9, -1],
|
|
|
[-1, -1, -1]])
|
|
|
sharpened = cv2.filter2D(image, -1, kernel)
|
|
|
|
|
|
|
|
|
enhanced = cv2.addWeighted(image, 0.7, sharpened, 0.3, 0)
|
|
|
|
|
|
|
|
|
lab = cv2.cvtColor(enhanced, cv2.COLOR_BGR2LAB)
|
|
|
lab[:, :, 0] = cv2.equalizeHist(lab[:, :, 0])
|
|
|
enhanced = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
|
|
|
|
|
|
return enhanced
|
|
|
|
|
|
def batch_process(self, user_rooms: list, inspiration_rooms: list,
|
|
|
output_dir: str = None) -> list:
|
|
|
"""Process multiple room pairs"""
|
|
|
|
|
|
if output_dir is None:
|
|
|
output_dir = self.config.OUTPUT_DIR
|
|
|
|
|
|
results = []
|
|
|
|
|
|
for i, (user_room, inspiration_room) in enumerate(zip(user_rooms, inspiration_rooms)):
|
|
|
print(f"Processing pair {i+1}/{len(user_rooms)}")
|
|
|
|
|
|
try:
|
|
|
result = self.transfer_style(
|
|
|
user_room, inspiration_room,
|
|
|
os.path.join(output_dir, f"result_{i+1}.jpg")
|
|
|
)
|
|
|
results.append(result)
|
|
|
except Exception as e:
|
|
|
print(f"Error processing pair {i+1}: {e}")
|
|
|
results.append(None)
|
|
|
|
|
|
return results
|
|
|
|