SceneWeaver / image_blender.py
DawnC's picture
Upload 11 files
7d583e3 verified
import cv2
import numpy as np
import traceback
from PIL import Image
import logging
from typing import Dict, Any, Optional, Tuple
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class ImageBlender:
"""
Advanced image blending with aggressive spill suppression and color replacement.
Supports two primary modes:
- Background generation: Foreground preservation with edge refinement
- Inpainting: Seamless blending with adaptive color correction
Attributes:
enable_multi_scale: Whether multi-scale edge refinement is enabled
"""
EDGE_EROSION_PIXELS = 1 # Pixels to erode from mask edge
ALPHA_BINARIZE_THRESHOLD = 0.5 # Alpha threshold for binarization
DARK_LUMINANCE_THRESHOLD = 60 # Luminance threshold for dark foreground
FOREGROUND_PROTECTION_THRESHOLD = 140 # Mask value for strong protection
BACKGROUND_COLOR_TOLERANCE = 30 # DeltaE tolerance for background detection
# Inpainting-specific parameters
INPAINT_FEATHER_SCALE = 1.2 # Scale factor for inpainting feathering
INPAINT_COLOR_BLEND_RADIUS = 10 # Radius for color adaptation zone
def __init__(self, enable_multi_scale: bool = True):
"""
Initialize ImageBlender.
Parameters
----------
enable_multi_scale : bool
Whether to enable multi-scale edge refinement (default True)
"""
self.enable_multi_scale = enable_multi_scale
self._debug_info = {}
self._adaptive_strength_map = None
def _erode_mask_edges(
self,
mask_array: np.ndarray,
erosion_pixels: int = 2
) -> np.ndarray:
"""
Erode mask edges to remove contaminated boundary pixels.
This removes the outermost pixels of the foreground mask where
color contamination from the original background is most likely.
Args:
mask_array: Input mask as numpy array (uint8, 0-255)
erosion_pixels: Number of pixels to erode (default 2)
Returns:
Eroded mask array (uint8)
"""
if erosion_pixels <= 0:
return mask_array
# Use elliptical kernel for natural-looking erosion
kernel_size = max(2, erosion_pixels)
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE,
(kernel_size, kernel_size)
)
# Apply erosion
eroded = cv2.erode(mask_array, kernel, iterations=1)
# Slight blur to smooth the eroded edges
eroded = cv2.GaussianBlur(eroded, (3, 3), 0)
logger.debug(f"Mask erosion applied: {erosion_pixels}px, kernel size: {kernel_size}")
return eroded
def _binarize_edge_alpha(
self,
alpha: np.ndarray,
mask_array: np.ndarray,
orig_array: np.ndarray,
threshold: float = 0.45
) -> np.ndarray:
"""
Binarize semi-transparent edge pixels to eliminate color bleeding.
Semi-transparent pixels at edges cause visible contamination because
they blend the original (potentially dark) foreground with the new
background. This method forces edge pixels to be either fully opaque
or fully transparent.
Args:
alpha: Current alpha channel (float32, 0.0-1.0)
mask_array: Original mask array (uint8, 0-255)
orig_array: Original foreground image array (uint8, RGB)
threshold: Alpha threshold for binarization decision (default 0.45)
Returns:
Modified alpha array with binarized edges (float32)
"""
# Identify semi-transparent edge zone (not fully opaque, not fully transparent)
edge_zone = (alpha > 0.05) & (alpha < 0.95)
if not np.any(edge_zone):
return alpha
# Calculate local foreground luminance for adaptive thresholding
gray = np.mean(orig_array, axis=2)
# For dark foreground pixels, use slightly higher threshold
# to preserve more of the dark subject
is_dark = gray < self.DARK_LUMINANCE_THRESHOLD
# Create adaptive threshold map
adaptive_threshold = np.full_like(alpha, threshold)
adaptive_threshold[is_dark] = threshold + 0.1 # Keep more dark pixels
# Binarize: above threshold -> opaque, below -> transparent
alpha_binarized = alpha.copy()
# Pixels above threshold become fully opaque
make_opaque = edge_zone & (alpha > adaptive_threshold)
alpha_binarized[make_opaque] = 1.0
# Pixels below threshold become fully transparent
make_transparent = edge_zone & (alpha <= adaptive_threshold)
alpha_binarized[make_transparent] = 0.0
# Log statistics
num_opaque = np.sum(make_opaque)
num_transparent = np.sum(make_transparent)
logger.info(f"Edge binarization: {num_opaque} pixels -> opaque, {num_transparent} pixels -> transparent")
return alpha_binarized
def _apply_edge_cleanup(
self,
result_array: np.ndarray,
bg_array: np.ndarray,
alpha: np.ndarray,
cleanup_width: int = 2
) -> np.ndarray:
"""
Final cleanup pass to remove any remaining edge artifacts.
Detects remaining semi-transparent edges and replaces them with
either pure foreground or pure background colors.
Args:
result_array: Current blended result (uint8, RGB)
bg_array: Background image array (uint8, RGB)
alpha: Final alpha channel (float32, 0.0-1.0)
cleanup_width: Width of edge zone to clean (default 2)
Returns:
Cleaned result array (uint8)
"""
# Find edge pixels that might still have artifacts
# These are pixels with alpha close to but not exactly 0 or 1
residual_edge = (alpha > 0.01) & (alpha < 0.99) & (alpha != 0.0) & (alpha != 1.0)
if not np.any(residual_edge):
return result_array
result_cleaned = result_array.copy()
# For residual edge pixels, snap to nearest pure state
snap_to_bg = residual_edge & (alpha < 0.5)
snap_to_fg = residual_edge & (alpha >= 0.5)
# Replace with background
result_cleaned[snap_to_bg] = bg_array[snap_to_bg]
# For foreground, keep original but ensure no blending artifacts
# (already handled by the blend, so no action needed for snap_to_fg)
num_cleaned = np.sum(residual_edge)
if num_cleaned > 0:
logger.debug(f"Edge cleanup: {num_cleaned} residual pixels cleaned")
return result_cleaned
def _remove_background_color_contamination(
self,
image_array: np.ndarray,
mask_array: np.ndarray,
orig_bg_color_lab: np.ndarray,
tolerance: float = 30.0
) -> np.ndarray:
"""
Remove original background color contamination from foreground pixels.
Scans the foreground area for pixels that match the original background
color and replaces them with nearby clean foreground colors.
Args:
image_array: Foreground image array (uint8, RGB)
mask_array: Mask array (uint8, 0-255)
orig_bg_color_lab: Original background color in Lab space
tolerance: DeltaE tolerance for detecting contaminated pixels
Returns:
Cleaned image array (uint8)
"""
# Convert to Lab for color comparison
image_lab = cv2.cvtColor(image_array, cv2.COLOR_RGB2LAB).astype(np.float32)
# Only process foreground pixels (mask > 50)
foreground_mask = mask_array > 50
if not np.any(foreground_mask):
return image_array
# Calculate deltaE from original background color for all pixels
delta_l = image_lab[:, :, 0] - orig_bg_color_lab[0]
delta_a = image_lab[:, :, 1] - orig_bg_color_lab[1]
delta_b = image_lab[:, :, 2] - orig_bg_color_lab[2]
delta_e = np.sqrt(delta_l**2 + delta_a**2 + delta_b**2)
# Find contaminated pixels: in foreground but color similar to original background
contaminated = foreground_mask & (delta_e < tolerance)
if not np.any(contaminated):
logger.debug("No background color contamination detected in foreground")
return image_array
num_contaminated = np.sum(contaminated)
logger.info(f"Found {num_contaminated} pixels with background color contamination")
# Create output array
result = image_array.copy()
# For contaminated pixels, use inpainting to replace with surrounding colors
inpaint_mask = contaminated.astype(np.uint8) * 255
try:
# Use inpainting to fill contaminated areas with surrounding foreground colors
result = cv2.inpaint(result, inpaint_mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
logger.info(f"Inpainted {num_contaminated} contaminated pixels")
except Exception as e:
logger.warning(f"Inpainting failed: {e}, using median filter fallback")
# Fallback: apply median filter to contaminated areas
median_filtered = cv2.medianBlur(image_array, 5)
result[contaminated] = median_filtered[contaminated]
return result
def _protect_foreground_core(
self,
result_array: np.ndarray,
orig_array: np.ndarray,
mask_array: np.ndarray,
protection_threshold: int = 140
) -> np.ndarray:
"""
Strongly protect core foreground pixels from any background influence.
For pixels with high mask confidence, directly use the original foreground
color without any blending, ensuring faces and bodies are not affected.
Args:
result_array: Current blended result (uint8, RGB)
orig_array: Original foreground image (uint8, RGB)
mask_array: Mask array (uint8, 0-255)
protection_threshold: Mask value above which pixels are fully protected
Returns:
Protected result array (uint8)
"""
# Identify strongly protected foreground pixels
strong_foreground = mask_array >= protection_threshold
if not np.any(strong_foreground):
return result_array
# For these pixels, use original foreground color directly
result_protected = result_array.copy()
result_protected[strong_foreground] = orig_array[strong_foreground]
num_protected = np.sum(strong_foreground)
logger.info(f"Protected {num_protected} core foreground pixels from background influence")
return result_protected
def multi_scale_edge_refinement(
self,
original_image: Image.Image,
background_image: Image.Image,
mask: Image.Image
) -> Image.Image:
"""
Multi-scale edge refinement for better edge quality.
Uses image pyramid to handle edges at different scales.
Args:
original_image: Foreground PIL Image
background_image: Background PIL Image
mask: Current mask PIL Image
Returns:
Refined mask PIL Image
"""
logger.info("🔍 Starting multi-scale edge refinement...")
try:
# Convert to numpy arrays
orig_array = np.array(original_image.convert('RGB'))
mask_array = np.array(mask).astype(np.float32)
height, width = mask_array.shape
# Define scales for pyramid
scales = [1.0, 0.5, 0.25] # Original, half, quarter
scale_masks = []
scale_complexities = []
# Convert to grayscale for edge detection
gray = cv2.cvtColor(orig_array, cv2.COLOR_RGB2GRAY)
for scale in scales:
if scale == 1.0:
scaled_gray = gray
scaled_mask = mask_array
else:
new_h = int(height * scale)
new_w = int(width * scale)
scaled_gray = cv2.resize(gray, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
scaled_mask = cv2.resize(mask_array, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
# Compute local complexity using gradient standard deviation
sobel_x = cv2.Sobel(scaled_gray, cv2.CV_64F, 1, 0, ksize=3)
sobel_y = cv2.Sobel(scaled_gray, cv2.CV_64F, 0, 1, ksize=3)
gradient_mag = np.sqrt(sobel_x**2 + sobel_y**2)
# Calculate local complexity in 5x5 regions
kernel_size = 5
complexity = cv2.blur(gradient_mag, (kernel_size, kernel_size))
# Resize back to original size
if scale != 1.0:
scaled_mask = cv2.resize(scaled_mask, (width, height), interpolation=cv2.INTER_LANCZOS4)
complexity = cv2.resize(complexity, (width, height), interpolation=cv2.INTER_LANCZOS4)
scale_masks.append(scaled_mask)
scale_complexities.append(complexity)
# Compute weights based on complexity
# High complexity -> use high resolution mask
# Low complexity -> use low resolution mask (smoother)
weights = np.zeros((len(scales), height, width), dtype=np.float32)
# Normalize complexities
max_complexity = max(c.max() for c in scale_complexities) + 1e-6
normalized_complexities = [c / max_complexity for c in scale_complexities]
# Weight assignment: higher complexity at each scale means that scale is more reliable
for i, complexity in enumerate(normalized_complexities):
if i == 0: # High resolution - prefer for high complexity regions
weights[i] = complexity
elif i == 1: # Medium resolution - moderate complexity
weights[i] = 0.5 * (1 - complexity) + 0.5 * complexity * 0.5
else: # Low resolution - prefer for low complexity regions
weights[i] = 1 - complexity
# Normalize weights so they sum to 1 at each pixel
weight_sum = weights.sum(axis=0, keepdims=True) + 1e-6
weights = weights / weight_sum
# Weighted blend of masks from different scales
refined_mask = np.zeros((height, width), dtype=np.float32)
for i, mask_i in enumerate(scale_masks):
refined_mask += weights[i] * mask_i
# Clip and convert to uint8
refined_mask = np.clip(refined_mask, 0, 255).astype(np.uint8)
logger.info("✅ Multi-scale edge refinement completed")
return Image.fromarray(refined_mask, mode='L')
except Exception as e:
logger.error(f"❌ Multi-scale refinement failed: {e}, using original mask")
return mask
def simple_blend_images(
self,
original_image: Image.Image,
background_image: Image.Image,
combination_mask: Image.Image,
use_multi_scale: Optional[bool] = None
) -> Image.Image:
"""
Aggressive spill suppression + color replacement: completely eliminate yellow edge residue, maintain sharp edges
Args:
original_image: Foreground PIL Image
background_image: Background PIL Image
combination_mask: Mask PIL Image (L mode)
use_multi_scale: Override for multi-scale refinement (None = use class default)
Returns:
Blended PIL Image
"""
logger.info("🎨 Starting advanced image blending process...")
# Apply multi-scale edge refinement if enabled
should_use_multi_scale = use_multi_scale if use_multi_scale is not None else self.enable_multi_scale
if should_use_multi_scale:
combination_mask = self.multi_scale_edge_refinement(
original_image, background_image, combination_mask
)
# Convert to numpy arrays
orig_array = np.array(original_image, dtype=np.uint8)
bg_array = np.array(background_image, dtype=np.uint8)
mask_array = np.array(combination_mask, dtype=np.uint8)
logger.info(f"📊 Image dimensions - Original: {orig_array.shape}, Background: {bg_array.shape}, Mask: {mask_array.shape}")
logger.info(f"📊 Mask statistics (before erosion) - Mean: {mask_array.mean():.1f}, Min: {mask_array.min()}, Max: {mask_array.max()}")
# === NEW: Apply mask erosion to remove contaminated edge pixels ===
mask_array = self._erode_mask_edges(mask_array, self.EDGE_EROSION_PIXELS)
logger.info(f"📊 Mask statistics (after erosion) - Mean: {mask_array.mean():.1f}, Min: {mask_array.min()}, Max: {mask_array.max()}")
# Enhanced parameters for better spill suppression
RING_WIDTH_PX = 4 # Increased ring width for better coverage
SPILL_STRENGTH = 0.85 # Stronger spill suppression
L_MATCH_STRENGTH = 0.65 # Stronger luminance matching
DELTAE_THRESHOLD = 18 # More aggressive contamination detection
HARD_EDGE_PROTECT = True # Black edge protection
INPAINT_FALLBACK = True # inpaint fallback repair
MULTI_PASS_CORRECTION = True # Enable multi-pass correction
# Estimate original background color and foreground representative color ===
height, width = orig_array.shape[:2]
# Take 15px from each side to estimate original background color
edge_width = 15
border_pixels = []
# Collect border pixels (excluding foreground areas)
border_mask = np.zeros((height, width), dtype=bool)
border_mask[:edge_width, :] = True # Top edge
border_mask[-edge_width:, :] = True # Bottom edge
border_mask[:, :edge_width] = True # Left edge
border_mask[:, -edge_width:] = True # Right edge
# Exclude foreground areas
fg_binary = mask_array > 50
border_mask = border_mask & (~fg_binary)
if np.any(border_mask):
border_pixels = orig_array[border_mask].reshape(-1, 3)
# Simplified background color estimation (no sklearn dependency)
try:
if len(border_pixels) > 100:
# Use histogram to find mode colors
# Quantize RGB to coarser grid to find main colors
quantized = (border_pixels // 32) * 32 # 8-level quantization
# Find most frequent color
unique_colors, counts = np.unique(quantized.reshape(-1, quantized.shape[-1]),
axis=0, return_counts=True)
most_common_idx = np.argmax(counts)
orig_bg_color_rgb = unique_colors[most_common_idx].astype(np.uint8)
else:
orig_bg_color_rgb = np.median(border_pixels, axis=0).astype(np.uint8)
except:
# Fallback: use four corners average
corners = np.array([orig_array[0,0], orig_array[0,-1],
orig_array[-1,0], orig_array[-1,-1]])
orig_bg_color_rgb = np.mean(corners, axis=0).astype(np.uint8)
else:
orig_bg_color_rgb = np.array([200, 180, 120], dtype=np.uint8) # Default yellow
# Convert to Lab space
orig_bg_color_lab = cv2.cvtColor(orig_bg_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
logger.info(f"🎨 Detected original background color: RGB{tuple(orig_bg_color_rgb)}")
# Remove original background color contamination from foreground
orig_array = self._remove_background_color_contamination(
orig_array,
mask_array,
orig_bg_color_lab,
tolerance=self.BACKGROUND_COLOR_TOLERANCE
)
# Redefine trimap, optimized for cartoon characters
try:
kernel_3x3 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
# FG_CORE: Reduce erosion iterations from 2 to 1 to avoid losing thin limbs
mask_eroded_once = cv2.erode(mask_array, kernel_3x3, iterations=1)
fg_core = mask_eroded_once > 127 # Adjustable parameter: erosion iterations
# RING: Use morphological gradient to redefine, ensuring only thin edge band
mask_dilated = cv2.dilate(mask_array, kernel_3x3, iterations=1)
mask_eroded = cv2.erode(mask_array, kernel_3x3, iterations=1)
# Ensure consistent data types to avoid overflow
morphological_gradient = cv2.subtract(mask_dilated, mask_eroded)
ring_zone = morphological_gradient > 0 # Areas with morphological gradient > 0 are edge bands
# BG: background area
bg_zone = mask_array < 30
logger.info(f"🔍 Trimap regions - FG_CORE: {fg_core.sum()}, RING: {ring_zone.sum()}, BG: {bg_zone.sum()}")
except Exception as e:
logger.error(f"❌ Trimap definition failed: {e}")
logger.error(f"📍 Traceback: {traceback.format_exc()}")
print(f"❌ TRIMAP ERROR: {e}")
print(f"Traceback: {traceback.format_exc()}")
# Fallback to simple definition
fg_core = mask_array > 200
ring_zone = (mask_array > 50) & (mask_array <= 200)
bg_zone = mask_array <= 50
# Foreground representative color: estimated from FG_CORE
if np.any(fg_core):
fg_pixels = orig_array[fg_core].reshape(-1, 3)
fg_rep_color_rgb = np.median(fg_pixels, axis=0).astype(np.uint8)
else:
fg_rep_color_rgb = np.array([80, 60, 40], dtype=np.uint8) # Default dark
fg_rep_color_lab = cv2.cvtColor(fg_rep_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
# Edge band spill suppression and repair
if np.any(ring_zone):
# Convert to Lab space
orig_lab = cv2.cvtColor(orig_array, cv2.COLOR_RGB2LAB).astype(np.float32)
orig_array_working = orig_array.copy().astype(np.float32)
# ΔE detect contaminated pixels
ring_pixels_lab = orig_lab[ring_zone]
# Calculate ΔE with original background color (simplified version)
delta_l = ring_pixels_lab[:, 0] - orig_bg_color_lab[0]
delta_a = ring_pixels_lab[:, 1] - orig_bg_color_lab[1]
delta_b = ring_pixels_lab[:, 2] - orig_bg_color_lab[2]
delta_e = np.sqrt(delta_l**2 + delta_a**2 + delta_b**2)
# Contaminated pixel mask
contaminated_mask = delta_e < DELTAE_THRESHOLD
if np.any(contaminated_mask):
# Calculate adaptive strength based on delta_e for each pixel
# Pixels closer to background color get stronger correction
contaminated_delta_e = delta_e[contaminated_mask]
# Adaptive strength formula: inverse relationship with delta_e
# Pixels very close to bg color (low delta_e) -> strong correction
# Pixels further from bg color (high delta_e) -> lighter correction
adaptive_strength = SPILL_STRENGTH * np.maximum(
0.0,
1.0 - (contaminated_delta_e / DELTAE_THRESHOLD)
)
# Clamp adaptive strength to reasonable range (30% - 100% of base strength)
min_strength = SPILL_STRENGTH * 0.3
adaptive_strength = np.clip(adaptive_strength, min_strength, SPILL_STRENGTH)
# Store for debug visualization
self._adaptive_strength_map = np.zeros_like(delta_e)
self._adaptive_strength_map[contaminated_mask] = adaptive_strength
logger.info(f"📊 Adaptive strength stats - Mean: {adaptive_strength.mean():.3f}, Min: {adaptive_strength.min():.3f}, Max: {adaptive_strength.max():.3f}")
# Chroma vector deprojection
bg_chroma = np.array([orig_bg_color_lab[1], orig_bg_color_lab[2]])
bg_chroma_norm = bg_chroma / (np.linalg.norm(bg_chroma) + 1e-6)
# Color correction for contaminated pixels
contaminated_pixels = ring_pixels_lab[contaminated_mask]
# Remove background chroma component with adaptive strength (per-pixel)
pixel_chroma = contaminated_pixels[:, 1:3] # a, b channels
projection = np.dot(pixel_chroma, bg_chroma_norm)[:, np.newaxis] * bg_chroma_norm
# Apply adaptive strength per pixel
adaptive_strength_2d = adaptive_strength[:, np.newaxis]
corrected_chroma = pixel_chroma - projection * adaptive_strength_2d
# Converge toward foreground representative color with adaptive strength
convergence_factor = adaptive_strength_2d * 0.6
corrected_chroma = (corrected_chroma * (1 - convergence_factor) +
fg_rep_color_lab[1:3] * convergence_factor)
# Adaptive luminance matching
adaptive_l_strength = adaptive_strength * (L_MATCH_STRENGTH / SPILL_STRENGTH)
corrected_l = (contaminated_pixels[:, 0] * (1 - adaptive_l_strength) +
fg_rep_color_lab[0] * adaptive_l_strength)
# Update Lab values
ring_pixels_lab[contaminated_mask, 0] = corrected_l
ring_pixels_lab[contaminated_mask, 1:3] = corrected_chroma
# Write back to original image
orig_lab[ring_zone] = ring_pixels_lab
# Dark edge protection
if HARD_EDGE_PROTECT:
gray = np.mean(orig_array, axis=2)
# Detect dark and high gradient areas
sobel_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
sobel_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
gradient_mag = np.sqrt(sobel_x**2 + sobel_y**2)
dark_edge_zone = ring_zone & (gray < 60) & (gradient_mag > 20)
# Protect these areas from excessive modification, copy directly from original
if np.any(dark_edge_zone):
orig_lab[dark_edge_zone] = cv2.cvtColor(orig_array, cv2.COLOR_RGB2LAB)[dark_edge_zone]
# Multi-pass correction for stubborn spill
if MULTI_PASS_CORRECTION:
# Second pass for remaining contamination
ring_pixels_lab_pass2 = orig_lab[ring_zone]
delta_l_pass2 = ring_pixels_lab_pass2[:, 0] - orig_bg_color_lab[0]
delta_a_pass2 = ring_pixels_lab_pass2[:, 1] - orig_bg_color_lab[1]
delta_b_pass2 = ring_pixels_lab_pass2[:, 2] - orig_bg_color_lab[2]
delta_e_pass2 = np.sqrt(delta_l_pass2**2 + delta_a_pass2**2 + delta_b_pass2**2)
still_contaminated = delta_e_pass2 < (DELTAE_THRESHOLD * 0.8)
if np.any(still_contaminated):
# Apply stronger correction to remaining contaminated pixels
remaining_pixels = ring_pixels_lab_pass2[still_contaminated]
# More aggressive chroma neutralization
remaining_chroma = remaining_pixels[:, 1:3]
neutralized_chroma = remaining_chroma * 0.3 + fg_rep_color_lab[1:3] * 0.7
# Stronger luminance matching
neutralized_l = remaining_pixels[:, 0] * 0.4 + fg_rep_color_lab[0] * 0.6
ring_pixels_lab_pass2[still_contaminated, 0] = neutralized_l
ring_pixels_lab_pass2[still_contaminated, 1:3] = neutralized_chroma
orig_lab[ring_zone] = ring_pixels_lab_pass2
# Convert back to RGB
orig_lab_clipped = np.clip(orig_lab, 0, 255).astype(np.uint8)
orig_array_corrected = cv2.cvtColor(orig_lab_clipped, cv2.COLOR_LAB2RGB)
# inpaint fallback repair
if INPAINT_FALLBACK:
# inpaint still contaminated outermost pixels
final_contaminated = ring_zone.copy()
# Check if there's still contamination after repair
final_lab = cv2.cvtColor(orig_array_corrected, cv2.COLOR_RGB2LAB).astype(np.float32)
final_ring_lab = final_lab[ring_zone]
final_delta_l = final_ring_lab[:, 0] - orig_bg_color_lab[0]
final_delta_a = final_ring_lab[:, 1] - orig_bg_color_lab[1]
final_delta_b = final_ring_lab[:, 2] - orig_bg_color_lab[2]
final_delta_e = np.sqrt(final_delta_l**2 + final_delta_a**2 + final_delta_b**2)
still_contaminated = final_delta_e < (DELTAE_THRESHOLD * 0.5)
if np.any(still_contaminated):
# Create inpaint mask
inpaint_mask = np.zeros((height, width), dtype=np.uint8)
ring_coords = np.where(ring_zone)
inpaint_coords = (ring_coords[0][still_contaminated], ring_coords[1][still_contaminated])
inpaint_mask[inpaint_coords] = 255
# Execute inpaint
try:
orig_array_corrected = cv2.inpaint(orig_array_corrected, inpaint_mask, 3, cv2.INPAINT_TELEA)
except:
# Fallback: directly cover with foreground representative color
orig_array_corrected[inpaint_coords] = fg_rep_color_rgb
orig_array = orig_array_corrected
# === Linear space blending (keep original logic) ===
def srgb_to_linear(img):
img_f = img.astype(np.float32) / 255.0
return np.where(img_f <= 0.04045, img_f / 12.92, np.power((img_f + 0.055) / 1.055, 2.4))
def linear_to_srgb(img):
img_clipped = np.clip(img, 0, 1)
return np.where(img_clipped <= 0.0031308,
12.92 * img_clipped,
1.055 * np.power(img_clipped, 1/2.4) - 0.055)
orig_linear = srgb_to_linear(orig_array)
bg_linear = srgb_to_linear(bg_array)
# Cartoon-optimized Alpha calculation
alpha = mask_array.astype(np.float32) / 255.0
# Core foreground region - fully opaque
alpha[fg_core] = 1.0
# Background region - fully transparent
alpha[bg_zone] = 0.0
# [Key Fix] Force pixels with mask≥160 to α=1.0, avoiding white fill areas being limited to 0.9
high_confidence_pixels = mask_array >= 160
alpha[high_confidence_pixels] = 1.0
logger.info(f"💯 High confidence pixels set to full opacity: {high_confidence_pixels.sum()}")
# Ring area can be dehaloed, but doesn't affect already set high confidence pixels
ring_without_high_conf = ring_zone & (~high_confidence_pixels)
alpha[ring_without_high_conf] = np.clip(alpha[ring_without_high_conf], 0.2, 0.9)
# Retain existing black outline/strong edge protection
orig_gray = np.mean(orig_array, axis=2)
# Detect strong edge areas
sobel_x = cv2.Sobel(orig_gray, cv2.CV_64F, 1, 0, ksize=3)
sobel_y = cv2.Sobel(orig_gray, cv2.CV_64F, 0, 1, ksize=3)
gradient_mag = np.sqrt(sobel_x**2 + sobel_y**2)
# Black outline/strong edge protection: nearly fully opaque
black_edge_threshold = 60 # black edge threshold
gradient_threshold = 25 # gradient threshold
strong_edges = (orig_gray < black_edge_threshold) & (gradient_mag > gradient_threshold) & (mask_array > 10)
alpha[strong_edges] = np.maximum(alpha[strong_edges], 0.995) # black edge alpha
logger.info(f"🛡️ Protection applied - High conf: {high_confidence_pixels.sum()}, Strong edges: {strong_edges.sum()}")
# Apply edge alpha binarization to eliminate semi-transparent artifacts
alpha = self._binarize_edge_alpha(
alpha,
mask_array,
orig_array,
threshold=self.ALPHA_BINARIZE_THRESHOLD
)
# Final blending
alpha_3d = alpha[:, :, np.newaxis]
result_linear = orig_linear * alpha_3d + bg_linear * (1 - alpha_3d)
result_srgb = linear_to_srgb(result_linear)
result_array = (result_srgb * 255).astype(np.uint8)
# Final edge cleanup pass
result_array = self._apply_edge_cleanup(result_array, bg_array, alpha)
# Protect core foreground from any background influence
# This ensures faces and bodies retain original colors
result_array = self._protect_foreground_core(
result_array,
np.array(original_image, dtype=np.uint8), # Use original unprocessed image
mask_array,
protection_threshold=self.FOREGROUND_PROTECTION_THRESHOLD
)
# Store debug information (for debug output)
self._debug_info = {
'orig_bg_color_rgb': orig_bg_color_rgb,
'fg_rep_color_rgb': fg_rep_color_rgb,
'orig_bg_color_lab': orig_bg_color_lab,
'fg_rep_color_lab': fg_rep_color_lab,
'ring_zone': ring_zone,
'fg_core': fg_core,
'alpha_final': alpha
}
return Image.fromarray(result_array)
def create_debug_images(
self,
original_image: Image.Image,
generated_background: Image.Image,
combination_mask: Image.Image,
combined_image: Image.Image
) -> Dict[str, Image.Image]:
"""
Generate debug images: (a) Final mask grayscale (b) Alpha heatmap (c) Ring visualization overlay
"""
debug_images = {}
# Final Mask grayscale
debug_images["mask_gray"] = combination_mask.convert('L')
# Alpha Heatmap
mask_array = np.array(combination_mask.convert('L'))
heatmap_colored = cv2.applyColorMap(mask_array, cv2.COLORMAP_JET)
heatmap_rgb = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
debug_images["alpha_heatmap"] = Image.fromarray(heatmap_rgb)
# Ring visualization overlay - show ring areas on original image
if hasattr(self, '_debug_info') and 'ring_zone' in self._debug_info:
ring_zone = self._debug_info['ring_zone']
orig_array = np.array(original_image)
ring_overlay = orig_array.copy()
# Mark ring areas with red semi-transparent overlay
ring_overlay[ring_zone] = ring_overlay[ring_zone] * 0.7 + np.array([255, 0, 0]) * 0.3
debug_images["ring_visualization"] = Image.fromarray(ring_overlay.astype(np.uint8))
else:
# If no ring information, use original image
debug_images["ring_visualization"] = original_image
# Adaptive strength heatmap - visualize per-pixel correction strength
if hasattr(self, '_adaptive_strength_map') and self._adaptive_strength_map is not None:
# Normalize adaptive strength to 0-255 for visualization
strength_map = self._adaptive_strength_map
if strength_map.max() > 0:
normalized_strength = (strength_map / strength_map.max() * 255).astype(np.uint8)
else:
normalized_strength = np.zeros_like(strength_map, dtype=np.uint8)
# Apply colormap
strength_heatmap = cv2.applyColorMap(normalized_strength, cv2.COLORMAP_VIRIDIS)
strength_heatmap_rgb = cv2.cvtColor(strength_heatmap, cv2.COLOR_BGR2RGB)
debug_images["adaptive_strength_heatmap"] = Image.fromarray(strength_heatmap_rgb)
return debug_images
# INPAINTING-SPECIFIC BLENDING METHODS
def blend_inpainting(
self,
original: Image.Image,
generated: Image.Image,
mask: Image.Image,
feather_radius: int = 8,
apply_color_correction: bool = True
) -> Image.Image:
"""
Blend inpainted region with original image.
Specialized blending for inpainting that focuses on seamless integration
rather than foreground protection. Performs blending in linear color space
with optional adaptive color correction at boundaries.
Parameters
----------
original : PIL.Image
Original image before inpainting
generated : PIL.Image
Generated/inpainted result from the model
mask : PIL.Image
Inpainting mask (white = inpainted area)
feather_radius : int
Feathering radius for smooth transitions
apply_color_correction : bool
Whether to apply adaptive color correction at boundaries
Returns
-------
PIL.Image
Blended result
"""
logger.info(f"Inpainting blend: feather={feather_radius}, color_correction={apply_color_correction}")
# Ensure same size
if generated.size != original.size:
generated = generated.resize(original.size, Image.LANCZOS)
if mask.size != original.size:
mask = mask.resize(original.size, Image.LANCZOS)
# Convert to arrays
orig_array = np.array(original.convert('RGB')).astype(np.float32)
gen_array = np.array(generated.convert('RGB')).astype(np.float32)
mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
# Apply feathering to mask
if feather_radius > 0:
scaled_radius = int(feather_radius * self.INPAINT_FEATHER_SCALE)
kernel_size = scaled_radius * 2 + 1
mask_array = cv2.GaussianBlur(
mask_array,
(kernel_size, kernel_size),
scaled_radius / 2
)
# Apply adaptive color correction if enabled
if apply_color_correction:
gen_array = self._apply_inpaint_color_correction(
orig_array, gen_array, mask_array
)
# sRGB to linear conversion for accurate blending
def srgb_to_linear(img):
img_norm = img / 255.0
return np.where(
img_norm <= 0.04045,
img_norm / 12.92,
np.power((img_norm + 0.055) / 1.055, 2.4)
)
def linear_to_srgb(img):
img_clipped = np.clip(img, 0, 1)
return np.where(
img_clipped <= 0.0031308,
12.92 * img_clipped,
1.055 * np.power(img_clipped, 1/2.4) - 0.055
)
# Convert to linear space
orig_linear = srgb_to_linear(orig_array)
gen_linear = srgb_to_linear(gen_array)
# Alpha blending in linear space
alpha = mask_array[:, :, np.newaxis]
result_linear = gen_linear * alpha + orig_linear * (1 - alpha)
# Convert back to sRGB
result_srgb = linear_to_srgb(result_linear)
result_array = (result_srgb * 255).astype(np.uint8)
logger.debug("Inpainting blend completed in linear color space")
return Image.fromarray(result_array)
def _apply_inpaint_color_correction(
self,
original: np.ndarray,
generated: np.ndarray,
mask: np.ndarray
) -> np.ndarray:
"""
Apply adaptive color correction to match generated region with surroundings.
Analyzes the boundary region and adjusts the generated content's
luminance and color to better match the original context.
Parameters
----------
original : np.ndarray
Original image (float32, 0-255)
generated : np.ndarray
Generated image (float32, 0-255)
mask : np.ndarray
Blend mask (float32, 0-1)
Returns
-------
np.ndarray
Color-corrected generated image
"""
# Find boundary region
mask_binary = (mask > 0.5).astype(np.uint8)
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE,
(self.INPAINT_COLOR_BLEND_RADIUS * 2 + 1, self.INPAINT_COLOR_BLEND_RADIUS * 2 + 1)
)
dilated = cv2.dilate(mask_binary, kernel, iterations=1)
boundary_zone = (dilated > 0) & (mask < 0.3)
if not np.any(boundary_zone):
return generated
# Convert to Lab for perceptual color matching
orig_lab = cv2.cvtColor(
original.astype(np.uint8), cv2.COLOR_RGB2LAB
).astype(np.float32)
gen_lab = cv2.cvtColor(
generated.astype(np.uint8), cv2.COLOR_RGB2LAB
).astype(np.float32)
# Calculate statistics in boundary zone (original)
boundary_orig_l = orig_lab[boundary_zone, 0]
boundary_orig_a = orig_lab[boundary_zone, 1]
boundary_orig_b = orig_lab[boundary_zone, 2]
orig_mean_l = np.median(boundary_orig_l)
orig_mean_a = np.median(boundary_orig_a)
orig_mean_b = np.median(boundary_orig_b)
# Calculate statistics in generated inpaint region
inpaint_zone = mask > 0.5
if not np.any(inpaint_zone):
return generated
gen_inpaint_l = gen_lab[inpaint_zone, 0]
gen_inpaint_a = gen_lab[inpaint_zone, 1]
gen_inpaint_b = gen_lab[inpaint_zone, 2]
gen_mean_l = np.median(gen_inpaint_l)
gen_mean_a = np.median(gen_inpaint_a)
gen_mean_b = np.median(gen_inpaint_b)
# Calculate correction deltas
delta_l = orig_mean_l - gen_mean_l
delta_a = orig_mean_a - gen_mean_a
delta_b = orig_mean_b - gen_mean_b
# Limit correction to avoid over-adjustment
max_correction = 15
delta_l = np.clip(delta_l, -max_correction, max_correction)
delta_a = np.clip(delta_a, -max_correction * 0.5, max_correction * 0.5)
delta_b = np.clip(delta_b, -max_correction * 0.5, max_correction * 0.5)
logger.debug(f"Color correction deltas: L={delta_l:.1f}, a={delta_a:.1f}, b={delta_b:.1f}")
# Apply correction with spatial falloff from boundary
# Create distance map from boundary
distance = cv2.distanceTransform(
mask_binary, cv2.DIST_L2, 5
)
max_dist = np.max(distance)
if max_dist > 0:
# Correction strength falls off from boundary toward center
correction_strength = 1.0 - np.clip(distance / (max_dist * 0.5), 0, 1)
else:
correction_strength = np.ones_like(distance)
# Apply correction to Lab channels
corrected_lab = gen_lab.copy()
corrected_lab[:, :, 0] += delta_l * correction_strength * 0.7
corrected_lab[:, :, 1] += delta_a * correction_strength * 0.5
corrected_lab[:, :, 2] += delta_b * correction_strength * 0.5
# Clip to valid Lab ranges
corrected_lab[:, :, 0] = np.clip(corrected_lab[:, :, 0], 0, 255)
corrected_lab[:, :, 1] = np.clip(corrected_lab[:, :, 1], 0, 255)
corrected_lab[:, :, 2] = np.clip(corrected_lab[:, :, 2], 0, 255)
# Convert back to RGB
corrected_rgb = cv2.cvtColor(
corrected_lab.astype(np.uint8), cv2.COLOR_LAB2RGB
).astype(np.float32)
return corrected_rgb
def blend_inpainting_with_guided_filter(
self,
original: Image.Image,
generated: Image.Image,
mask: Image.Image,
feather_radius: int = 8,
guide_radius: int = 8,
guide_eps: float = 0.01
) -> Image.Image:
"""
Blend inpainted region using guided filter for edge-aware transitions.
Combines standard alpha blending with guided filtering to preserve
edges in the original image while seamlessly integrating new content.
Parameters
----------
original : PIL.Image
Original image
generated : PIL.Image
Generated/inpainted result
mask : PIL.Image
Inpainting mask
feather_radius : int
Base feathering radius
guide_radius : int
Guided filter radius
guide_eps : float
Guided filter regularization
Returns
-------
PIL.Image
Blended result with edge-aware transitions
"""
logger.info("Applying guided filter inpainting blend")
# Ensure same size
if generated.size != original.size:
generated = generated.resize(original.size, Image.LANCZOS)
if mask.size != original.size:
mask = mask.resize(original.size, Image.LANCZOS)
# Convert to arrays
orig_array = np.array(original.convert('RGB')).astype(np.float32)
gen_array = np.array(generated.convert('RGB')).astype(np.float32)
mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
# Apply base feathering
if feather_radius > 0:
kernel_size = feather_radius * 2 + 1
mask_feathered = cv2.GaussianBlur(
mask_array,
(kernel_size, kernel_size),
feather_radius / 2
)
else:
mask_feathered = mask_array
# Use original image as guide for the filter
guide = cv2.cvtColor(orig_array.astype(np.uint8), cv2.COLOR_RGB2GRAY)
guide = guide.astype(np.float32) / 255.0
# Apply guided filter to the mask
try:
mask_guided = cv2.ximgproc.guidedFilter(
guide=guide,
src=mask_feathered,
radius=guide_radius,
eps=guide_eps
)
logger.debug("Guided filter applied successfully")
except Exception as e:
logger.warning(f"Guided filter failed: {e}, using standard feathering")
mask_guided = mask_feathered
# Alpha blending
alpha = mask_guided[:, :, np.newaxis]
result = gen_array * alpha + orig_array * (1 - alpha)
result = np.clip(result, 0, 255).astype(np.uint8)
return Image.fromarray(result)