File size: 7,392 Bytes
01c9ed4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""

Room Segmentation Module for preserving structural elements

"""
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import segmentation_models_pytorch as smp
from typing import Tuple, Dict, List
import albumentations as A
from albumentations.pytorch import ToTensorV2

class RoomSegmentation:
    def __init__(self, device: str = "cuda"):
        self.device = device
        self.model = None
        self.class_names = [
            'wall', 'window', 'door', 'ceiling', 'floor', 
            'furniture', 'lighting', 'decoration', 'textile'
        ]
        self.load_model()
        
    def load_model(self):
        """Load pre-trained segmentation model"""
        try:
            # Using DeepLabV3+ with ResNet50 backbone
            self.model = smp.DeepLabV3Plus(
                encoder_name="resnet50",
                encoder_weights="imagenet",
                classes=len(self.class_names),
                activation="softmax"
            )
            self.model.to(self.device)
            self.model.eval()
        except Exception as e:
            print(f"Warning: Could not load segmentation model: {e}")
            print("Falling back to rule-based segmentation")
            self.model = None
    
    def segment_room(self, image: np.ndarray) -> Dict[str, np.ndarray]:
        """

        Segment room image into different components

        

        Args:

            image: Input room image (BGR format)

            

        Returns:

            Dictionary with class names as keys and binary masks as values

        """
        if self.model is not None:
            return self._ml_segmentation(image)
        else:
            return self._rule_based_segmentation(image)
    
    def _ml_segmentation(self, image: np.ndarray) -> Dict[str, np.ndarray]:
        """ML-based segmentation using pre-trained model"""
        # Preprocess image
        transform = A.Compose([
            A.Resize(512, 512),
            A.Normalize(),
            ToTensorV2()
        ])
        
        transformed = transform(image=image)
        input_tensor = transformed['image'].unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            output = self.model(input_tensor)
            predictions = F.softmax(output, dim=1)
            predictions = predictions.squeeze().cpu().numpy()
        
        # Convert to binary masks
        masks = {}
        for i, class_name in enumerate(self.class_names):
            mask = (predictions[i] > 0.5).astype(np.uint8)
            # Resize back to original size
            mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
            masks[class_name] = mask
            
        return masks
    
    def _rule_based_segmentation(self, image: np.ndarray) -> Dict[str, np.ndarray]:
        """Rule-based segmentation using color and edge detection"""
        # Convert to different color spaces
        hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        
        masks = {}
        
        # Wall detection (light colored, large regions)
        light_mask = cv2.inRange(hsv, np.array([0, 0, 200]), np.array([180, 30, 255]))
        walls = self._morphological_cleanup(light_mask)
        masks['wall'] = walls
        
        # Window detection (bright, rectangular regions)
        bright_mask = cv2.inRange(hsv, np.array([0, 0, 220]), np.array([180, 30, 255]))
        edges = cv2.Canny(gray, 50, 150)
        windows = cv2.bitwise_and(bright_mask, edges)
        masks['window'] = self._morphological_cleanup(windows)
        
        # Floor detection (bottom portion, darker)
        height, width = image.shape[:2]
        floor_mask = np.zeros((height, width), dtype=np.uint8)
        floor_mask[height//2:, :] = 255
        dark_mask = cv2.inRange(hsv, np.array([0, 0, 0]), np.array([180, 255, 100]))
        floor = cv2.bitwise_and(floor_mask, dark_mask)
        masks['floor'] = self._morphological_cleanup(floor)
        
        # Ceiling detection (top portion, light)
        ceiling_mask = np.zeros((height, width), dtype=np.uint8)
        ceiling_mask[:height//3, :] = 255
        ceiling = cv2.bitwise_and(ceiling_mask, light_mask)
        masks['ceiling'] = self._morphological_cleanup(ceiling)
        
        # Door detection (vertical edges, medium height)
        door_mask = np.zeros((height, width), dtype=np.uint8)
        door_mask[height//4:3*height//4, :] = 255
        vertical_edges = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
        vertical_edges = np.uint8(np.absolute(vertical_edges))
        doors = cv2.bitwise_and(door_mask, vertical_edges)
        masks['door'] = self._morphological_cleanup(doors)
        
        # Furniture (everything else)
        all_structural = sum([masks['wall'], masks['window'], masks['floor'], 
                            masks['ceiling'], masks['door']])
        furniture = cv2.bitwise_not(all_structural)
        masks['furniture'] = furniture
        
        return masks
    
    def _morphological_cleanup(self, mask: np.ndarray) -> np.ndarray:
        """Clean up mask using morphological operations"""
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
        cleaned = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
        cleaned = cv2.morphologyEx(cleaned, cv2.MORPH_OPEN, kernel)
        return cleaned
    
    def create_preservation_mask(self, masks: Dict[str, np.ndarray], 

                                preserve_classes: List[str]) -> np.ndarray:
        """

        Create a mask for elements that should be preserved

        

        Args:

            masks: Dictionary of segmentation masks

            preserve_classes: List of class names to preserve

            

        Returns:

            Binary mask where 1 indicates preserved regions

        """
        preservation_mask = np.zeros_like(list(masks.values())[0])
        
        for class_name in preserve_classes:
            if class_name in masks:
                preservation_mask = cv2.bitwise_or(preservation_mask, masks[class_name])
        
        # Dilate to ensure complete coverage
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
        preservation_mask = cv2.dilate(preservation_mask, kernel, iterations=2)
        
        return preservation_mask
    
    def visualize_segmentation(self, image: np.ndarray, masks: Dict[str, np.ndarray]) -> np.ndarray:
        """Visualize segmentation results"""
        colors = [
            [255, 0, 0],    # Red - walls
            [0, 255, 0],    # Green - windows
            [0, 0, 255],    # Blue - doors
            [255, 255, 0],  # Cyan - ceiling
            [255, 0, 255],  # Magenta - floor
            [0, 255, 255],  # Yellow - furniture
            [128, 128, 128], # Gray - lighting
            [64, 64, 64],   # Dark gray - decoration
            [192, 192, 192] # Light gray - textile
        ]
        
        result = image.copy()
        for i, (class_name, mask) in enumerate(masks.items()):
            if i < len(colors):
                color = colors[i]
                result[mask > 0] = color
        
        return result