Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from typing import Dict, Any, Tuple, Optional | |
| from dataclasses import dataclass | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| class QualityResult: | |
| """Result of a quality check.""" | |
| score: float # 0-100 | |
| passed: bool | |
| issue: str | |
| details: Dict[str, Any] | |
| class QualityChecker: | |
| """ | |
| Automated quality validation system for generated images. | |
| Provides checks for mask coverage, edge continuity, and color harmony. | |
| """ | |
| # Quality thresholds | |
| THRESHOLD_PASS = 70 | |
| THRESHOLD_WARNING = 50 | |
| def __init__(self, strictness: str = "standard"): | |
| """ | |
| Initialize QualityChecker. | |
| Args: | |
| strictness: Quality check strictness level | |
| "lenient" - Only check fatal issues | |
| "standard" - All checks with moderate thresholds | |
| "strict" - High standards required | |
| """ | |
| self.strictness = strictness | |
| self._set_thresholds() | |
| def _set_thresholds(self): | |
| """Set quality thresholds based on strictness level.""" | |
| if self.strictness == "lenient": | |
| self.min_coverage = 0.03 # 3% | |
| self.min_edge_score = 40 | |
| self.min_harmony_score = 40 | |
| elif self.strictness == "strict": | |
| self.min_coverage = 0.10 # 10% | |
| self.min_edge_score = 75 | |
| self.min_harmony_score = 75 | |
| else: # standard | |
| self.min_coverage = 0.05 # 5% | |
| self.min_edge_score = 60 | |
| self.min_harmony_score = 60 | |
| def check_mask_coverage(self, mask: Image.Image) -> QualityResult: | |
| """ | |
| Verify mask coverage is adequate. | |
| Args: | |
| mask: Grayscale mask image (L mode) | |
| Returns: | |
| QualityResult with coverage analysis | |
| """ | |
| try: | |
| mask_array = np.array(mask.convert('L')) | |
| height, width = mask_array.shape | |
| total_pixels = height * width | |
| # Count foreground pixels | |
| fg_pixels = np.count_nonzero(mask_array > 127) | |
| coverage_ratio = fg_pixels / total_pixels | |
| # Check for isolated small regions (noise) | |
| _, binary = cv2.threshold(mask_array, 127, 255, cv2.THRESH_BINARY) | |
| num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary, connectivity=8) | |
| # Count significant regions (> 1% of image) | |
| min_region_size = total_pixels * 0.01 | |
| significant_regions = sum(1 for i in range(1, num_labels) | |
| if stats[i, cv2.CC_STAT_AREA] > min_region_size) | |
| # Calculate fragmentation (many small regions = bad) | |
| fragmentation_penalty = max(0, (num_labels - 1 - significant_regions) * 2) | |
| # Score calculation | |
| coverage_score = min(100, coverage_ratio * 200) # 50% coverage = 100 score | |
| final_score = max(0, coverage_score - fragmentation_penalty) | |
| # Determine pass/fail | |
| passed = coverage_ratio >= self.min_coverage and significant_regions >= 1 | |
| issue = "" | |
| if coverage_ratio < self.min_coverage: | |
| issue = f"Low foreground coverage ({coverage_ratio:.1%})" | |
| elif significant_regions == 0: | |
| issue = "No significant foreground regions detected" | |
| elif fragmentation_penalty > 20: | |
| issue = f"Fragmented mask with {num_labels - 1} isolated regions" | |
| return QualityResult( | |
| score=final_score, | |
| passed=passed, | |
| issue=issue, | |
| details={ | |
| "coverage_ratio": coverage_ratio, | |
| "foreground_pixels": fg_pixels, | |
| "total_regions": num_labels - 1, | |
| "significant_regions": significant_regions | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"❌ Mask coverage check failed: {e}") | |
| return QualityResult(score=0, passed=False, issue=str(e), details={}) | |
| def check_edge_continuity(self, mask: Image.Image) -> QualityResult: | |
| """ | |
| Check if mask edges are continuous and smooth. | |
| Args: | |
| mask: Grayscale mask image | |
| Returns: | |
| QualityResult with edge analysis | |
| """ | |
| try: | |
| mask_array = np.array(mask.convert('L')) | |
| # Find edges using morphological gradient | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) | |
| gradient = cv2.morphologyEx(mask_array, cv2.MORPH_GRADIENT, kernel) | |
| # Get edge pixels | |
| edge_pixels = gradient > 20 | |
| edge_count = np.count_nonzero(edge_pixels) | |
| if edge_count == 0: | |
| return QualityResult( | |
| score=50, | |
| passed=False, | |
| issue="No edges detected in mask", | |
| details={"edge_count": 0} | |
| ) | |
| # Check edge smoothness using Laplacian | |
| laplacian = cv2.Laplacian(mask_array, cv2.CV_64F) | |
| edge_laplacian = np.abs(laplacian[edge_pixels]) | |
| # High Laplacian values indicate jagged edges | |
| smoothness = 100 - min(100, np.std(edge_laplacian) * 0.5) | |
| # Check for gaps in edges | |
| # Dilate and erode to find disconnections | |
| dilated = cv2.dilate(gradient, kernel, iterations=1) | |
| eroded = cv2.erode(dilated, kernel, iterations=1) | |
| gaps = cv2.subtract(dilated, eroded) | |
| gap_ratio = np.count_nonzero(gaps) / max(edge_count, 1) | |
| # Calculate final score | |
| gap_penalty = min(40, gap_ratio * 100) | |
| final_score = max(0, smoothness - gap_penalty) | |
| passed = final_score >= self.min_edge_score | |
| issue = "" | |
| if final_score < self.min_edge_score: | |
| if smoothness < 60: | |
| issue = "Jagged or rough edges detected" | |
| elif gap_ratio > 0.3: | |
| issue = "Discontinuous edges with gaps" | |
| else: | |
| issue = "Poor edge quality" | |
| return QualityResult( | |
| score=final_score, | |
| passed=passed, | |
| issue=issue, | |
| details={ | |
| "edge_count": edge_count, | |
| "smoothness": smoothness, | |
| "gap_ratio": gap_ratio | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"❌ Edge continuity check failed: {e}") | |
| return QualityResult(score=0, passed=False, issue=str(e), details={}) | |
| def check_color_harmony( | |
| self, | |
| foreground: Image.Image, | |
| background: Image.Image, | |
| mask: Image.Image | |
| ) -> QualityResult: | |
| """ | |
| Evaluate color harmony between foreground and background. | |
| Args: | |
| foreground: Original foreground image | |
| background: Generated background image | |
| mask: Combination mask | |
| Returns: | |
| QualityResult with harmony analysis | |
| """ | |
| try: | |
| fg_array = np.array(foreground.convert('RGB')) | |
| bg_array = np.array(background.convert('RGB')) | |
| mask_array = np.array(mask.convert('L')) | |
| # Get foreground and background regions | |
| fg_region = mask_array > 127 | |
| bg_region = mask_array <= 127 | |
| if not np.any(fg_region) or not np.any(bg_region): | |
| return QualityResult( | |
| score=50, | |
| passed=True, | |
| issue="Cannot analyze harmony - insufficient regions", | |
| details={} | |
| ) | |
| # Convert to LAB for perceptual analysis | |
| fg_lab = cv2.cvtColor(fg_array, cv2.COLOR_RGB2LAB).astype(np.float32) | |
| bg_lab = cv2.cvtColor(bg_array, cv2.COLOR_RGB2LAB).astype(np.float32) | |
| # Calculate average colors | |
| fg_avg_l = np.mean(fg_lab[fg_region, 0]) | |
| fg_avg_a = np.mean(fg_lab[fg_region, 1]) | |
| fg_avg_b = np.mean(fg_lab[fg_region, 2]) | |
| bg_avg_l = np.mean(bg_lab[bg_region, 0]) | |
| bg_avg_a = np.mean(bg_lab[bg_region, 1]) | |
| bg_avg_b = np.mean(bg_lab[bg_region, 2]) | |
| # Calculate color differences | |
| delta_l = abs(fg_avg_l - bg_avg_l) | |
| delta_a = abs(fg_avg_a - bg_avg_a) | |
| delta_b = abs(fg_avg_b - bg_avg_b) | |
| # Overall color difference (Delta E approximation) | |
| delta_e = np.sqrt(delta_l**2 + delta_a**2 + delta_b**2) | |
| # Score calculation | |
| # Moderate difference is good (20-60 Delta E) | |
| # Too similar or too different is problematic | |
| if delta_e < 10: | |
| harmony_score = 60 # Too similar, foreground may get lost | |
| issue = "Foreground and background colors too similar" | |
| elif delta_e > 80: | |
| harmony_score = 50 # Too different, may look unnatural | |
| issue = "High color contrast may look unnatural" | |
| elif 20 <= delta_e <= 60: | |
| harmony_score = 100 # Ideal range | |
| issue = "" | |
| else: | |
| harmony_score = 80 | |
| issue = "" | |
| # Check for extreme contrast (very dark fg on very bright bg or vice versa) | |
| brightness_contrast = abs(fg_avg_l - bg_avg_l) | |
| if brightness_contrast > 100: | |
| harmony_score = max(40, harmony_score - 30) | |
| issue = "Extreme brightness contrast between foreground and background" | |
| passed = harmony_score >= self.min_harmony_score | |
| return QualityResult( | |
| score=harmony_score, | |
| passed=passed, | |
| issue=issue, | |
| details={ | |
| "delta_e": delta_e, | |
| "delta_l": delta_l, | |
| "delta_a": delta_a, | |
| "delta_b": delta_b, | |
| "fg_luminance": fg_avg_l, | |
| "bg_luminance": bg_avg_l | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"❌ Color harmony check failed: {e}") | |
| return QualityResult(score=0, passed=False, issue=str(e), details={}) | |
| def run_all_checks( | |
| self, | |
| foreground: Image.Image, | |
| background: Image.Image, | |
| mask: Image.Image, | |
| combined: Optional[Image.Image] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Run all quality checks and return comprehensive results. | |
| Args: | |
| foreground: Original foreground image | |
| background: Generated background | |
| mask: Combination mask | |
| combined: Final combined image (optional) | |
| Returns: | |
| Dictionary with all check results and overall score | |
| """ | |
| logger.info("🔍 Running quality checks...") | |
| results = { | |
| "checks": {}, | |
| "overall_score": 0, | |
| "passed": True, | |
| "warnings": [], | |
| "errors": [] | |
| } | |
| # Run individual checks | |
| coverage_result = self.check_mask_coverage(mask) | |
| results["checks"]["mask_coverage"] = { | |
| "score": coverage_result.score, | |
| "passed": coverage_result.passed, | |
| "issue": coverage_result.issue, | |
| "details": coverage_result.details | |
| } | |
| edge_result = self.check_edge_continuity(mask) | |
| results["checks"]["edge_continuity"] = { | |
| "score": edge_result.score, | |
| "passed": edge_result.passed, | |
| "issue": edge_result.issue, | |
| "details": edge_result.details | |
| } | |
| harmony_result = self.check_color_harmony(foreground, background, mask) | |
| results["checks"]["color_harmony"] = { | |
| "score": harmony_result.score, | |
| "passed": harmony_result.passed, | |
| "issue": harmony_result.issue, | |
| "details": harmony_result.details | |
| } | |
| # Calculate overall score (weighted average) | |
| weights = { | |
| "mask_coverage": 0.4, | |
| "edge_continuity": 0.3, | |
| "color_harmony": 0.3 | |
| } | |
| total_score = ( | |
| coverage_result.score * weights["mask_coverage"] + | |
| edge_result.score * weights["edge_continuity"] + | |
| harmony_result.score * weights["color_harmony"] | |
| ) | |
| results["overall_score"] = round(total_score, 1) | |
| # Determine overall pass/fail | |
| results["passed"] = all([ | |
| coverage_result.passed, | |
| edge_result.passed, | |
| harmony_result.passed | |
| ]) | |
| # Collect warnings and errors | |
| for check_name, check_data in results["checks"].items(): | |
| if check_data["issue"]: | |
| if check_data["passed"]: | |
| results["warnings"].append(f"{check_name}: {check_data['issue']}") | |
| else: | |
| results["errors"].append(f"{check_name}: {check_data['issue']}") | |
| logger.info(f"📊 Quality check complete - Score: {results['overall_score']}, Passed: {results['passed']}") | |
| return results | |
| def get_quality_summary(self, results: Dict[str, Any]) -> str: | |
| """ | |
| Generate human-readable quality summary. | |
| Args: | |
| results: Results from run_all_checks | |
| Returns: | |
| Summary string | |
| """ | |
| score = results["overall_score"] | |
| passed = results["passed"] | |
| if score >= 90: | |
| grade = "Excellent" | |
| elif score >= 75: | |
| grade = "Good" | |
| elif score >= 60: | |
| grade = "Acceptable" | |
| elif score >= 40: | |
| grade = "Needs Improvement" | |
| else: | |
| grade = "Poor" | |
| summary = f"Quality: {grade} ({score:.0f}/100)" | |
| if results["errors"]: | |
| summary += f"\nIssues: {'; '.join(results['errors'])}" | |
| elif results["warnings"]: | |
| summary += f"\nNotes: {'; '.join(results['warnings'])}" | |
| return summary | |