File size: 7,392 Bytes
01c9ed4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
"""
Room Segmentation Module for preserving structural elements
"""
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import segmentation_models_pytorch as smp
from typing import Tuple, Dict, List
import albumentations as A
from albumentations.pytorch import ToTensorV2
class RoomSegmentation:
def __init__(self, device: str = "cuda"):
self.device = device
self.model = None
self.class_names = [
'wall', 'window', 'door', 'ceiling', 'floor',
'furniture', 'lighting', 'decoration', 'textile'
]
self.load_model()
def load_model(self):
"""Load pre-trained segmentation model"""
try:
# Using DeepLabV3+ with ResNet50 backbone
self.model = smp.DeepLabV3Plus(
encoder_name="resnet50",
encoder_weights="imagenet",
classes=len(self.class_names),
activation="softmax"
)
self.model.to(self.device)
self.model.eval()
except Exception as e:
print(f"Warning: Could not load segmentation model: {e}")
print("Falling back to rule-based segmentation")
self.model = None
def segment_room(self, image: np.ndarray) -> Dict[str, np.ndarray]:
"""
Segment room image into different components
Args:
image: Input room image (BGR format)
Returns:
Dictionary with class names as keys and binary masks as values
"""
if self.model is not None:
return self._ml_segmentation(image)
else:
return self._rule_based_segmentation(image)
def _ml_segmentation(self, image: np.ndarray) -> Dict[str, np.ndarray]:
"""ML-based segmentation using pre-trained model"""
# Preprocess image
transform = A.Compose([
A.Resize(512, 512),
A.Normalize(),
ToTensorV2()
])
transformed = transform(image=image)
input_tensor = transformed['image'].unsqueeze(0).to(self.device)
with torch.no_grad():
output = self.model(input_tensor)
predictions = F.softmax(output, dim=1)
predictions = predictions.squeeze().cpu().numpy()
# Convert to binary masks
masks = {}
for i, class_name in enumerate(self.class_names):
mask = (predictions[i] > 0.5).astype(np.uint8)
# Resize back to original size
mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
masks[class_name] = mask
return masks
def _rule_based_segmentation(self, image: np.ndarray) -> Dict[str, np.ndarray]:
"""Rule-based segmentation using color and edge detection"""
# Convert to different color spaces
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
masks = {}
# Wall detection (light colored, large regions)
light_mask = cv2.inRange(hsv, np.array([0, 0, 200]), np.array([180, 30, 255]))
walls = self._morphological_cleanup(light_mask)
masks['wall'] = walls
# Window detection (bright, rectangular regions)
bright_mask = cv2.inRange(hsv, np.array([0, 0, 220]), np.array([180, 30, 255]))
edges = cv2.Canny(gray, 50, 150)
windows = cv2.bitwise_and(bright_mask, edges)
masks['window'] = self._morphological_cleanup(windows)
# Floor detection (bottom portion, darker)
height, width = image.shape[:2]
floor_mask = np.zeros((height, width), dtype=np.uint8)
floor_mask[height//2:, :] = 255
dark_mask = cv2.inRange(hsv, np.array([0, 0, 0]), np.array([180, 255, 100]))
floor = cv2.bitwise_and(floor_mask, dark_mask)
masks['floor'] = self._morphological_cleanup(floor)
# Ceiling detection (top portion, light)
ceiling_mask = np.zeros((height, width), dtype=np.uint8)
ceiling_mask[:height//3, :] = 255
ceiling = cv2.bitwise_and(ceiling_mask, light_mask)
masks['ceiling'] = self._morphological_cleanup(ceiling)
# Door detection (vertical edges, medium height)
door_mask = np.zeros((height, width), dtype=np.uint8)
door_mask[height//4:3*height//4, :] = 255
vertical_edges = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
vertical_edges = np.uint8(np.absolute(vertical_edges))
doors = cv2.bitwise_and(door_mask, vertical_edges)
masks['door'] = self._morphological_cleanup(doors)
# Furniture (everything else)
all_structural = sum([masks['wall'], masks['window'], masks['floor'],
masks['ceiling'], masks['door']])
furniture = cv2.bitwise_not(all_structural)
masks['furniture'] = furniture
return masks
def _morphological_cleanup(self, mask: np.ndarray) -> np.ndarray:
"""Clean up mask using morphological operations"""
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
cleaned = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
cleaned = cv2.morphologyEx(cleaned, cv2.MORPH_OPEN, kernel)
return cleaned
def create_preservation_mask(self, masks: Dict[str, np.ndarray],
preserve_classes: List[str]) -> np.ndarray:
"""
Create a mask for elements that should be preserved
Args:
masks: Dictionary of segmentation masks
preserve_classes: List of class names to preserve
Returns:
Binary mask where 1 indicates preserved regions
"""
preservation_mask = np.zeros_like(list(masks.values())[0])
for class_name in preserve_classes:
if class_name in masks:
preservation_mask = cv2.bitwise_or(preservation_mask, masks[class_name])
# Dilate to ensure complete coverage
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
preservation_mask = cv2.dilate(preservation_mask, kernel, iterations=2)
return preservation_mask
def visualize_segmentation(self, image: np.ndarray, masks: Dict[str, np.ndarray]) -> np.ndarray:
"""Visualize segmentation results"""
colors = [
[255, 0, 0], # Red - walls
[0, 255, 0], # Green - windows
[0, 0, 255], # Blue - doors
[255, 255, 0], # Cyan - ceiling
[255, 0, 255], # Magenta - floor
[0, 255, 255], # Yellow - furniture
[128, 128, 128], # Gray - lighting
[64, 64, 64], # Dark gray - decoration
[192, 192, 192] # Light gray - textile
]
result = image.copy()
for i, (class_name, mask) in enumerate(masks.items()):
if i < len(colors):
color = colors[i]
result[mask > 0] = color
return result
|