""" Room Segmentation Module for preserving structural elements """ import cv2 import numpy as np import torch import torch.nn.functional as F from PIL import Image import segmentation_models_pytorch as smp from typing import Tuple, Dict, List import albumentations as A from albumentations.pytorch import ToTensorV2 class RoomSegmentation: def __init__(self, device: str = "cuda"): self.device = device self.model = None self.class_names = [ 'wall', 'window', 'door', 'ceiling', 'floor', 'furniture', 'lighting', 'decoration', 'textile' ] self.load_model() def load_model(self): """Load pre-trained segmentation model""" try: # Using DeepLabV3+ with ResNet50 backbone self.model = smp.DeepLabV3Plus( encoder_name="resnet50", encoder_weights="imagenet", classes=len(self.class_names), activation="softmax" ) self.model.to(self.device) self.model.eval() except Exception as e: print(f"Warning: Could not load segmentation model: {e}") print("Falling back to rule-based segmentation") self.model = None def segment_room(self, image: np.ndarray) -> Dict[str, np.ndarray]: """ Segment room image into different components Args: image: Input room image (BGR format) Returns: Dictionary with class names as keys and binary masks as values """ if self.model is not None: return self._ml_segmentation(image) else: return self._rule_based_segmentation(image) def _ml_segmentation(self, image: np.ndarray) -> Dict[str, np.ndarray]: """ML-based segmentation using pre-trained model""" # Preprocess image transform = A.Compose([ A.Resize(512, 512), A.Normalize(), ToTensorV2() ]) transformed = transform(image=image) input_tensor = transformed['image'].unsqueeze(0).to(self.device) with torch.no_grad(): output = self.model(input_tensor) predictions = F.softmax(output, dim=1) predictions = predictions.squeeze().cpu().numpy() # Convert to binary masks masks = {} for i, class_name in enumerate(self.class_names): mask = (predictions[i] > 0.5).astype(np.uint8) # Resize back to original size mask = cv2.resize(mask, (image.shape[1], image.shape[0])) masks[class_name] = mask return masks def _rule_based_segmentation(self, image: np.ndarray) -> Dict[str, np.ndarray]: """Rule-based segmentation using color and edge detection""" # Convert to different color spaces hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) masks = {} # Wall detection (light colored, large regions) light_mask = cv2.inRange(hsv, np.array([0, 0, 200]), np.array([180, 30, 255])) walls = self._morphological_cleanup(light_mask) masks['wall'] = walls # Window detection (bright, rectangular regions) bright_mask = cv2.inRange(hsv, np.array([0, 0, 220]), np.array([180, 30, 255])) edges = cv2.Canny(gray, 50, 150) windows = cv2.bitwise_and(bright_mask, edges) masks['window'] = self._morphological_cleanup(windows) # Floor detection (bottom portion, darker) height, width = image.shape[:2] floor_mask = np.zeros((height, width), dtype=np.uint8) floor_mask[height//2:, :] = 255 dark_mask = cv2.inRange(hsv, np.array([0, 0, 0]), np.array([180, 255, 100])) floor = cv2.bitwise_and(floor_mask, dark_mask) masks['floor'] = self._morphological_cleanup(floor) # Ceiling detection (top portion, light) ceiling_mask = np.zeros((height, width), dtype=np.uint8) ceiling_mask[:height//3, :] = 255 ceiling = cv2.bitwise_and(ceiling_mask, light_mask) masks['ceiling'] = self._morphological_cleanup(ceiling) # Door detection (vertical edges, medium height) door_mask = np.zeros((height, width), dtype=np.uint8) door_mask[height//4:3*height//4, :] = 255 vertical_edges = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) vertical_edges = np.uint8(np.absolute(vertical_edges)) doors = cv2.bitwise_and(door_mask, vertical_edges) masks['door'] = self._morphological_cleanup(doors) # Furniture (everything else) all_structural = sum([masks['wall'], masks['window'], masks['floor'], masks['ceiling'], masks['door']]) furniture = cv2.bitwise_not(all_structural) masks['furniture'] = furniture return masks def _morphological_cleanup(self, mask: np.ndarray) -> np.ndarray: """Clean up mask using morphological operations""" kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) cleaned = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) cleaned = cv2.morphologyEx(cleaned, cv2.MORPH_OPEN, kernel) return cleaned def create_preservation_mask(self, masks: Dict[str, np.ndarray], preserve_classes: List[str]) -> np.ndarray: """ Create a mask for elements that should be preserved Args: masks: Dictionary of segmentation masks preserve_classes: List of class names to preserve Returns: Binary mask where 1 indicates preserved regions """ preservation_mask = np.zeros_like(list(masks.values())[0]) for class_name in preserve_classes: if class_name in masks: preservation_mask = cv2.bitwise_or(preservation_mask, masks[class_name]) # Dilate to ensure complete coverage kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10)) preservation_mask = cv2.dilate(preservation_mask, kernel, iterations=2) return preservation_mask def visualize_segmentation(self, image: np.ndarray, masks: Dict[str, np.ndarray]) -> np.ndarray: """Visualize segmentation results""" colors = [ [255, 0, 0], # Red - walls [0, 255, 0], # Green - windows [0, 0, 255], # Blue - doors [255, 255, 0], # Cyan - ceiling [255, 0, 255], # Magenta - floor [0, 255, 255], # Yellow - furniture [128, 128, 128], # Gray - lighting [64, 64, 64], # Dark gray - decoration [192, 192, 192] # Light gray - textile ] result = image.copy() for i, (class_name, mask) in enumerate(masks.items()): if i < len(colors): color = colors[i] result[mask > 0] = color return result