Lasya18's picture
Upload 3 files
01c9ed4 verified
raw
history blame
7.39 kB
"""
Room Segmentation Module for preserving structural elements
"""
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import segmentation_models_pytorch as smp
from typing import Tuple, Dict, List
import albumentations as A
from albumentations.pytorch import ToTensorV2
class RoomSegmentation:
def __init__(self, device: str = "cuda"):
self.device = device
self.model = None
self.class_names = [
'wall', 'window', 'door', 'ceiling', 'floor',
'furniture', 'lighting', 'decoration', 'textile'
]
self.load_model()
def load_model(self):
"""Load pre-trained segmentation model"""
try:
# Using DeepLabV3+ with ResNet50 backbone
self.model = smp.DeepLabV3Plus(
encoder_name="resnet50",
encoder_weights="imagenet",
classes=len(self.class_names),
activation="softmax"
)
self.model.to(self.device)
self.model.eval()
except Exception as e:
print(f"Warning: Could not load segmentation model: {e}")
print("Falling back to rule-based segmentation")
self.model = None
def segment_room(self, image: np.ndarray) -> Dict[str, np.ndarray]:
"""
Segment room image into different components
Args:
image: Input room image (BGR format)
Returns:
Dictionary with class names as keys and binary masks as values
"""
if self.model is not None:
return self._ml_segmentation(image)
else:
return self._rule_based_segmentation(image)
def _ml_segmentation(self, image: np.ndarray) -> Dict[str, np.ndarray]:
"""ML-based segmentation using pre-trained model"""
# Preprocess image
transform = A.Compose([
A.Resize(512, 512),
A.Normalize(),
ToTensorV2()
])
transformed = transform(image=image)
input_tensor = transformed['image'].unsqueeze(0).to(self.device)
with torch.no_grad():
output = self.model(input_tensor)
predictions = F.softmax(output, dim=1)
predictions = predictions.squeeze().cpu().numpy()
# Convert to binary masks
masks = {}
for i, class_name in enumerate(self.class_names):
mask = (predictions[i] > 0.5).astype(np.uint8)
# Resize back to original size
mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
masks[class_name] = mask
return masks
def _rule_based_segmentation(self, image: np.ndarray) -> Dict[str, np.ndarray]:
"""Rule-based segmentation using color and edge detection"""
# Convert to different color spaces
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
masks = {}
# Wall detection (light colored, large regions)
light_mask = cv2.inRange(hsv, np.array([0, 0, 200]), np.array([180, 30, 255]))
walls = self._morphological_cleanup(light_mask)
masks['wall'] = walls
# Window detection (bright, rectangular regions)
bright_mask = cv2.inRange(hsv, np.array([0, 0, 220]), np.array([180, 30, 255]))
edges = cv2.Canny(gray, 50, 150)
windows = cv2.bitwise_and(bright_mask, edges)
masks['window'] = self._morphological_cleanup(windows)
# Floor detection (bottom portion, darker)
height, width = image.shape[:2]
floor_mask = np.zeros((height, width), dtype=np.uint8)
floor_mask[height//2:, :] = 255
dark_mask = cv2.inRange(hsv, np.array([0, 0, 0]), np.array([180, 255, 100]))
floor = cv2.bitwise_and(floor_mask, dark_mask)
masks['floor'] = self._morphological_cleanup(floor)
# Ceiling detection (top portion, light)
ceiling_mask = np.zeros((height, width), dtype=np.uint8)
ceiling_mask[:height//3, :] = 255
ceiling = cv2.bitwise_and(ceiling_mask, light_mask)
masks['ceiling'] = self._morphological_cleanup(ceiling)
# Door detection (vertical edges, medium height)
door_mask = np.zeros((height, width), dtype=np.uint8)
door_mask[height//4:3*height//4, :] = 255
vertical_edges = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
vertical_edges = np.uint8(np.absolute(vertical_edges))
doors = cv2.bitwise_and(door_mask, vertical_edges)
masks['door'] = self._morphological_cleanup(doors)
# Furniture (everything else)
all_structural = sum([masks['wall'], masks['window'], masks['floor'],
masks['ceiling'], masks['door']])
furniture = cv2.bitwise_not(all_structural)
masks['furniture'] = furniture
return masks
def _morphological_cleanup(self, mask: np.ndarray) -> np.ndarray:
"""Clean up mask using morphological operations"""
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
cleaned = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
cleaned = cv2.morphologyEx(cleaned, cv2.MORPH_OPEN, kernel)
return cleaned
def create_preservation_mask(self, masks: Dict[str, np.ndarray],
preserve_classes: List[str]) -> np.ndarray:
"""
Create a mask for elements that should be preserved
Args:
masks: Dictionary of segmentation masks
preserve_classes: List of class names to preserve
Returns:
Binary mask where 1 indicates preserved regions
"""
preservation_mask = np.zeros_like(list(masks.values())[0])
for class_name in preserve_classes:
if class_name in masks:
preservation_mask = cv2.bitwise_or(preservation_mask, masks[class_name])
# Dilate to ensure complete coverage
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
preservation_mask = cv2.dilate(preservation_mask, kernel, iterations=2)
return preservation_mask
def visualize_segmentation(self, image: np.ndarray, masks: Dict[str, np.ndarray]) -> np.ndarray:
"""Visualize segmentation results"""
colors = [
[255, 0, 0], # Red - walls
[0, 255, 0], # Green - windows
[0, 0, 255], # Blue - doors
[255, 255, 0], # Cyan - ceiling
[255, 0, 255], # Magenta - floor
[0, 255, 255], # Yellow - furniture
[128, 128, 128], # Gray - lighting
[64, 64, 64], # Dark gray - decoration
[192, 192, 192] # Light gray - textile
]
result = image.copy()
for i, (class_name, mask) in enumerate(masks.items()):
if i < len(colors):
color = colors[i]
result[mask > 0] = color
return result