Lasya18's picture
Update utils.py
3b3c48f verified
raw
history blame
15.1 kB
"""
Utility functions for the Interior Style Transfer Pipeline
"""
import cv2
import numpy as np
from PIL import Image
import os
from typing import Tuple, List, Optional, Union
import json
from pathlib import Path
def load_image_safe(image_path: str, target_size: Tuple[int, int] = None) -> np.ndarray:
"""
Safely load an image with error handling
Args:
image_path: Path to the image file
target_size: Optional target size (width, height)
Returns:
Loaded image as numpy array
Raises:
ValueError: If image cannot be loaded
"""
if not os.path.exists(image_path):
raise ValueError(f"Image file not found: {image_path}")
# Try to load with OpenCV first
image = cv2.imread(image_path)
if image is None:
# Fallback to PIL
try:
pil_image = Image.open(image_path)
image = np.array(pil_image)
if len(image.shape) == 3 and image.shape[2] == 3:
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
elif len(image.shape) == 3 and image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
except Exception as e:
raise ValueError(f"Could not load image {image_path}: {e}")
if target_size:
image = cv2.resize(image, target_size)
return image
def save_image_safe(image: np.ndarray, output_path: str,
quality: int = 95) -> bool:
"""
Safely save an image with error handling
Args:
image: Image to save as numpy array
output_path: Output file path
quality: JPEG quality (1-100)
Returns:
True if successful, False otherwise
"""
try:
# Ensure output directory exists
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# Save with OpenCV
success = cv2.imwrite(output_path, image)
if not success:
# Fallback to PIL
if len(image.shape) == 3 and image.shape[2] == 3:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
else:
pil_image = Image.fromarray(image)
pil_image.save(output_path, quality=quality)
success = True
return success
except Exception as e:
print(f"Error saving image to {output_path}: {e}")
return False
def validate_image_pair(user_room: np.ndarray, inspiration_room: np.ndarray) -> Tuple[bool, str]:
"""
Validate that two images are suitable for style transfer
Args:
user_room: User room image
inspiration_room: Inspiration room image
Returns:
Tuple of (is_valid, error_message)
"""
# Check image dimensions
if user_room.shape != inspiration_room.shape:
return False, f"Image dimensions don't match: {user_room.shape} vs {inspiration_room.shape}"
# Check minimum size
min_size = 256
if user_room.shape[0] < min_size or user_room.shape[1] < min_size:
return False, f"Images too small. Minimum size: {min_size}x{min_size}"
# Check aspect ratio (should be roughly square for best results)
aspect_ratio = user_room.shape[1] / user_room.shape[0]
if aspect_ratio < 0.5 or aspect_ratio > 2.0:
return False, f"Extreme aspect ratio: {aspect_ratio:.2f}. Square images work best."
# Check if images are too dark or too bright
user_brightness = np.mean(cv2.cvtColor(user_room, cv2.COLOR_BGR2GRAY))
inspiration_brightness = np.mean(cv2.cvtColor(inspiration_room, cv2.COLOR_BGR2GRAY))
if user_brightness < 30 or user_brightness > 225:
return False, f"User room too {'dark' if user_brightness < 30 else 'bright'}: {user_brightness:.1f}"
if inspiration_brightness < 30 or inspiration_brightness > 225:
return False, f"Inspiration room too {'dark' if inspiration_brightness < 30 else 'bright'}: {inspiration_brightness:.1f}"
return True, "Images are valid for style transfer"
def create_comparison_image(original: np.ndarray, result: np.ndarray,
title: str = "Style Transfer Comparison") -> np.ndarray:
"""
Create a side-by-side comparison image
Args:
original: Original user room image
result: Style transfer result
title: Title for the comparison
Returns:
Comparison image
"""
# Ensure both images have the same dimensions
if original.shape != result.shape:
result = cv2.resize(result, (original.shape[1], original.shape[0]))
# Create comparison image
comparison = np.hstack([original, result])
# Add title
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1.0
thickness = 2
# Calculate text position
text_size = cv2.getTextSize(title, font, font_scale, thickness)[0]
text_x = (comparison.shape[1] - text_size[0]) // 2
text_y = 50
# Add background for text
cv2.rectangle(comparison, (text_x - 10, text_y - 30),
(text_x + text_size[0] + 10, text_y + 10), (255, 255, 255), -1)
# Add text
cv2.putText(comparison, title, (text_x, text_y), font, font_scale, (0, 0, 0), thickness)
# Add labels
cv2.putText(comparison, "Original", (50, comparison.shape[0] - 30),
font, 0.7, (255, 255, 255), 2)
cv2.putText(comparison, "Result", (original.shape[1] + 50, comparison.shape[0] - 30),
font, 0.7, (255, 255, 255), 2)
return comparison
def create_multi_comparison_image(images: List[np.ndarray],
titles: List[str] = None,
title: str = "Multi-Image Comparison") -> np.ndarray:
"""
Create a comparison image with multiple images side by side
Args:
images: List of images to compare
titles: List of titles for each image (optional)
title: Main title for the comparison
Returns:
Comparison image
"""
if not images:
raise ValueError("At least one image is required")
# Ensure all images have the same dimensions
target_shape = images[0].shape
resized_images = []
for img in images:
if img.shape != target_shape:
resized_img = cv2.resize(img, (target_shape[1], target_shape[0]))
resized_images.append(resized_img)
else:
resized_images.append(img)
# Create horizontal stack of images
comparison = np.hstack(resized_images)
# Add main title
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1.0
thickness = 2
# Calculate text position for main title
text_size = cv2.getTextSize(title, font, font_scale, thickness)[0]
text_x = (comparison.shape[1] - text_size[0]) // 2
text_y = 50
# Add background for main title
cv2.rectangle(comparison, (text_x - 10, text_y - 30),
(text_x + text_size[0] + 10, text_y + 10), (255, 255, 255), -1)
# Add main title
cv2.putText(comparison, title, (text_x, text_y), font, font_scale, (0, 0, 0), thickness)
# Add individual image titles if provided
if titles and len(titles) == len(images):
font_scale_small = 0.7
thickness_small = 1
for i, (img, img_title) in enumerate(zip(resized_images, titles)):
# Calculate position for each image title
img_width = img.shape[1]
start_x = sum(img.shape[1] for img in resized_images[:i])
# Add background for image title
title_size = cv2.getTextSize(img_title, font, font_scale_small, thickness_small)[0]
title_x = start_x + (img_width - title_size[0]) // 2
title_y = comparison.shape[0] - 30
# Add background rectangle
cv2.rectangle(comparison, (title_x - 5, title_y - 20),
(title_x + title_size[0] + 5, title_y + 5), (255, 255, 255), -1)
# Add image title
cv2.putText(comparison, img_title, (title_x, title_y),
font, font_scale_small, (0, 0, 0), thickness_small)
return comparison
def enhance_image_quality(image: np.ndarray,
sharpness: float = 0.3,
contrast: float = 1.1,
saturation: float = 1.1) -> np.ndarray:
"""
Enhance image quality with various filters
Args:
image: Input image
sharpness: Sharpening strength (0.0 to 1.0)
contrast: Contrast multiplier
saturation: Saturation multiplier
Returns:
Enhanced image
"""
enhanced = image.copy()
# Sharpening
if sharpness > 0:
kernel = np.array([[-1, -1, -1],
[-1, 9, -1],
[-1, -1, -1]]) * sharpness
enhanced = cv2.filter2D(enhanced, -1, kernel)
# Contrast adjustment
if contrast != 1.0:
enhanced = np.clip(enhanced * contrast, 0, 255).astype(np.uint8)
# Saturation adjustment
if saturation != 1.0:
hsv = cv2.cvtColor(enhanced, cv2.COLOR_BGR2HSV).astype(np.float32)
hsv[:, :, 1] = np.clip(hsv[:, :, 1] * saturation, 0, 255)
enhanced = cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR)
return enhanced
def create_progress_bar(total: int, description: str = "Processing") -> callable:
"""
Create a simple progress bar function
Args:
total: Total number of steps
description: Description of the process
Returns:
Function to update progress
"""
def update_progress(current: int):
percentage = (current / total) * 100
bar_length = 30
filled_length = int(bar_length * current // total)
bar = '█' * filled_length + '-' * (bar_length - filled_length)
print(f'\r{description}: |{bar}| {percentage:.1f}% ({current}/{total})', end='')
if current == total:
print()
return update_progress
def save_metadata(metadata: dict, output_path: str) -> bool:
"""
Save metadata to JSON file
Args:
metadata: Dictionary of metadata
output_path: Output file path
Returns:
True if successful, False otherwise
"""
try:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'w') as f:
json.dump(metadata, f, indent=2, default=str)
return True
except Exception as e:
print(f"Error saving metadata to {output_path}: {e}")
return False
def load_metadata(metadata_path: str) -> Optional[dict]:
"""
Load metadata from JSON file
Args:
metadata_path: Path to metadata file
Returns:
Loaded metadata dictionary or None if failed
"""
try:
with open(metadata_path, 'r') as f:
return json.load(f)
except Exception as e:
print(f"Error loading metadata from {metadata_path}: {e}")
return None
def calculate_image_similarity(img1: np.ndarray, img2: np.ndarray) -> float:
"""
Calculate similarity between two images using structural similarity
Args:
img1: First image
img2: Second image
Returns:
Similarity score (0.0 to 1.0, higher is more similar)
"""
try:
from skimage.metrics import structural_similarity as ssim
# Ensure same dimensions
if img1.shape != img2.shape:
img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
# Convert to grayscale for SSIM
gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
# Calculate SSIM
similarity = ssim(gray1, gray2)
return max(0.0, similarity) # Ensure non-negative
except ImportError:
# Fallback to simple MSE-based similarity
if img1.shape != img2.shape:
img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
mse = np.mean((img1.astype(np.float32) - img2.astype(np.float32)) ** 2)
max_mse = 255 ** 2
similarity = 1.0 - (mse / max_mse)
return max(0.0, similarity)
def create_thumbnail(image: np.ndarray, max_size: int = 200) -> np.ndarray:
"""
Create a thumbnail version of an image
Args:
image: Input image
max_size: Maximum dimension size
Returns:
Thumbnail image
"""
height, width = image.shape[:2]
if height <= max_size and width <= max_size:
return image.copy()
# Calculate new dimensions maintaining aspect ratio
if height > width:
new_height = max_size
new_width = int(width * max_size / height)
else:
new_width = max_size
new_height = int(height * max_size / width)
thumbnail = cv2.resize(image, (new_width, new_height))
return thumbnail
def batch_resize_images(images: List[np.ndarray],
target_size: Tuple[int, int]) -> List[np.ndarray]:
"""
Resize a list of images to the same target size
Args:
images: List of input images
target_size: Target size (width, height)
Returns:
List of resized images
"""
resized_images = []
for image in images:
resized = cv2.resize(image, target_size)
resized_images.append(resized)
return resized_images
def create_image_grid(images: List[np.ndarray],
grid_size: Tuple[int, int] = None) -> np.ndarray:
"""
Create a grid layout of images
Args:
images: List of images to arrange in grid
grid_size: Grid dimensions (rows, cols). If None, auto-calculate
Returns:
Grid image
"""
if not images:
return np.array([])
if grid_size is None:
# Auto-calculate grid size
n_images = len(images)
cols = int(np.ceil(np.sqrt(n_images)))
rows = int(np.ceil(n_images / cols))
grid_size = (rows, cols)
rows, cols = grid_size
# Ensure all images have the same size
target_size = (images[0].shape[1], images[0].shape[0])
resized_images = batch_resize_images(images, target_size)
# Create grid
grid_rows = []
for i in range(rows):
row_images = []
for j in range(cols):
idx = i * cols + j
if idx < len(resized_images):
row_images.append(resized_images[idx])
else:
# Fill empty space with black
empty_image = np.zeros((target_size[1], target_size[0], 3), dtype=np.uint8)
row_images.append(empty_image)
row = np.hstack(row_images)
grid_rows.append(row)
grid = np.vstack(grid_rows)
return grid