Spaces:
Running
on
Zero
Running
on
Zero
File size: 35,704 Bytes
ca80d1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 |
import cv2
import numpy as np
from PIL import Image
import logging
from typing import Dict, Any, Optional, Tuple
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class ImageBlender:
"""
Advanced image blending with aggressive spill suppression and color replacement
Completely eliminates yellow edge residue while maintaining sharp edges
"""
EDGE_EROSION_PIXELS = 1 # Pixels to erode from mask edge (reduced to protect more foreground)
ALPHA_BINARIZE_THRESHOLD = 0.5 # Alpha threshold for binarization (increased to keep more foreground)
DARK_LUMINANCE_THRESHOLD = 60 # Luminance threshold for dark foreground detection
FOREGROUND_PROTECTION_THRESHOLD = 140 # Mask value above which pixels are strongly protected
BACKGROUND_COLOR_TOLERANCE = 30 # DeltaE tolerance for background color detection
def __init__(self, enable_multi_scale: bool = True):
"""
Initialize ImageBlender.
Args:
enable_multi_scale: Whether to enable multi-scale edge refinement (default True)
"""
self.enable_multi_scale = enable_multi_scale
self._debug_info = {}
self._adaptive_strength_map = None
def _erode_mask_edges(
self,
mask_array: np.ndarray,
erosion_pixels: int = 2
) -> np.ndarray:
"""
Erode mask edges to remove contaminated boundary pixels.
This removes the outermost pixels of the foreground mask where
color contamination from the original background is most likely.
Args:
mask_array: Input mask as numpy array (uint8, 0-255)
erosion_pixels: Number of pixels to erode (default 2)
Returns:
Eroded mask array (uint8)
"""
if erosion_pixels <= 0:
return mask_array
# Use elliptical kernel for natural-looking erosion
kernel_size = max(2, erosion_pixels)
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE,
(kernel_size, kernel_size)
)
# Apply erosion
eroded = cv2.erode(mask_array, kernel, iterations=1)
# Slight blur to smooth the eroded edges
eroded = cv2.GaussianBlur(eroded, (3, 3), 0)
logger.debug(f"Mask erosion applied: {erosion_pixels}px, kernel size: {kernel_size}")
return eroded
def _binarize_edge_alpha(
self,
alpha: np.ndarray,
mask_array: np.ndarray,
orig_array: np.ndarray,
threshold: float = 0.45
) -> np.ndarray:
"""
Binarize semi-transparent edge pixels to eliminate color bleeding.
Semi-transparent pixels at edges cause visible contamination because
they blend the original (potentially dark) foreground with the new
background. This method forces edge pixels to be either fully opaque
or fully transparent.
Args:
alpha: Current alpha channel (float32, 0.0-1.0)
mask_array: Original mask array (uint8, 0-255)
orig_array: Original foreground image array (uint8, RGB)
threshold: Alpha threshold for binarization decision (default 0.45)
Returns:
Modified alpha array with binarized edges (float32)
"""
# Identify semi-transparent edge zone (not fully opaque, not fully transparent)
edge_zone = (alpha > 0.05) & (alpha < 0.95)
if not np.any(edge_zone):
return alpha
# Calculate local foreground luminance for adaptive thresholding
gray = np.mean(orig_array, axis=2)
# For dark foreground pixels, use slightly higher threshold
# to preserve more of the dark subject
is_dark = gray < self.DARK_LUMINANCE_THRESHOLD
# Create adaptive threshold map
adaptive_threshold = np.full_like(alpha, threshold)
adaptive_threshold[is_dark] = threshold + 0.1 # Keep more dark pixels
# Binarize: above threshold -> opaque, below -> transparent
alpha_binarized = alpha.copy()
# Pixels above threshold become fully opaque
make_opaque = edge_zone & (alpha > adaptive_threshold)
alpha_binarized[make_opaque] = 1.0
# Pixels below threshold become fully transparent
make_transparent = edge_zone & (alpha <= adaptive_threshold)
alpha_binarized[make_transparent] = 0.0
# Log statistics
num_opaque = np.sum(make_opaque)
num_transparent = np.sum(make_transparent)
logger.info(f"Edge binarization: {num_opaque} pixels -> opaque, {num_transparent} pixels -> transparent")
return alpha_binarized
def _apply_edge_cleanup(
self,
result_array: np.ndarray,
bg_array: np.ndarray,
alpha: np.ndarray,
cleanup_width: int = 2
) -> np.ndarray:
"""
Final cleanup pass to remove any remaining edge artifacts.
Detects remaining semi-transparent edges and replaces them with
either pure foreground or pure background colors.
Args:
result_array: Current blended result (uint8, RGB)
bg_array: Background image array (uint8, RGB)
alpha: Final alpha channel (float32, 0.0-1.0)
cleanup_width: Width of edge zone to clean (default 2)
Returns:
Cleaned result array (uint8)
"""
# Find edge pixels that might still have artifacts
# These are pixels with alpha close to but not exactly 0 or 1
residual_edge = (alpha > 0.01) & (alpha < 0.99) & (alpha != 0.0) & (alpha != 1.0)
if not np.any(residual_edge):
return result_array
result_cleaned = result_array.copy()
# For residual edge pixels, snap to nearest pure state
snap_to_bg = residual_edge & (alpha < 0.5)
snap_to_fg = residual_edge & (alpha >= 0.5)
# Replace with background
result_cleaned[snap_to_bg] = bg_array[snap_to_bg]
# For foreground, keep original but ensure no blending artifacts
# (already handled by the blend, so no action needed for snap_to_fg)
num_cleaned = np.sum(residual_edge)
if num_cleaned > 0:
logger.debug(f"Edge cleanup: {num_cleaned} residual pixels cleaned")
return result_cleaned
def _remove_background_color_contamination(
self,
image_array: np.ndarray,
mask_array: np.ndarray,
orig_bg_color_lab: np.ndarray,
tolerance: float = 30.0
) -> np.ndarray:
"""
Remove original background color contamination from foreground pixels.
Scans the foreground area for pixels that match the original background
color and replaces them with nearby clean foreground colors.
Args:
image_array: Foreground image array (uint8, RGB)
mask_array: Mask array (uint8, 0-255)
orig_bg_color_lab: Original background color in Lab space
tolerance: DeltaE tolerance for detecting contaminated pixels
Returns:
Cleaned image array (uint8)
"""
# Convert to Lab for color comparison
image_lab = cv2.cvtColor(image_array, cv2.COLOR_RGB2LAB).astype(np.float32)
# Only process foreground pixels (mask > 50)
foreground_mask = mask_array > 50
if not np.any(foreground_mask):
return image_array
# Calculate deltaE from original background color for all pixels
delta_l = image_lab[:, :, 0] - orig_bg_color_lab[0]
delta_a = image_lab[:, :, 1] - orig_bg_color_lab[1]
delta_b = image_lab[:, :, 2] - orig_bg_color_lab[2]
delta_e = np.sqrt(delta_l**2 + delta_a**2 + delta_b**2)
# Find contaminated pixels: in foreground but color similar to original background
contaminated = foreground_mask & (delta_e < tolerance)
if not np.any(contaminated):
logger.debug("No background color contamination detected in foreground")
return image_array
num_contaminated = np.sum(contaminated)
logger.info(f"Found {num_contaminated} pixels with background color contamination")
# Create output array
result = image_array.copy()
# For contaminated pixels, use inpainting to replace with surrounding colors
inpaint_mask = contaminated.astype(np.uint8) * 255
try:
# Use inpainting to fill contaminated areas with surrounding foreground colors
result = cv2.inpaint(result, inpaint_mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
logger.info(f"Inpainted {num_contaminated} contaminated pixels")
except Exception as e:
logger.warning(f"Inpainting failed: {e}, using median filter fallback")
# Fallback: apply median filter to contaminated areas
median_filtered = cv2.medianBlur(image_array, 5)
result[contaminated] = median_filtered[contaminated]
return result
def _protect_foreground_core(
self,
result_array: np.ndarray,
orig_array: np.ndarray,
mask_array: np.ndarray,
protection_threshold: int = 140
) -> np.ndarray:
"""
Strongly protect core foreground pixels from any background influence.
For pixels with high mask confidence, directly use the original foreground
color without any blending, ensuring faces and bodies are not affected.
Args:
result_array: Current blended result (uint8, RGB)
orig_array: Original foreground image (uint8, RGB)
mask_array: Mask array (uint8, 0-255)
protection_threshold: Mask value above which pixels are fully protected
Returns:
Protected result array (uint8)
"""
# Identify strongly protected foreground pixels
strong_foreground = mask_array >= protection_threshold
if not np.any(strong_foreground):
return result_array
# For these pixels, use original foreground color directly
result_protected = result_array.copy()
result_protected[strong_foreground] = orig_array[strong_foreground]
num_protected = np.sum(strong_foreground)
logger.info(f"Protected {num_protected} core foreground pixels from background influence")
return result_protected
def multi_scale_edge_refinement(
self,
original_image: Image.Image,
background_image: Image.Image,
mask: Image.Image
) -> Image.Image:
"""
Multi-scale edge refinement for better edge quality.
Uses image pyramid to handle edges at different scales.
Args:
original_image: Foreground PIL Image
background_image: Background PIL Image
mask: Current mask PIL Image
Returns:
Refined mask PIL Image
"""
logger.info("🔍 Starting multi-scale edge refinement...")
try:
# Convert to numpy arrays
orig_array = np.array(original_image.convert('RGB'))
mask_array = np.array(mask).astype(np.float32)
height, width = mask_array.shape
# Define scales for pyramid
scales = [1.0, 0.5, 0.25] # Original, half, quarter
scale_masks = []
scale_complexities = []
# Convert to grayscale for edge detection
gray = cv2.cvtColor(orig_array, cv2.COLOR_RGB2GRAY)
for scale in scales:
if scale == 1.0:
scaled_gray = gray
scaled_mask = mask_array
else:
new_h = int(height * scale)
new_w = int(width * scale)
scaled_gray = cv2.resize(gray, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
scaled_mask = cv2.resize(mask_array, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
# Compute local complexity using gradient standard deviation
sobel_x = cv2.Sobel(scaled_gray, cv2.CV_64F, 1, 0, ksize=3)
sobel_y = cv2.Sobel(scaled_gray, cv2.CV_64F, 0, 1, ksize=3)
gradient_mag = np.sqrt(sobel_x**2 + sobel_y**2)
# Calculate local complexity in 5x5 regions
kernel_size = 5
complexity = cv2.blur(gradient_mag, (kernel_size, kernel_size))
# Resize back to original size
if scale != 1.0:
scaled_mask = cv2.resize(scaled_mask, (width, height), interpolation=cv2.INTER_LANCZOS4)
complexity = cv2.resize(complexity, (width, height), interpolation=cv2.INTER_LANCZOS4)
scale_masks.append(scaled_mask)
scale_complexities.append(complexity)
# Compute weights based on complexity
# High complexity -> use high resolution mask
# Low complexity -> use low resolution mask (smoother)
weights = np.zeros((len(scales), height, width), dtype=np.float32)
# Normalize complexities
max_complexity = max(c.max() for c in scale_complexities) + 1e-6
normalized_complexities = [c / max_complexity for c in scale_complexities]
# Weight assignment: higher complexity at each scale means that scale is more reliable
for i, complexity in enumerate(normalized_complexities):
if i == 0: # High resolution - prefer for high complexity regions
weights[i] = complexity
elif i == 1: # Medium resolution - moderate complexity
weights[i] = 0.5 * (1 - complexity) + 0.5 * complexity * 0.5
else: # Low resolution - prefer for low complexity regions
weights[i] = 1 - complexity
# Normalize weights so they sum to 1 at each pixel
weight_sum = weights.sum(axis=0, keepdims=True) + 1e-6
weights = weights / weight_sum
# Weighted blend of masks from different scales
refined_mask = np.zeros((height, width), dtype=np.float32)
for i, mask_i in enumerate(scale_masks):
refined_mask += weights[i] * mask_i
# Clip and convert to uint8
refined_mask = np.clip(refined_mask, 0, 255).astype(np.uint8)
logger.info("✅ Multi-scale edge refinement completed")
return Image.fromarray(refined_mask, mode='L')
except Exception as e:
logger.error(f"❌ Multi-scale refinement failed: {e}, using original mask")
return mask
def simple_blend_images(
self,
original_image: Image.Image,
background_image: Image.Image,
combination_mask: Image.Image,
use_multi_scale: Optional[bool] = None
) -> Image.Image:
"""
Aggressive spill suppression + color replacement: completely eliminate yellow edge residue, maintain sharp edges
Args:
original_image: Foreground PIL Image
background_image: Background PIL Image
combination_mask: Mask PIL Image (L mode)
use_multi_scale: Override for multi-scale refinement (None = use class default)
Returns:
Blended PIL Image
"""
logger.info("🎨 Starting advanced image blending process...")
# Apply multi-scale edge refinement if enabled
should_use_multi_scale = use_multi_scale if use_multi_scale is not None else self.enable_multi_scale
if should_use_multi_scale:
combination_mask = self.multi_scale_edge_refinement(
original_image, background_image, combination_mask
)
# Convert to numpy arrays
orig_array = np.array(original_image, dtype=np.uint8)
bg_array = np.array(background_image, dtype=np.uint8)
mask_array = np.array(combination_mask, dtype=np.uint8)
logger.info(f"📊 Image dimensions - Original: {orig_array.shape}, Background: {bg_array.shape}, Mask: {mask_array.shape}")
logger.info(f"📊 Mask statistics (before erosion) - Mean: {mask_array.mean():.1f}, Min: {mask_array.min()}, Max: {mask_array.max()}")
# === NEW: Apply mask erosion to remove contaminated edge pixels ===
mask_array = self._erode_mask_edges(mask_array, self.EDGE_EROSION_PIXELS)
logger.info(f"📊 Mask statistics (after erosion) - Mean: {mask_array.mean():.1f}, Min: {mask_array.min()}, Max: {mask_array.max()}")
# Enhanced parameters for better spill suppression
RING_WIDTH_PX = 4 # Increased ring width for better coverage
SPILL_STRENGTH = 0.85 # Stronger spill suppression
L_MATCH_STRENGTH = 0.65 # Stronger luminance matching
DELTAE_THRESHOLD = 18 # More aggressive contamination detection
HARD_EDGE_PROTECT = True # Black edge protection
INPAINT_FALLBACK = True # inpaint fallback repair
MULTI_PASS_CORRECTION = True # Enable multi-pass correction
# Estimate original background color and foreground representative color ===
height, width = orig_array.shape[:2]
# Take 15px from each side to estimate original background color
edge_width = 15
border_pixels = []
# Collect border pixels (excluding foreground areas)
border_mask = np.zeros((height, width), dtype=bool)
border_mask[:edge_width, :] = True # Top edge
border_mask[-edge_width:, :] = True # Bottom edge
border_mask[:, :edge_width] = True # Left edge
border_mask[:, -edge_width:] = True # Right edge
# Exclude foreground areas
fg_binary = mask_array > 50
border_mask = border_mask & (~fg_binary)
if np.any(border_mask):
border_pixels = orig_array[border_mask].reshape(-1, 3)
# Simplified background color estimation (no sklearn dependency)
try:
if len(border_pixels) > 100:
# Use histogram to find mode colors
# Quantize RGB to coarser grid to find main colors
quantized = (border_pixels // 32) * 32 # 8-level quantization
# Find most frequent color
unique_colors, counts = np.unique(quantized.reshape(-1, quantized.shape[-1]),
axis=0, return_counts=True)
most_common_idx = np.argmax(counts)
orig_bg_color_rgb = unique_colors[most_common_idx].astype(np.uint8)
else:
orig_bg_color_rgb = np.median(border_pixels, axis=0).astype(np.uint8)
except:
# Fallback: use four corners average
corners = np.array([orig_array[0,0], orig_array[0,-1],
orig_array[-1,0], orig_array[-1,-1]])
orig_bg_color_rgb = np.mean(corners, axis=0).astype(np.uint8)
else:
orig_bg_color_rgb = np.array([200, 180, 120], dtype=np.uint8) # Default yellow
# Convert to Lab space
orig_bg_color_lab = cv2.cvtColor(orig_bg_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
logger.info(f"🎨 Detected original background color: RGB{tuple(orig_bg_color_rgb)}")
# Remove original background color contamination from foreground
orig_array = self._remove_background_color_contamination(
orig_array,
mask_array,
orig_bg_color_lab,
tolerance=self.BACKGROUND_COLOR_TOLERANCE
)
# Redefine trimap, optimized for cartoon characters
try:
kernel_3x3 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
# FG_CORE: Reduce erosion iterations from 2 to 1 to avoid losing thin limbs
mask_eroded_once = cv2.erode(mask_array, kernel_3x3, iterations=1)
fg_core = mask_eroded_once > 127 # Adjustable parameter: erosion iterations
# RING: Use morphological gradient to redefine, ensuring only thin edge band
mask_dilated = cv2.dilate(mask_array, kernel_3x3, iterations=1)
mask_eroded = cv2.erode(mask_array, kernel_3x3, iterations=1)
# Ensure consistent data types to avoid overflow
morphological_gradient = cv2.subtract(mask_dilated, mask_eroded)
ring_zone = morphological_gradient > 0 # Areas with morphological gradient > 0 are edge bands
# BG: background area
bg_zone = mask_array < 30
logger.info(f"🔍 Trimap regions - FG_CORE: {fg_core.sum()}, RING: {ring_zone.sum()}, BG: {bg_zone.sum()}")
except Exception as e:
import traceback
logger.error(f"❌ Trimap definition failed: {e}")
logger.error(f"📍 Traceback: {traceback.format_exc()}")
print(f"❌ TRIMAP ERROR: {e}")
print(f"Traceback: {traceback.format_exc()}")
# Fallback to simple definition
fg_core = mask_array > 200
ring_zone = (mask_array > 50) & (mask_array <= 200)
bg_zone = mask_array <= 50
# Foreground representative color: estimated from FG_CORE
if np.any(fg_core):
fg_pixels = orig_array[fg_core].reshape(-1, 3)
fg_rep_color_rgb = np.median(fg_pixels, axis=0).astype(np.uint8)
else:
fg_rep_color_rgb = np.array([80, 60, 40], dtype=np.uint8) # Default dark
fg_rep_color_lab = cv2.cvtColor(fg_rep_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
# Edge band spill suppression and repair
if np.any(ring_zone):
# Convert to Lab space
orig_lab = cv2.cvtColor(orig_array, cv2.COLOR_RGB2LAB).astype(np.float32)
orig_array_working = orig_array.copy().astype(np.float32)
# ΔE detect contaminated pixels
ring_pixels_lab = orig_lab[ring_zone]
# Calculate ΔE with original background color (simplified version)
delta_l = ring_pixels_lab[:, 0] - orig_bg_color_lab[0]
delta_a = ring_pixels_lab[:, 1] - orig_bg_color_lab[1]
delta_b = ring_pixels_lab[:, 2] - orig_bg_color_lab[2]
delta_e = np.sqrt(delta_l**2 + delta_a**2 + delta_b**2)
# Contaminated pixel mask
contaminated_mask = delta_e < DELTAE_THRESHOLD
if np.any(contaminated_mask):
# Calculate adaptive strength based on delta_e for each pixel
# Pixels closer to background color get stronger correction
contaminated_delta_e = delta_e[contaminated_mask]
# Adaptive strength formula: inverse relationship with delta_e
# Pixels very close to bg color (low delta_e) -> strong correction
# Pixels further from bg color (high delta_e) -> lighter correction
adaptive_strength = SPILL_STRENGTH * np.maximum(
0.0,
1.0 - (contaminated_delta_e / DELTAE_THRESHOLD)
)
# Clamp adaptive strength to reasonable range (30% - 100% of base strength)
min_strength = SPILL_STRENGTH * 0.3
adaptive_strength = np.clip(adaptive_strength, min_strength, SPILL_STRENGTH)
# Store for debug visualization
self._adaptive_strength_map = np.zeros_like(delta_e)
self._adaptive_strength_map[contaminated_mask] = adaptive_strength
logger.info(f"📊 Adaptive strength stats - Mean: {adaptive_strength.mean():.3f}, Min: {adaptive_strength.min():.3f}, Max: {adaptive_strength.max():.3f}")
# Chroma vector deprojection
bg_chroma = np.array([orig_bg_color_lab[1], orig_bg_color_lab[2]])
bg_chroma_norm = bg_chroma / (np.linalg.norm(bg_chroma) + 1e-6)
# Color correction for contaminated pixels
contaminated_pixels = ring_pixels_lab[contaminated_mask]
# Remove background chroma component with adaptive strength (per-pixel)
pixel_chroma = contaminated_pixels[:, 1:3] # a, b channels
projection = np.dot(pixel_chroma, bg_chroma_norm)[:, np.newaxis] * bg_chroma_norm
# Apply adaptive strength per pixel
adaptive_strength_2d = adaptive_strength[:, np.newaxis]
corrected_chroma = pixel_chroma - projection * adaptive_strength_2d
# Converge toward foreground representative color with adaptive strength
convergence_factor = adaptive_strength_2d * 0.6
corrected_chroma = (corrected_chroma * (1 - convergence_factor) +
fg_rep_color_lab[1:3] * convergence_factor)
# Adaptive luminance matching
adaptive_l_strength = adaptive_strength * (L_MATCH_STRENGTH / SPILL_STRENGTH)
corrected_l = (contaminated_pixels[:, 0] * (1 - adaptive_l_strength) +
fg_rep_color_lab[0] * adaptive_l_strength)
# Update Lab values
ring_pixels_lab[contaminated_mask, 0] = corrected_l
ring_pixels_lab[contaminated_mask, 1:3] = corrected_chroma
# Write back to original image
orig_lab[ring_zone] = ring_pixels_lab
# Dark edge protection
if HARD_EDGE_PROTECT:
gray = np.mean(orig_array, axis=2)
# Detect dark and high gradient areas
sobel_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
sobel_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
gradient_mag = np.sqrt(sobel_x**2 + sobel_y**2)
dark_edge_zone = ring_zone & (gray < 60) & (gradient_mag > 20)
# Protect these areas from excessive modification, copy directly from original
if np.any(dark_edge_zone):
orig_lab[dark_edge_zone] = cv2.cvtColor(orig_array, cv2.COLOR_RGB2LAB)[dark_edge_zone]
# Multi-pass correction for stubborn spill
if MULTI_PASS_CORRECTION:
# Second pass for remaining contamination
ring_pixels_lab_pass2 = orig_lab[ring_zone]
delta_l_pass2 = ring_pixels_lab_pass2[:, 0] - orig_bg_color_lab[0]
delta_a_pass2 = ring_pixels_lab_pass2[:, 1] - orig_bg_color_lab[1]
delta_b_pass2 = ring_pixels_lab_pass2[:, 2] - orig_bg_color_lab[2]
delta_e_pass2 = np.sqrt(delta_l_pass2**2 + delta_a_pass2**2 + delta_b_pass2**2)
still_contaminated = delta_e_pass2 < (DELTAE_THRESHOLD * 0.8)
if np.any(still_contaminated):
# Apply stronger correction to remaining contaminated pixels
remaining_pixels = ring_pixels_lab_pass2[still_contaminated]
# More aggressive chroma neutralization
remaining_chroma = remaining_pixels[:, 1:3]
neutralized_chroma = remaining_chroma * 0.3 + fg_rep_color_lab[1:3] * 0.7
# Stronger luminance matching
neutralized_l = remaining_pixels[:, 0] * 0.4 + fg_rep_color_lab[0] * 0.6
ring_pixels_lab_pass2[still_contaminated, 0] = neutralized_l
ring_pixels_lab_pass2[still_contaminated, 1:3] = neutralized_chroma
orig_lab[ring_zone] = ring_pixels_lab_pass2
# Convert back to RGB
orig_lab_clipped = np.clip(orig_lab, 0, 255).astype(np.uint8)
orig_array_corrected = cv2.cvtColor(orig_lab_clipped, cv2.COLOR_LAB2RGB)
# inpaint fallback repair
if INPAINT_FALLBACK:
# inpaint still contaminated outermost pixels
final_contaminated = ring_zone.copy()
# Check if there's still contamination after repair
final_lab = cv2.cvtColor(orig_array_corrected, cv2.COLOR_RGB2LAB).astype(np.float32)
final_ring_lab = final_lab[ring_zone]
final_delta_l = final_ring_lab[:, 0] - orig_bg_color_lab[0]
final_delta_a = final_ring_lab[:, 1] - orig_bg_color_lab[1]
final_delta_b = final_ring_lab[:, 2] - orig_bg_color_lab[2]
final_delta_e = np.sqrt(final_delta_l**2 + final_delta_a**2 + final_delta_b**2)
still_contaminated = final_delta_e < (DELTAE_THRESHOLD * 0.5)
if np.any(still_contaminated):
# Create inpaint mask
inpaint_mask = np.zeros((height, width), dtype=np.uint8)
ring_coords = np.where(ring_zone)
inpaint_coords = (ring_coords[0][still_contaminated], ring_coords[1][still_contaminated])
inpaint_mask[inpaint_coords] = 255
# Execute inpaint
try:
orig_array_corrected = cv2.inpaint(orig_array_corrected, inpaint_mask, 3, cv2.INPAINT_TELEA)
except:
# Fallback: directly cover with foreground representative color
orig_array_corrected[inpaint_coords] = fg_rep_color_rgb
orig_array = orig_array_corrected
# === Linear space blending (keep original logic) ===
def srgb_to_linear(img):
img_f = img.astype(np.float32) / 255.0
return np.where(img_f <= 0.04045, img_f / 12.92, np.power((img_f + 0.055) / 1.055, 2.4))
def linear_to_srgb(img):
img_clipped = np.clip(img, 0, 1)
return np.where(img_clipped <= 0.0031308,
12.92 * img_clipped,
1.055 * np.power(img_clipped, 1/2.4) - 0.055)
orig_linear = srgb_to_linear(orig_array)
bg_linear = srgb_to_linear(bg_array)
# === Cartoon-optimized Alpha calculation ===
alpha = mask_array.astype(np.float32) / 255.0
# Core foreground region - fully opaque
alpha[fg_core] = 1.0
# Background region - fully transparent
alpha[bg_zone] = 0.0
# [Key Fix] Force pixels with mask≥160 to α=1.0, avoiding white fill areas being limited to 0.9
high_confidence_pixels = mask_array >= 160
alpha[high_confidence_pixels] = 1.0
logger.info(f"💯 High confidence pixels set to full opacity: {high_confidence_pixels.sum()}")
# Ring area can be dehaloed, but doesn't affect already set high confidence pixels
ring_without_high_conf = ring_zone & (~high_confidence_pixels)
alpha[ring_without_high_conf] = np.clip(alpha[ring_without_high_conf], 0.2, 0.9)
# Retain existing black outline/strong edge protection
orig_gray = np.mean(orig_array, axis=2)
# Detect strong edge areas
sobel_x = cv2.Sobel(orig_gray, cv2.CV_64F, 1, 0, ksize=3)
sobel_y = cv2.Sobel(orig_gray, cv2.CV_64F, 0, 1, ksize=3)
gradient_mag = np.sqrt(sobel_x**2 + sobel_y**2)
# Black outline/strong edge protection: nearly fully opaque
black_edge_threshold = 60 # black edge threshold
gradient_threshold = 25 # gradient threshold
strong_edges = (orig_gray < black_edge_threshold) & (gradient_mag > gradient_threshold) & (mask_array > 10)
alpha[strong_edges] = np.maximum(alpha[strong_edges], 0.995) # black edge alpha
logger.info(f"🛡️ Protection applied - High conf: {high_confidence_pixels.sum()}, Strong edges: {strong_edges.sum()}")
# Apply edge alpha binarization to eliminate semi-transparent artifacts
alpha = self._binarize_edge_alpha(
alpha,
mask_array,
orig_array,
threshold=self.ALPHA_BINARIZE_THRESHOLD
)
# Final blending
alpha_3d = alpha[:, :, np.newaxis]
result_linear = orig_linear * alpha_3d + bg_linear * (1 - alpha_3d)
result_srgb = linear_to_srgb(result_linear)
result_array = (result_srgb * 255).astype(np.uint8)
# Final edge cleanup pass
result_array = self._apply_edge_cleanup(result_array, bg_array, alpha)
# Protect core foreground from any background influence
# This ensures faces and bodies retain original colors
result_array = self._protect_foreground_core(
result_array,
np.array(original_image, dtype=np.uint8), # Use original unprocessed image
mask_array,
protection_threshold=self.FOREGROUND_PROTECTION_THRESHOLD
)
# Store debug information (for debug output)
self._debug_info = {
'orig_bg_color_rgb': orig_bg_color_rgb,
'fg_rep_color_rgb': fg_rep_color_rgb,
'orig_bg_color_lab': orig_bg_color_lab,
'fg_rep_color_lab': fg_rep_color_lab,
'ring_zone': ring_zone,
'fg_core': fg_core,
'alpha_final': alpha
}
return Image.fromarray(result_array)
def create_debug_images(
self,
original_image: Image.Image,
generated_background: Image.Image,
combination_mask: Image.Image,
combined_image: Image.Image
) -> Dict[str, Image.Image]:
"""
Generate debug images: (a) Final mask grayscale (b) Alpha heatmap (c) Ring visualization overlay
"""
debug_images = {}
# Final Mask grayscale
debug_images["mask_gray"] = combination_mask.convert('L')
# Alpha Heatmap
mask_array = np.array(combination_mask.convert('L'))
heatmap_colored = cv2.applyColorMap(mask_array, cv2.COLORMAP_JET)
heatmap_rgb = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
debug_images["alpha_heatmap"] = Image.fromarray(heatmap_rgb)
# Ring visualization overlay - show ring areas on original image
if hasattr(self, '_debug_info') and 'ring_zone' in self._debug_info:
ring_zone = self._debug_info['ring_zone']
orig_array = np.array(original_image)
ring_overlay = orig_array.copy()
# Mark ring areas with red semi-transparent overlay
ring_overlay[ring_zone] = ring_overlay[ring_zone] * 0.7 + np.array([255, 0, 0]) * 0.3
debug_images["ring_visualization"] = Image.fromarray(ring_overlay.astype(np.uint8))
else:
# If no ring information, use original image
debug_images["ring_visualization"] = original_image
# Adaptive strength heatmap - visualize per-pixel correction strength
if hasattr(self, '_adaptive_strength_map') and self._adaptive_strength_map is not None:
# Normalize adaptive strength to 0-255 for visualization
strength_map = self._adaptive_strength_map
if strength_map.max() > 0:
normalized_strength = (strength_map / strength_map.max() * 255).astype(np.uint8)
else:
normalized_strength = np.zeros_like(strength_map, dtype=np.uint8)
# Apply colormap
strength_heatmap = cv2.applyColorMap(normalized_strength, cv2.COLORMAP_VIRIDIS)
strength_heatmap_rgb = cv2.cvtColor(strength_heatmap, cv2.COLOR_BGR2RGB)
debug_images["adaptive_strength_heatmap"] = Image.fromarray(strength_heatmap_rgb)
return debug_images
|