|
|
"""
|
|
|
Room Segmentation Module for preserving structural elements
|
|
|
"""
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from PIL import Image
|
|
|
import segmentation_models_pytorch as smp
|
|
|
from typing import Tuple, Dict, List
|
|
|
import albumentations as A
|
|
|
from albumentations.pytorch import ToTensorV2
|
|
|
|
|
|
class RoomSegmentation:
|
|
|
def __init__(self, device: str = "cuda"):
|
|
|
self.device = device
|
|
|
self.model = None
|
|
|
self.class_names = [
|
|
|
'wall', 'window', 'door', 'ceiling', 'floor',
|
|
|
'furniture', 'lighting', 'decoration', 'textile'
|
|
|
]
|
|
|
self.load_model()
|
|
|
|
|
|
def load_model(self):
|
|
|
"""Load pre-trained segmentation model"""
|
|
|
try:
|
|
|
|
|
|
self.model = smp.DeepLabV3Plus(
|
|
|
encoder_name="resnet50",
|
|
|
encoder_weights="imagenet",
|
|
|
classes=len(self.class_names),
|
|
|
activation="softmax"
|
|
|
)
|
|
|
self.model.to(self.device)
|
|
|
self.model.eval()
|
|
|
except Exception as e:
|
|
|
print(f"Warning: Could not load segmentation model: {e}")
|
|
|
print("Falling back to rule-based segmentation")
|
|
|
self.model = None
|
|
|
|
|
|
def segment_room(self, image: np.ndarray) -> Dict[str, np.ndarray]:
|
|
|
"""
|
|
|
Segment room image into different components
|
|
|
|
|
|
Args:
|
|
|
image: Input room image (BGR format)
|
|
|
|
|
|
Returns:
|
|
|
Dictionary with class names as keys and binary masks as values
|
|
|
"""
|
|
|
if self.model is not None:
|
|
|
return self._ml_segmentation(image)
|
|
|
else:
|
|
|
return self._rule_based_segmentation(image)
|
|
|
|
|
|
def _ml_segmentation(self, image: np.ndarray) -> Dict[str, np.ndarray]:
|
|
|
"""ML-based segmentation using pre-trained model"""
|
|
|
|
|
|
transform = A.Compose([
|
|
|
A.Resize(512, 512),
|
|
|
A.Normalize(),
|
|
|
ToTensorV2()
|
|
|
])
|
|
|
|
|
|
transformed = transform(image=image)
|
|
|
input_tensor = transformed['image'].unsqueeze(0).to(self.device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
output = self.model(input_tensor)
|
|
|
predictions = F.softmax(output, dim=1)
|
|
|
predictions = predictions.squeeze().cpu().numpy()
|
|
|
|
|
|
|
|
|
masks = {}
|
|
|
for i, class_name in enumerate(self.class_names):
|
|
|
mask = (predictions[i] > 0.5).astype(np.uint8)
|
|
|
|
|
|
mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
|
|
|
masks[class_name] = mask
|
|
|
|
|
|
return masks
|
|
|
|
|
|
def _rule_based_segmentation(self, image: np.ndarray) -> Dict[str, np.ndarray]:
|
|
|
"""Rule-based segmentation using color and edge detection"""
|
|
|
|
|
|
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
|
|
|
|
masks = {}
|
|
|
|
|
|
|
|
|
light_mask = cv2.inRange(hsv, np.array([0, 0, 200]), np.array([180, 30, 255]))
|
|
|
walls = self._morphological_cleanup(light_mask)
|
|
|
masks['wall'] = walls
|
|
|
|
|
|
|
|
|
bright_mask = cv2.inRange(hsv, np.array([0, 0, 220]), np.array([180, 30, 255]))
|
|
|
edges = cv2.Canny(gray, 50, 150)
|
|
|
windows = cv2.bitwise_and(bright_mask, edges)
|
|
|
masks['window'] = self._morphological_cleanup(windows)
|
|
|
|
|
|
|
|
|
height, width = image.shape[:2]
|
|
|
floor_mask = np.zeros((height, width), dtype=np.uint8)
|
|
|
floor_mask[height//2:, :] = 255
|
|
|
dark_mask = cv2.inRange(hsv, np.array([0, 0, 0]), np.array([180, 255, 100]))
|
|
|
floor = cv2.bitwise_and(floor_mask, dark_mask)
|
|
|
masks['floor'] = self._morphological_cleanup(floor)
|
|
|
|
|
|
|
|
|
ceiling_mask = np.zeros((height, width), dtype=np.uint8)
|
|
|
ceiling_mask[:height//3, :] = 255
|
|
|
ceiling = cv2.bitwise_and(ceiling_mask, light_mask)
|
|
|
masks['ceiling'] = self._morphological_cleanup(ceiling)
|
|
|
|
|
|
|
|
|
door_mask = np.zeros((height, width), dtype=np.uint8)
|
|
|
door_mask[height//4:3*height//4, :] = 255
|
|
|
vertical_edges = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
|
|
|
vertical_edges = np.uint8(np.absolute(vertical_edges))
|
|
|
doors = cv2.bitwise_and(door_mask, vertical_edges)
|
|
|
masks['door'] = self._morphological_cleanup(doors)
|
|
|
|
|
|
|
|
|
all_structural = sum([masks['wall'], masks['window'], masks['floor'],
|
|
|
masks['ceiling'], masks['door']])
|
|
|
furniture = cv2.bitwise_not(all_structural)
|
|
|
masks['furniture'] = furniture
|
|
|
|
|
|
return masks
|
|
|
|
|
|
def _morphological_cleanup(self, mask: np.ndarray) -> np.ndarray:
|
|
|
"""Clean up mask using morphological operations"""
|
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
|
|
cleaned = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
|
|
cleaned = cv2.morphologyEx(cleaned, cv2.MORPH_OPEN, kernel)
|
|
|
return cleaned
|
|
|
|
|
|
def create_preservation_mask(self, masks: Dict[str, np.ndarray],
|
|
|
preserve_classes: List[str]) -> np.ndarray:
|
|
|
"""
|
|
|
Create a mask for elements that should be preserved
|
|
|
|
|
|
Args:
|
|
|
masks: Dictionary of segmentation masks
|
|
|
preserve_classes: List of class names to preserve
|
|
|
|
|
|
Returns:
|
|
|
Binary mask where 1 indicates preserved regions
|
|
|
"""
|
|
|
preservation_mask = np.zeros_like(list(masks.values())[0])
|
|
|
|
|
|
for class_name in preserve_classes:
|
|
|
if class_name in masks:
|
|
|
preservation_mask = cv2.bitwise_or(preservation_mask, masks[class_name])
|
|
|
|
|
|
|
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
|
|
|
preservation_mask = cv2.dilate(preservation_mask, kernel, iterations=2)
|
|
|
|
|
|
return preservation_mask
|
|
|
|
|
|
def visualize_segmentation(self, image: np.ndarray, masks: Dict[str, np.ndarray]) -> np.ndarray:
|
|
|
"""Visualize segmentation results"""
|
|
|
colors = [
|
|
|
[255, 0, 0],
|
|
|
[0, 255, 0],
|
|
|
[0, 0, 255],
|
|
|
[255, 255, 0],
|
|
|
[255, 0, 255],
|
|
|
[0, 255, 255],
|
|
|
[128, 128, 128],
|
|
|
[64, 64, 64],
|
|
|
[192, 192, 192]
|
|
|
]
|
|
|
|
|
|
result = image.copy()
|
|
|
for i, (class_name, mask) in enumerate(masks.items()):
|
|
|
if i < len(colors):
|
|
|
color = colors[i]
|
|
|
result[mask > 0] = color
|
|
|
|
|
|
return result
|
|
|
|