|
|
""" |
|
|
Utility functions for the Interior Style Transfer Pipeline |
|
|
""" |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import os |
|
|
from typing import Tuple, List, Optional, Union |
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
def load_image_safe(image_path: str, target_size: Tuple[int, int] = None) -> np.ndarray: |
|
|
""" |
|
|
Safely load an image with error handling |
|
|
|
|
|
Args: |
|
|
image_path: Path to the image file |
|
|
target_size: Optional target size (width, height) |
|
|
|
|
|
Returns: |
|
|
Loaded image as numpy array |
|
|
|
|
|
Raises: |
|
|
ValueError: If image cannot be loaded |
|
|
""" |
|
|
if not os.path.exists(image_path): |
|
|
raise ValueError(f"Image file not found: {image_path}") |
|
|
|
|
|
|
|
|
image = cv2.imread(image_path) |
|
|
if image is None: |
|
|
|
|
|
try: |
|
|
pil_image = Image.open(image_path) |
|
|
image = np.array(pil_image) |
|
|
if len(image.shape) == 3 and image.shape[2] == 3: |
|
|
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
|
|
elif len(image.shape) == 3 and image.shape[2] == 4: |
|
|
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR) |
|
|
except Exception as e: |
|
|
raise ValueError(f"Could not load image {image_path}: {e}") |
|
|
|
|
|
if target_size: |
|
|
image = cv2.resize(image, target_size) |
|
|
|
|
|
return image |
|
|
|
|
|
def save_image_safe(image: np.ndarray, output_path: str, |
|
|
quality: int = 95) -> bool: |
|
|
""" |
|
|
Safely save an image with error handling |
|
|
|
|
|
Args: |
|
|
image: Image to save as numpy array |
|
|
output_path: Output file path |
|
|
quality: JPEG quality (1-100) |
|
|
|
|
|
Returns: |
|
|
True if successful, False otherwise |
|
|
""" |
|
|
try: |
|
|
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
|
|
|
|
|
|
|
success = cv2.imwrite(output_path, image) |
|
|
|
|
|
if not success: |
|
|
|
|
|
if len(image.shape) == 3 and image.shape[2] == 3: |
|
|
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) |
|
|
else: |
|
|
pil_image = Image.fromarray(image) |
|
|
|
|
|
pil_image.save(output_path, quality=quality) |
|
|
success = True |
|
|
|
|
|
return success |
|
|
except Exception as e: |
|
|
print(f"Error saving image to {output_path}: {e}") |
|
|
return False |
|
|
|
|
|
def validate_image_pair(user_room: np.ndarray, inspiration_room: np.ndarray) -> Tuple[bool, str]: |
|
|
""" |
|
|
Validate that two images are suitable for style transfer |
|
|
|
|
|
Args: |
|
|
user_room: User room image |
|
|
inspiration_room: Inspiration room image |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, error_message) |
|
|
""" |
|
|
|
|
|
if user_room.shape != inspiration_room.shape: |
|
|
return False, f"Image dimensions don't match: {user_room.shape} vs {inspiration_room.shape}" |
|
|
|
|
|
|
|
|
min_size = 256 |
|
|
if user_room.shape[0] < min_size or user_room.shape[1] < min_size: |
|
|
return False, f"Images too small. Minimum size: {min_size}x{min_size}" |
|
|
|
|
|
|
|
|
aspect_ratio = user_room.shape[1] / user_room.shape[0] |
|
|
if aspect_ratio < 0.5 or aspect_ratio > 2.0: |
|
|
return False, f"Extreme aspect ratio: {aspect_ratio:.2f}. Square images work best." |
|
|
|
|
|
|
|
|
user_brightness = np.mean(cv2.cvtColor(user_room, cv2.COLOR_BGR2GRAY)) |
|
|
inspiration_brightness = np.mean(cv2.cvtColor(inspiration_room, cv2.COLOR_BGR2GRAY)) |
|
|
|
|
|
if user_brightness < 30 or user_brightness > 225: |
|
|
return False, f"User room too {'dark' if user_brightness < 30 else 'bright'}: {user_brightness:.1f}" |
|
|
|
|
|
if inspiration_brightness < 30 or inspiration_brightness > 225: |
|
|
return False, f"Inspiration room too {'dark' if inspiration_brightness < 30 else 'bright'}: {inspiration_brightness:.1f}" |
|
|
|
|
|
return True, "Images are valid for style transfer" |
|
|
|
|
|
def create_comparison_image(original: np.ndarray, result: np.ndarray, |
|
|
title: str = "Style Transfer Comparison") -> np.ndarray: |
|
|
""" |
|
|
Create a side-by-side comparison image |
|
|
|
|
|
Args: |
|
|
original: Original user room image |
|
|
result: Style transfer result |
|
|
title: Title for the comparison |
|
|
|
|
|
Returns: |
|
|
Comparison image |
|
|
""" |
|
|
|
|
|
if original.shape != result.shape: |
|
|
result = cv2.resize(result, (original.shape[1], original.shape[0])) |
|
|
|
|
|
|
|
|
comparison = np.hstack([original, result]) |
|
|
|
|
|
|
|
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
|
font_scale = 1.0 |
|
|
thickness = 2 |
|
|
|
|
|
|
|
|
text_size = cv2.getTextSize(title, font, font_scale, thickness)[0] |
|
|
text_x = (comparison.shape[1] - text_size[0]) // 2 |
|
|
text_y = 50 |
|
|
|
|
|
|
|
|
cv2.rectangle(comparison, (text_x - 10, text_y - 30), |
|
|
(text_x + text_size[0] + 10, text_y + 10), (255, 255, 255), -1) |
|
|
|
|
|
|
|
|
cv2.putText(comparison, title, (text_x, text_y), font, font_scale, (0, 0, 0), thickness) |
|
|
|
|
|
|
|
|
cv2.putText(comparison, "Original", (50, comparison.shape[0] - 30), |
|
|
font, 0.7, (255, 255, 255), 2) |
|
|
cv2.putText(comparison, "Result", (original.shape[1] + 50, comparison.shape[0] - 30), |
|
|
font, 0.7, (255, 255, 255), 2) |
|
|
|
|
|
return comparison |
|
|
|
|
|
def create_multi_comparison_image(images: List[np.ndarray], |
|
|
titles: List[str] = None, |
|
|
title: str = "Multi-Image Comparison") -> np.ndarray: |
|
|
""" |
|
|
Create a comparison image with multiple images side by side |
|
|
|
|
|
Args: |
|
|
images: List of images to compare |
|
|
titles: List of titles for each image (optional) |
|
|
title: Main title for the comparison |
|
|
|
|
|
Returns: |
|
|
Comparison image |
|
|
""" |
|
|
if not images: |
|
|
raise ValueError("At least one image is required") |
|
|
|
|
|
|
|
|
target_shape = images[0].shape |
|
|
resized_images = [] |
|
|
for img in images: |
|
|
if img.shape != target_shape: |
|
|
resized_img = cv2.resize(img, (target_shape[1], target_shape[0])) |
|
|
resized_images.append(resized_img) |
|
|
else: |
|
|
resized_images.append(img) |
|
|
|
|
|
|
|
|
comparison = np.hstack(resized_images) |
|
|
|
|
|
|
|
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
|
font_scale = 1.0 |
|
|
thickness = 2 |
|
|
|
|
|
|
|
|
text_size = cv2.getTextSize(title, font, font_scale, thickness)[0] |
|
|
text_x = (comparison.shape[1] - text_size[0]) // 2 |
|
|
text_y = 50 |
|
|
|
|
|
|
|
|
cv2.rectangle(comparison, (text_x - 10, text_y - 30), |
|
|
(text_x + text_size[0] + 10, text_y + 10), (255, 255, 255), -1) |
|
|
|
|
|
|
|
|
cv2.putText(comparison, title, (text_x, text_y), font, font_scale, (0, 0, 0), thickness) |
|
|
|
|
|
|
|
|
if titles and len(titles) == len(images): |
|
|
font_scale_small = 0.7 |
|
|
thickness_small = 1 |
|
|
|
|
|
for i, (img, img_title) in enumerate(zip(resized_images, titles)): |
|
|
|
|
|
img_width = img.shape[1] |
|
|
start_x = sum(img.shape[1] for img in resized_images[:i]) |
|
|
|
|
|
|
|
|
title_size = cv2.getTextSize(img_title, font, font_scale_small, thickness_small)[0] |
|
|
title_x = start_x + (img_width - title_size[0]) // 2 |
|
|
title_y = comparison.shape[0] - 30 |
|
|
|
|
|
|
|
|
cv2.rectangle(comparison, (title_x - 5, title_y - 20), |
|
|
(title_x + title_size[0] + 5, title_y + 5), (255, 255, 255), -1) |
|
|
|
|
|
|
|
|
cv2.putText(comparison, img_title, (title_x, title_y), |
|
|
font, font_scale_small, (0, 0, 0), thickness_small) |
|
|
|
|
|
return comparison |
|
|
|
|
|
def enhance_image_quality(image: np.ndarray, |
|
|
sharpness: float = 0.3, |
|
|
contrast: float = 1.1, |
|
|
saturation: float = 1.1) -> np.ndarray: |
|
|
""" |
|
|
Enhance image quality with various filters |
|
|
|
|
|
Args: |
|
|
image: Input image |
|
|
sharpness: Sharpening strength (0.0 to 1.0) |
|
|
contrast: Contrast multiplier |
|
|
saturation: Saturation multiplier |
|
|
|
|
|
Returns: |
|
|
Enhanced image |
|
|
""" |
|
|
enhanced = image.copy() |
|
|
|
|
|
|
|
|
if sharpness > 0: |
|
|
kernel = np.array([[-1, -1, -1], |
|
|
[-1, 9, -1], |
|
|
[-1, -1, -1]]) * sharpness |
|
|
enhanced = cv2.filter2D(enhanced, -1, kernel) |
|
|
|
|
|
|
|
|
if contrast != 1.0: |
|
|
enhanced = np.clip(enhanced * contrast, 0, 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
if saturation != 1.0: |
|
|
hsv = cv2.cvtColor(enhanced, cv2.COLOR_BGR2HSV).astype(np.float32) |
|
|
hsv[:, :, 1] = np.clip(hsv[:, :, 1] * saturation, 0, 255) |
|
|
enhanced = cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR) |
|
|
|
|
|
return enhanced |
|
|
|
|
|
def create_progress_bar(total: int, description: str = "Processing") -> callable: |
|
|
""" |
|
|
Create a simple progress bar function |
|
|
|
|
|
Args: |
|
|
total: Total number of steps |
|
|
description: Description of the process |
|
|
|
|
|
Returns: |
|
|
Function to update progress |
|
|
""" |
|
|
def update_progress(current: int): |
|
|
percentage = (current / total) * 100 |
|
|
bar_length = 30 |
|
|
filled_length = int(bar_length * current // total) |
|
|
bar = '█' * filled_length + '-' * (bar_length - filled_length) |
|
|
print(f'\r{description}: |{bar}| {percentage:.1f}% ({current}/{total})', end='') |
|
|
if current == total: |
|
|
print() |
|
|
|
|
|
return update_progress |
|
|
|
|
|
def save_metadata(metadata: dict, output_path: str) -> bool: |
|
|
""" |
|
|
Save metadata to JSON file |
|
|
|
|
|
Args: |
|
|
metadata: Dictionary of metadata |
|
|
output_path: Output file path |
|
|
|
|
|
Returns: |
|
|
True if successful, False otherwise |
|
|
""" |
|
|
try: |
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
|
|
|
|
with open(output_path, 'w') as f: |
|
|
json.dump(metadata, f, indent=2, default=str) |
|
|
|
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"Error saving metadata to {output_path}: {e}") |
|
|
return False |
|
|
|
|
|
def load_metadata(metadata_path: str) -> Optional[dict]: |
|
|
""" |
|
|
Load metadata from JSON file |
|
|
|
|
|
Args: |
|
|
metadata_path: Path to metadata file |
|
|
|
|
|
Returns: |
|
|
Loaded metadata dictionary or None if failed |
|
|
""" |
|
|
try: |
|
|
with open(metadata_path, 'r') as f: |
|
|
return json.load(f) |
|
|
except Exception as e: |
|
|
print(f"Error loading metadata from {metadata_path}: {e}") |
|
|
return None |
|
|
|
|
|
def calculate_image_similarity(img1: np.ndarray, img2: np.ndarray) -> float: |
|
|
""" |
|
|
Calculate similarity between two images using structural similarity |
|
|
|
|
|
Args: |
|
|
img1: First image |
|
|
img2: Second image |
|
|
|
|
|
Returns: |
|
|
Similarity score (0.0 to 1.0, higher is more similar) |
|
|
""" |
|
|
try: |
|
|
from skimage.metrics import structural_similarity as ssim |
|
|
|
|
|
|
|
|
if img1.shape != img2.shape: |
|
|
img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0])) |
|
|
|
|
|
|
|
|
gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) |
|
|
gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
|
|
|
similarity = ssim(gray1, gray2) |
|
|
return max(0.0, similarity) |
|
|
|
|
|
except ImportError: |
|
|
|
|
|
if img1.shape != img2.shape: |
|
|
img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0])) |
|
|
|
|
|
mse = np.mean((img1.astype(np.float32) - img2.astype(np.float32)) ** 2) |
|
|
max_mse = 255 ** 2 |
|
|
similarity = 1.0 - (mse / max_mse) |
|
|
return max(0.0, similarity) |
|
|
|
|
|
def create_thumbnail(image: np.ndarray, max_size: int = 200) -> np.ndarray: |
|
|
""" |
|
|
Create a thumbnail version of an image |
|
|
|
|
|
Args: |
|
|
image: Input image |
|
|
max_size: Maximum dimension size |
|
|
|
|
|
Returns: |
|
|
Thumbnail image |
|
|
""" |
|
|
height, width = image.shape[:2] |
|
|
|
|
|
if height <= max_size and width <= max_size: |
|
|
return image.copy() |
|
|
|
|
|
|
|
|
if height > width: |
|
|
new_height = max_size |
|
|
new_width = int(width * max_size / height) |
|
|
else: |
|
|
new_width = max_size |
|
|
new_height = int(height * max_size / width) |
|
|
|
|
|
thumbnail = cv2.resize(image, (new_width, new_height)) |
|
|
return thumbnail |
|
|
|
|
|
def batch_resize_images(images: List[np.ndarray], |
|
|
target_size: Tuple[int, int]) -> List[np.ndarray]: |
|
|
""" |
|
|
Resize a list of images to the same target size |
|
|
|
|
|
Args: |
|
|
images: List of input images |
|
|
target_size: Target size (width, height) |
|
|
|
|
|
Returns: |
|
|
List of resized images |
|
|
""" |
|
|
resized_images = [] |
|
|
|
|
|
for image in images: |
|
|
resized = cv2.resize(image, target_size) |
|
|
resized_images.append(resized) |
|
|
|
|
|
return resized_images |
|
|
|
|
|
def create_image_grid(images: List[np.ndarray], |
|
|
grid_size: Tuple[int, int] = None) -> np.ndarray: |
|
|
""" |
|
|
Create a grid layout of images |
|
|
|
|
|
Args: |
|
|
images: List of images to arrange in grid |
|
|
grid_size: Grid dimensions (rows, cols). If None, auto-calculate |
|
|
|
|
|
Returns: |
|
|
Grid image |
|
|
""" |
|
|
if not images: |
|
|
return np.array([]) |
|
|
|
|
|
if grid_size is None: |
|
|
|
|
|
n_images = len(images) |
|
|
cols = int(np.ceil(np.sqrt(n_images))) |
|
|
rows = int(np.ceil(n_images / cols)) |
|
|
grid_size = (rows, cols) |
|
|
|
|
|
rows, cols = grid_size |
|
|
|
|
|
|
|
|
target_size = (images[0].shape[1], images[0].shape[0]) |
|
|
resized_images = batch_resize_images(images, target_size) |
|
|
|
|
|
|
|
|
grid_rows = [] |
|
|
for i in range(rows): |
|
|
row_images = [] |
|
|
for j in range(cols): |
|
|
idx = i * cols + j |
|
|
if idx < len(resized_images): |
|
|
row_images.append(resized_images[idx]) |
|
|
else: |
|
|
|
|
|
empty_image = np.zeros((target_size[1], target_size[0], 3), dtype=np.uint8) |
|
|
row_images.append(empty_image) |
|
|
|
|
|
row = np.hstack(row_images) |
|
|
grid_rows.append(row) |
|
|
|
|
|
grid = np.vstack(grid_rows) |
|
|
return grid |
|
|
|
|
|
|