import logging import numpy as np import cv2 from PIL import Image from typing import Dict, Any, Tuple, Optional, List from dataclasses import dataclass logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @dataclass class QualityResult: """Result of a quality check.""" score: float # 0-100 passed: bool issue: str details: Dict[str, Any] class QualityChecker: """ Automated quality validation system for generated images. Provides checks for mask coverage, edge continuity, and color harmony. """ # Quality thresholds THRESHOLD_PASS = 70 THRESHOLD_WARNING = 50 def __init__(self, strictness: str = "standard"): """ Initialize QualityChecker. Args: strictness: Quality check strictness level "lenient" - Only check fatal issues "standard" - All checks with moderate thresholds "strict" - High standards required """ self.strictness = strictness self._set_thresholds() def _set_thresholds(self): """Set quality thresholds based on strictness level.""" if self.strictness == "lenient": self.min_coverage = 0.03 # 3% self.min_edge_score = 40 self.min_harmony_score = 40 elif self.strictness == "strict": self.min_coverage = 0.10 # 10% self.min_edge_score = 75 self.min_harmony_score = 75 else: # standard self.min_coverage = 0.05 # 5% self.min_edge_score = 60 self.min_harmony_score = 60 def check_mask_coverage(self, mask: Image.Image) -> QualityResult: """ Verify mask coverage is adequate. Args: mask: Grayscale mask image (L mode) Returns: QualityResult with coverage analysis """ try: mask_array = np.array(mask.convert('L')) height, width = mask_array.shape total_pixels = height * width # Count foreground pixels fg_pixels = np.count_nonzero(mask_array > 127) coverage_ratio = fg_pixels / total_pixels # Check for isolated small regions (noise) _, binary = cv2.threshold(mask_array, 127, 255, cv2.THRESH_BINARY) num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary, connectivity=8) # Count significant regions (> 1% of image) min_region_size = total_pixels * 0.01 significant_regions = sum(1 for i in range(1, num_labels) if stats[i, cv2.CC_STAT_AREA] > min_region_size) # Calculate fragmentation (many small regions = bad) fragmentation_penalty = max(0, (num_labels - 1 - significant_regions) * 2) # Score calculation coverage_score = min(100, coverage_ratio * 200) # 50% coverage = 100 score final_score = max(0, coverage_score - fragmentation_penalty) # Determine pass/fail passed = coverage_ratio >= self.min_coverage and significant_regions >= 1 issue = "" if coverage_ratio < self.min_coverage: issue = f"Low foreground coverage ({coverage_ratio:.1%})" elif significant_regions == 0: issue = "No significant foreground regions detected" elif fragmentation_penalty > 20: issue = f"Fragmented mask with {num_labels - 1} isolated regions" return QualityResult( score=final_score, passed=passed, issue=issue, details={ "coverage_ratio": coverage_ratio, "foreground_pixels": fg_pixels, "total_regions": num_labels - 1, "significant_regions": significant_regions } ) except Exception as e: logger.error(f"❌ Mask coverage check failed: {e}") return QualityResult(score=0, passed=False, issue=str(e), details={}) def check_edge_continuity(self, mask: Image.Image) -> QualityResult: """ Check if mask edges are continuous and smooth. Args: mask: Grayscale mask image Returns: QualityResult with edge analysis """ try: mask_array = np.array(mask.convert('L')) # Find edges using morphological gradient kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) gradient = cv2.morphologyEx(mask_array, cv2.MORPH_GRADIENT, kernel) # Get edge pixels edge_pixels = gradient > 20 edge_count = np.count_nonzero(edge_pixels) if edge_count == 0: return QualityResult( score=50, passed=False, issue="No edges detected in mask", details={"edge_count": 0} ) # Check edge smoothness using Laplacian laplacian = cv2.Laplacian(mask_array, cv2.CV_64F) edge_laplacian = np.abs(laplacian[edge_pixels]) # High Laplacian values indicate jagged edges smoothness = 100 - min(100, np.std(edge_laplacian) * 0.5) # Check for gaps in edges # Dilate and erode to find disconnections dilated = cv2.dilate(gradient, kernel, iterations=1) eroded = cv2.erode(dilated, kernel, iterations=1) gaps = cv2.subtract(dilated, eroded) gap_ratio = np.count_nonzero(gaps) / max(edge_count, 1) # Calculate final score gap_penalty = min(40, gap_ratio * 100) final_score = max(0, smoothness - gap_penalty) passed = final_score >= self.min_edge_score issue = "" if final_score < self.min_edge_score: if smoothness < 60: issue = "Jagged or rough edges detected" elif gap_ratio > 0.3: issue = "Discontinuous edges with gaps" else: issue = "Poor edge quality" return QualityResult( score=final_score, passed=passed, issue=issue, details={ "edge_count": edge_count, "smoothness": smoothness, "gap_ratio": gap_ratio } ) except Exception as e: logger.error(f"❌ Edge continuity check failed: {e}") return QualityResult(score=0, passed=False, issue=str(e), details={}) def check_color_harmony( self, foreground: Image.Image, background: Image.Image, mask: Image.Image ) -> QualityResult: """ Evaluate color harmony between foreground and background. Args: foreground: Original foreground image background: Generated background image mask: Combination mask Returns: QualityResult with harmony analysis """ try: fg_array = np.array(foreground.convert('RGB')) bg_array = np.array(background.convert('RGB')) mask_array = np.array(mask.convert('L')) # Get foreground and background regions fg_region = mask_array > 127 bg_region = mask_array <= 127 if not np.any(fg_region) or not np.any(bg_region): return QualityResult( score=50, passed=True, issue="Cannot analyze harmony - insufficient regions", details={} ) # Convert to LAB for perceptual analysis fg_lab = cv2.cvtColor(fg_array, cv2.COLOR_RGB2LAB).astype(np.float32) bg_lab = cv2.cvtColor(bg_array, cv2.COLOR_RGB2LAB).astype(np.float32) # Calculate average colors fg_avg_l = np.mean(fg_lab[fg_region, 0]) fg_avg_a = np.mean(fg_lab[fg_region, 1]) fg_avg_b = np.mean(fg_lab[fg_region, 2]) bg_avg_l = np.mean(bg_lab[bg_region, 0]) bg_avg_a = np.mean(bg_lab[bg_region, 1]) bg_avg_b = np.mean(bg_lab[bg_region, 2]) # Calculate color differences delta_l = abs(fg_avg_l - bg_avg_l) delta_a = abs(fg_avg_a - bg_avg_a) delta_b = abs(fg_avg_b - bg_avg_b) # Overall color difference (Delta E approximation) delta_e = np.sqrt(delta_l**2 + delta_a**2 + delta_b**2) # Score calculation # Moderate difference is good (20-60 Delta E) # Too similar or too different is problematic if delta_e < 10: harmony_score = 60 # Too similar, foreground may get lost issue = "Foreground and background colors too similar" elif delta_e > 80: harmony_score = 50 # Too different, may look unnatural issue = "High color contrast may look unnatural" elif 20 <= delta_e <= 60: harmony_score = 100 # Ideal range issue = "" else: harmony_score = 80 issue = "" # Check for extreme contrast (very dark fg on very bright bg or vice versa) brightness_contrast = abs(fg_avg_l - bg_avg_l) if brightness_contrast > 100: harmony_score = max(40, harmony_score - 30) issue = "Extreme brightness contrast between foreground and background" passed = harmony_score >= self.min_harmony_score return QualityResult( score=harmony_score, passed=passed, issue=issue, details={ "delta_e": delta_e, "delta_l": delta_l, "delta_a": delta_a, "delta_b": delta_b, "fg_luminance": fg_avg_l, "bg_luminance": bg_avg_l } ) except Exception as e: logger.error(f"❌ Color harmony check failed: {e}") return QualityResult(score=0, passed=False, issue=str(e), details={}) def run_all_checks( self, foreground: Image.Image, background: Image.Image, mask: Image.Image, combined: Optional[Image.Image] = None ) -> Dict[str, Any]: """ Run all quality checks and return comprehensive results. Args: foreground: Original foreground image background: Generated background mask: Combination mask combined: Final combined image (optional) Returns: Dictionary with all check results and overall score """ logger.info("🔍 Running quality checks...") results = { "checks": {}, "overall_score": 0, "passed": True, "warnings": [], "errors": [] } # Run individual checks coverage_result = self.check_mask_coverage(mask) results["checks"]["mask_coverage"] = { "score": coverage_result.score, "passed": coverage_result.passed, "issue": coverage_result.issue, "details": coverage_result.details } edge_result = self.check_edge_continuity(mask) results["checks"]["edge_continuity"] = { "score": edge_result.score, "passed": edge_result.passed, "issue": edge_result.issue, "details": edge_result.details } harmony_result = self.check_color_harmony(foreground, background, mask) results["checks"]["color_harmony"] = { "score": harmony_result.score, "passed": harmony_result.passed, "issue": harmony_result.issue, "details": harmony_result.details } # Calculate overall score (weighted average) weights = { "mask_coverage": 0.4, "edge_continuity": 0.3, "color_harmony": 0.3 } total_score = ( coverage_result.score * weights["mask_coverage"] + edge_result.score * weights["edge_continuity"] + harmony_result.score * weights["color_harmony"] ) results["overall_score"] = round(total_score, 1) # Determine overall pass/fail results["passed"] = all([ coverage_result.passed, edge_result.passed, harmony_result.passed ]) # Collect warnings and errors for check_name, check_data in results["checks"].items(): if check_data["issue"]: if check_data["passed"]: results["warnings"].append(f"{check_name}: {check_data['issue']}") else: results["errors"].append(f"{check_name}: {check_data['issue']}") logger.info(f"📊 Quality check complete - Score: {results['overall_score']}, Passed: {results['passed']}") return results def get_quality_summary(self, results: Dict[str, Any]) -> str: """ Generate human-readable quality summary. Args: results: Results from run_all_checks Returns: Summary string """ score = results["overall_score"] passed = results["passed"] if score >= 90: grade = "Excellent" elif score >= 75: grade = "Good" elif score >= 60: grade = "Acceptable" elif score >= 40: grade = "Needs Improvement" else: grade = "Poor" summary = f"Quality: {grade} ({score:.0f}/100)" if results["errors"]: summary += f"\nIssues: {'; '.join(results['errors'])}" elif results["warnings"]: summary += f"\nNotes: {'; '.join(results['warnings'])}" return summary # ========================================================================= # INPAINTING-SPECIFIC QUALITY CHECKS # ========================================================================= def check_inpainting_edge_continuity( self, original: Image.Image, inpainted: Image.Image, mask: Image.Image, ring_width: int = 5 ) -> QualityResult: """ Check edge continuity at inpainting boundary. Calculates color distribution similarity between the ring zones on each side of the mask boundary in Lab color space. Parameters ---------- original : PIL.Image Original image before inpainting inpainted : PIL.Image Result after inpainting mask : PIL.Image Inpainting mask (white = inpainted area) ring_width : int Width in pixels for the ring zones on each side Returns ------- QualityResult Edge continuity assessment """ try: # Convert to arrays orig_array = np.array(original.convert('RGB')) inpaint_array = np.array(inpainted.convert('RGB')) mask_array = np.array(mask.convert('L')) # Find boundary using morphological gradient kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) dilated = cv2.dilate(mask_array, kernel, iterations=ring_width) eroded = cv2.erode(mask_array, kernel, iterations=ring_width) # Inner ring (inside inpainted region, near boundary) inner_ring = (mask_array > 127) & (eroded <= 127) # Outer ring (outside inpainted region, near boundary) outer_ring = (mask_array <= 127) & (dilated > 127) if not np.any(inner_ring) or not np.any(outer_ring): return QualityResult( score=50, passed=True, issue="Unable to detect boundary rings", details={"ring_width": ring_width} ) # Convert to Lab for perceptual comparison inpaint_lab = cv2.cvtColor(inpaint_array, cv2.COLOR_RGB2LAB).astype(np.float32) # Get Lab values for each ring from the inpainted image inner_lab = inpaint_lab[inner_ring] outer_lab = inpaint_lab[outer_ring] # Calculate statistics for each channel inner_mean = np.mean(inner_lab, axis=0) outer_mean = np.mean(outer_lab, axis=0) inner_std = np.std(inner_lab, axis=0) outer_std = np.std(outer_lab, axis=0) # Calculate differences mean_diff = np.abs(inner_mean - outer_mean) std_diff = np.abs(inner_std - outer_std) # Calculate Delta E (simplified) delta_e = np.sqrt(np.sum(mean_diff ** 2)) # Score calculation # Low Delta E = good continuity # Target: Delta E < 10 is excellent, < 20 is good if delta_e < 5: continuity_score = 100 elif delta_e < 10: continuity_score = 90 elif delta_e < 20: continuity_score = 75 elif delta_e < 30: continuity_score = 60 elif delta_e < 50: continuity_score = 40 else: continuity_score = max(20, 100 - delta_e) # Penalize for large std differences (inconsistent textures) std_penalty = min(20, np.mean(std_diff) * 0.5) final_score = max(0, continuity_score - std_penalty) passed = final_score >= 60 issue = "" if final_score < 60: if delta_e > 30: issue = f"Visible color discontinuity at boundary (Delta E: {delta_e:.1f})" elif np.mean(std_diff) > 20: issue = "Texture mismatch at boundary" else: issue = "Poor edge blending" return QualityResult( score=final_score, passed=passed, issue=issue, details={ "delta_e": delta_e, "mean_diff_l": mean_diff[0], "mean_diff_a": mean_diff[1], "mean_diff_b": mean_diff[2], "std_diff_avg": np.mean(std_diff), "inner_pixels": np.count_nonzero(inner_ring), "outer_pixels": np.count_nonzero(outer_ring) } ) except Exception as e: logger.error(f"Inpainting edge continuity check failed: {e}") return QualityResult(score=0, passed=False, issue=str(e), details={}) def check_inpainting_color_harmony( self, original: Image.Image, inpainted: Image.Image, mask: Image.Image ) -> QualityResult: """ Check color harmony between inpainted region and surrounding area. Compares color statistics of the inpainted region with adjacent non-inpainted regions to assess visual coherence. Parameters ---------- original : PIL.Image Original image inpainted : PIL.Image Inpainted result mask : PIL.Image Inpainting mask Returns ------- QualityResult Color harmony assessment """ try: inpaint_array = np.array(inpainted.convert('RGB')) mask_array = np.array(mask.convert('L')) # Define regions inpaint_region = mask_array > 127 # Get adjacent region (dilated mask minus original mask) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) dilated = cv2.dilate(mask_array, kernel, iterations=2) adjacent_region = (dilated > 127) & (mask_array <= 127) if not np.any(inpaint_region) or not np.any(adjacent_region): return QualityResult( score=50, passed=True, issue="Insufficient regions for comparison", details={} ) # Convert to Lab inpaint_lab = cv2.cvtColor(inpaint_array, cv2.COLOR_RGB2LAB).astype(np.float32) # Extract region colors inpaint_colors = inpaint_lab[inpaint_region] adjacent_colors = inpaint_lab[adjacent_region] # Calculate color statistics inpaint_mean = np.mean(inpaint_colors, axis=0) adjacent_mean = np.mean(adjacent_colors, axis=0) inpaint_std = np.std(inpaint_colors, axis=0) adjacent_std = np.std(adjacent_colors, axis=0) # Color histogram comparison hist_scores = [] for i in range(3): # L, a, b channels hist_inpaint, _ = np.histogram( inpaint_colors[:, i], bins=32, range=(0, 255) ) hist_adjacent, _ = np.histogram( adjacent_colors[:, i], bins=32, range=(0, 255) ) # Normalize hist_inpaint = hist_inpaint.astype(np.float32) / (np.sum(hist_inpaint) + 1e-6) hist_adjacent = hist_adjacent.astype(np.float32) / (np.sum(hist_adjacent) + 1e-6) # Bhattacharyya coefficient (1 = identical, 0 = completely different) bc = np.sum(np.sqrt(hist_inpaint * hist_adjacent)) hist_scores.append(bc) avg_hist_score = np.mean(hist_scores) # Calculate harmony score mean_diff = np.linalg.norm(inpaint_mean - adjacent_mean) if mean_diff < 10 and avg_hist_score > 0.8: harmony_score = 100 elif mean_diff < 20 and avg_hist_score > 0.7: harmony_score = 85 elif mean_diff < 30 and avg_hist_score > 0.6: harmony_score = 70 elif mean_diff < 50: harmony_score = 55 else: harmony_score = max(30, 100 - mean_diff) # Boost score if histogram similarity is high histogram_bonus = (avg_hist_score - 0.5) * 20 # -10 to +10 final_score = max(0, min(100, harmony_score + histogram_bonus)) passed = final_score >= 60 issue = "" if final_score < 60: if mean_diff > 40: issue = "Significant color mismatch with surrounding area" elif avg_hist_score < 0.5: issue = "Color distribution differs from context" else: issue = "Poor color integration" return QualityResult( score=final_score, passed=passed, issue=issue, details={ "mean_color_diff": mean_diff, "histogram_similarity": avg_hist_score, "inpaint_luminance": inpaint_mean[0], "adjacent_luminance": adjacent_mean[0] } ) except Exception as e: logger.error(f"Inpainting color harmony check failed: {e}") return QualityResult(score=0, passed=False, issue=str(e), details={}) def check_inpainting_artifact_detection( self, inpainted: Image.Image, mask: Image.Image ) -> QualityResult: """ Detect common inpainting artifacts like blurriness or color bleeding. Parameters ---------- inpainted : PIL.Image Inpainted result mask : PIL.Image Inpainting mask Returns ------- QualityResult Artifact detection results """ try: inpaint_array = np.array(inpainted.convert('RGB')) mask_array = np.array(mask.convert('L')) inpaint_region = mask_array > 127 if not np.any(inpaint_region): return QualityResult( score=50, passed=True, issue="No inpainted region detected", details={} ) # Extract inpainted region pixels gray = cv2.cvtColor(inpaint_array, cv2.COLOR_RGB2GRAY) # Calculate sharpness (Laplacian variance) laplacian = cv2.Laplacian(gray, cv2.CV_64F) inpaint_laplacian = laplacian[inpaint_region] sharpness = np.var(inpaint_laplacian) # Get surrounding region for comparison kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10)) dilated = cv2.dilate(mask_array, kernel, iterations=1) surrounding = (dilated > 127) & (mask_array <= 127) if np.any(surrounding): surrounding_laplacian = laplacian[surrounding] surrounding_sharpness = np.var(surrounding_laplacian) sharpness_ratio = sharpness / (surrounding_sharpness + 1e-6) else: sharpness_ratio = 1.0 # Check for color bleeding (abnormal saturation at edges) hsv = cv2.cvtColor(inpaint_array, cv2.COLOR_RGB2HSV) saturation = hsv[:, :, 1] # Find boundary pixels boundary_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) boundary = cv2.morphologyEx(mask_array, cv2.MORPH_GRADIENT, boundary_kernel) > 0 if np.any(boundary): boundary_saturation = saturation[boundary] saturation_std = np.std(boundary_saturation) else: saturation_std = 0 # Calculate score sharpness_score = 100 if sharpness_ratio < 0.3: sharpness_score = 40 # Much blurrier than surroundings elif sharpness_ratio < 0.6: sharpness_score = 60 elif sharpness_ratio < 0.8: sharpness_score = 80 bleeding_penalty = min(20, saturation_std * 0.5) final_score = max(0, sharpness_score - bleeding_penalty) passed = final_score >= 60 issue = "" if sharpness_ratio < 0.5: issue = "Inpainted region appears blurry" elif saturation_std > 40: issue = "Possible color bleeding at edges" elif final_score < 60: issue = "Detected visual artifacts" return QualityResult( score=final_score, passed=passed, issue=issue, details={ "sharpness": sharpness, "sharpness_ratio": sharpness_ratio, "boundary_saturation_std": saturation_std } ) except Exception as e: logger.error(f"Inpainting artifact detection failed: {e}") return QualityResult(score=0, passed=False, issue=str(e), details={}) def run_inpainting_checks( self, original: Image.Image, inpainted: Image.Image, mask: Image.Image ) -> Dict[str, Any]: """ Run all inpainting-specific quality checks. Parameters ---------- original : PIL.Image Original image before inpainting inpainted : PIL.Image Result after inpainting mask : PIL.Image Inpainting mask Returns ------- dict Comprehensive quality assessment for inpainting """ logger.info("Running inpainting quality checks...") results = { "checks": {}, "overall_score": 0, "passed": True, "warnings": [], "errors": [] } # Run inpainting-specific checks edge_result = self.check_inpainting_edge_continuity(original, inpainted, mask) results["checks"]["edge_continuity"] = { "score": edge_result.score, "passed": edge_result.passed, "issue": edge_result.issue, "details": edge_result.details } harmony_result = self.check_inpainting_color_harmony(original, inpainted, mask) results["checks"]["color_harmony"] = { "score": harmony_result.score, "passed": harmony_result.passed, "issue": harmony_result.issue, "details": harmony_result.details } artifact_result = self.check_inpainting_artifact_detection(inpainted, mask) results["checks"]["artifact_detection"] = { "score": artifact_result.score, "passed": artifact_result.passed, "issue": artifact_result.issue, "details": artifact_result.details } # Calculate overall score (weighted) weights = { "edge_continuity": 0.4, "color_harmony": 0.35, "artifact_detection": 0.25 } total_score = ( edge_result.score * weights["edge_continuity"] + harmony_result.score * weights["color_harmony"] + artifact_result.score * weights["artifact_detection"] ) results["overall_score"] = round(total_score, 1) # Determine overall pass/fail results["passed"] = all([ edge_result.passed, harmony_result.passed, artifact_result.passed ]) # Collect issues for check_name, check_data in results["checks"].items(): if check_data["issue"]: if check_data["passed"]: results["warnings"].append(f"{check_name}: {check_data['issue']}") else: results["errors"].append(f"{check_name}: {check_data['issue']}") logger.info(f"Inpainting quality: {results['overall_score']:.1f}, Passed: {results['passed']}") return results