File size: 14,260 Bytes
01c9ed4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
"""

Main Style Transfer Pipeline for Interior Design

Combines segmentation, style extraction, and SDXL generation

"""
import cv2
import numpy as np
import torch
from PIL import Image
import os
from typing import Tuple, Dict, Any, Optional
from diffusers import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLControlNetPipeline
from diffusers.utils import load_image
from controlnet_aux import OpenposeDetector
import cv2
from tqdm import tqdm

from config import Config
from segmentation import RoomSegmentation
from style_extractor import StyleExtractor

class InteriorStyleTransferPipeline:
    def __init__(self, config: Config):
        self.config = config
        self.device = config.DEVICE
        
        # Initialize components
        self.segmentation = RoomSegmentation(device=self.device)
        self.style_extractor = StyleExtractor()
        
        # Initialize SDXL pipeline
        self.sdxl_pipeline = None
        self.controlnet_pipeline = None
        self._load_models()
        
    def _load_models(self):
        """Load SDXL and ControlNet models"""
        try:
            print("Loading SDXL Img2Img pipeline...")
            self.sdxl_pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
                self.config.SDXL_MODEL_ID,
                torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
                use_safetensors=True,
                variant="fp16" if self.device == "cuda" else None
            )
            
            if self.device == "cuda":
                self.sdxl_pipeline.enable_xformers_memory_efficient_attention()
                self.sdxl_pipeline.enable_model_cpu_offload()
            
            self.sdxl_pipeline.to(self.device)
            
            print("Loading ControlNet pipeline...")
            self.controlnet_pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
                self.config.SDXL_MODEL_ID,
                controlnet=self.config.CONTROLNET_MODEL_ID,
                torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
                use_safetensors=True,
                variant="fp16" if self.device == "cuda" else None
            )
            
            if self.device == "cuda":
                self.controlnet_pipeline.enable_xformers_memory_efficient_attention()
                self.controlnet_pipeline.enable_model_cpu_offload()
            
            self.controlnet_pipeline.to(self.device)
            
        except Exception as e:
            print(f"Warning: Could not load full pipeline: {e}")
            print("Falling back to basic SDXL pipeline")
            self.controlnet_pipeline = None
    
    def transfer_style(self, user_room_path: str, inspiration_room_path: str, 

                      output_path: str = None) -> Dict[str, Any]:
        """

        Main method to transfer style from inspiration room to user room

        

        Args:

            user_room_path: Path to user's room image

            inspiration_room_path: Path to inspiration room image

            output_path: Path to save output image

            

        Returns:

            Dictionary containing results and metadata

        """
        print("Starting interior style transfer...")
        
        # Load images
        user_room = cv2.imread(user_room_path)
        inspiration_room = cv2.imread(inspiration_room_path)
        
        if user_room is None or inspiration_room is None:
            raise ValueError("Could not load one or both images")
        
        # Resize images to match
        user_room = cv2.resize(user_room, (self.config.IMAGE_SIZE, self.config.IMAGE_SIZE))
        inspiration_room = cv2.resize(inspiration_room, (self.config.IMAGE_SIZE, self.config.IMAGE_SIZE))
        
        # Step 1: Segment user room to preserve structure
        print("Segmenting user room...")
        user_masks = self.segmentation.segment_room(user_room)
        preservation_mask = self.segmentation.create_preservation_mask(
            user_masks, self.config.PRESERVE_CLASSES
        )
        
        # Step 2: Extract style from inspiration room
        print("Extracting style from inspiration room...")
        style_info = self.style_extractor.extract_style(inspiration_room)
        style_prompt = self.style_extractor.generate_style_prompt(style_info)
        
        # Step 3: Generate enhanced prompt
        enhanced_prompt = self._create_enhanced_prompt(style_info, user_room)
        
        # Step 4: Apply style transfer with structure preservation
        print("Applying style transfer...")
        result_image = self._apply_style_transfer(
            user_room, inspiration_room, preservation_mask, 
            enhanced_prompt, style_info
        )
        
        # Step 5: Post-process and blend
        print("Post-processing result...")
        final_result = self._post_process_result(
            result_image, user_room, preservation_mask, style_info
        )
        
        # Save results
        if output_path is None:
            output_path = os.path.join(self.config.OUTPUT_DIR, "style_transfer_result.jpg")
        
        cv2.imwrite(output_path, final_result)
        
        # Save style analysis
        style_analysis_path = output_path.replace('.jpg', '_style_analysis.json')
        self.style_extractor.save_style_analysis(style_info, style_analysis_path)
        
        # Create visualization
        viz_path = output_path.replace('.jpg', '_analysis.jpg')
        analysis_viz = self.style_extractor.visualize_style_analysis(inspiration_room, style_info)
        cv2.imwrite(viz_path, cv2.cvtColor(analysis_viz, cv2.COLOR_RGB2BGR))
        
        print(f"Style transfer completed! Results saved to {output_path}")
        
        return {
            'result_image': final_result,
            'style_info': style_info,
            'preservation_mask': preservation_mask,
            'enhanced_prompt': enhanced_prompt,
            'output_path': output_path
        }
    
    def _create_enhanced_prompt(self, style_info: Dict[str, Any], 

                               user_room: np.ndarray) -> str:
        """Create an enhanced prompt combining style and room context"""
        base_prompt = style_info.get('style_category', 'interior design')
        color_temp = style_info['color_palette']['color_temperature']
        tone = style_info['color_palette']['overall_tone']
        
        # Analyze user room to add context
        user_style = self.style_extractor.extract_style(user_room)
        user_tone = user_style['color_palette']['overall_tone']
        
        # Create context-aware prompt
        enhanced_prompt = f"professional interior design photography, {base_prompt} style, "
        enhanced_prompt += f"{color_temp} color palette, {tone} lighting, "
        enhanced_prompt += f"high quality furniture, elegant decorations, "
        enhanced_prompt += f"realistic textures, professional lighting, "
        enhanced_prompt += f"architectural photography, 8k resolution, "
        enhanced_prompt += f"detailed interior design, magazine quality"
        
        return enhanced_prompt
    
    def _apply_style_transfer(self, user_room: np.ndarray, inspiration_room: np.ndarray,

                             preservation_mask: np.ndarray, enhanced_prompt: str,

                             style_info: Dict[str, Any]) -> np.ndarray:
        """Apply style transfer using SDXL with structure preservation"""
        
        # Convert to PIL images
        user_room_pil = Image.fromarray(cv2.cvtColor(user_room, cv2.COLOR_BGR2RGB))
        inspiration_room_pil = Image.fromarray(cv2.cvtColor(inspiration_room, cv2.COLOR_BGR2RGB))
        
        # Create negative prompt
        negative_prompt = "blurry, low quality, distorted, unrealistic, poor lighting, "
        negative_prompt += "bad architecture, ugly furniture, cluttered, messy"
        
        # Use ControlNet if available for better structure preservation
        if self.controlnet_pipeline is not None:
            result = self._controlnet_style_transfer(
                user_room_pil, enhanced_prompt, negative_prompt, preservation_mask
            )
        else:
            result = self._basic_style_transfer(
                user_room_pil, enhanced_prompt, negative_prompt
            )
        
        return cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR)
    
    def _controlnet_style_transfer(self, user_room_pil: Image.Image, 

                                  enhanced_prompt: str, negative_prompt: str,

                                  preservation_mask: np.ndarray) -> Image.Image:
        """Use ControlNet for better structure preservation"""
        
        # Create control image from preservation mask
        control_image = Image.fromarray(preservation_mask * 255)
        
        # Generate with ControlNet
        result = self.controlnet_pipeline(
            prompt=enhanced_prompt,
            negative_prompt=negative_prompt,
            image=user_room_pil,
            control_image=control_image,
            num_inference_steps=self.config.NUM_INFERENCE_STEPS,
            guidance_scale=self.config.GUIDANCE_SCALE,
            strength=self.config.STRENGTH,
            controlnet_conditioning_scale=0.8
        ).images[0]
        
        return result
    
    def _basic_style_transfer(self, user_room_pil: Image.Image, 

                             enhanced_prompt: str, negative_prompt: str) -> Image.Image:
        """Basic style transfer using SDXL Img2Img"""
        
        result = self.sdxl_pipeline(
            prompt=enhanced_prompt,
            negative_prompt=negative_prompt,
            image=user_room_pil,
            num_inference_steps=self.config.NUM_INFERENCE_STEPS,
            guidance_scale=self.config.GUIDANCE_SCALE,
            strength=self.config.STRENGTH
        ).images[0]
        
        return result
    
    def _post_process_result(self, result_image: np.ndarray, user_room: np.ndarray,

                           preservation_mask: np.ndarray, style_info: Dict[str, Any]) -> np.ndarray:
        """Post-process the result for better integration"""
        
        # Blend with original structure where needed
        alpha = self.config.BLEND_ALPHA
        
        # Create smooth blending mask
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20, 20))
        smooth_mask = cv2.dilate(preservation_mask, kernel, iterations=1)
        smooth_mask = cv2.GaussianBlur(smooth_mask.astype(np.float32), (21, 21), 0)
        smooth_mask = smooth_mask / 255.0
        
        # Blend images
        blended = (result_image * (1 - smooth_mask[..., np.newaxis]) + 
                  user_room * smooth_mask[..., np.newaxis]).astype(np.uint8)
        
        # Color correction to match inspiration style
        corrected = self._apply_color_correction(blended, style_info)
        
        # Final enhancement
        enhanced = self._enhance_final_result(corrected)
        
        return enhanced
    
    def _apply_color_correction(self, image: np.ndarray, 

                               style_info: Dict[str, Any]) -> np.ndarray:
        """Apply color correction to match inspiration style"""
        
        # Get dominant colors from inspiration
        inspiration_colors = np.array(style_info['color_palette']['dominant_colors'])
        inspiration_proportions = np.array(style_info['color_palette']['proportions'])
        
        # Calculate target color temperature
        target_temp = style_info['color_palette']['color_temperature']
        
        # Convert to LAB for better color manipulation
        lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
        
        if target_temp == "warm":
            # Increase red channel, decrease blue
            lab[:, :, 1] = np.clip(lab[:, :, 1] * 1.1, 0, 255)  # Increase a (red-green)
            lab[:, :, 2] = np.clip(lab[:, :, 2] * 0.9, 0, 255)  # Decrease b (blue-yellow)
        elif target_temp == "cool":
            # Increase blue channel, decrease red
            lab[:, :, 1] = np.clip(lab[:, :, 1] * 0.9, 0, 255)  # Decrease a
            lab[:, :, 2] = np.clip(lab[:, :, 2] * 1.1, 0, 255)  # Increase b
        
        # Convert back to BGR
        corrected = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
        
        return corrected
    
    def _enhance_final_result(self, image: np.ndarray) -> np.ndarray:
        """Apply final enhancements to the result"""
        
        # Slight sharpening
        kernel = np.array([[-1, -1, -1],
                          [-1,  9, -1],
                          [-1, -1, -1]])
        sharpened = cv2.filter2D(image, -1, kernel)
        
        # Blend with original for natural look
        enhanced = cv2.addWeighted(image, 0.7, sharpened, 0.3, 0)
        
        # Slight contrast enhancement
        lab = cv2.cvtColor(enhanced, cv2.COLOR_BGR2LAB)
        lab[:, :, 0] = cv2.equalizeHist(lab[:, :, 0])
        enhanced = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
        
        return enhanced
    
    def batch_process(self, user_rooms: list, inspiration_rooms: list, 

                     output_dir: str = None) -> list:
        """Process multiple room pairs"""
        
        if output_dir is None:
            output_dir = self.config.OUTPUT_DIR
        
        results = []
        
        for i, (user_room, inspiration_room) in enumerate(zip(user_rooms, inspiration_rooms)):
            print(f"Processing pair {i+1}/{len(user_rooms)}")
            
            try:
                result = self.transfer_style(
                    user_room, inspiration_room,
                    os.path.join(output_dir, f"result_{i+1}.jpg")
                )
                results.append(result)
            except Exception as e:
                print(f"Error processing pair {i+1}: {e}")
                results.append(None)
        
        return results