LogicGoInfotechSpaces commited on
Commit
8a0a72a
·
1 Parent(s): f854294

fix: simplify mask processing to match reference model - direct white=remove detection

Browse files
Files changed (2) hide show
  1. api/main.py +15 -1
  2. src/core.py +21 -26
api/main.py CHANGED
@@ -322,10 +322,24 @@ def inpaint_multipart(
322
 
323
  log.info("Auto-converted painted image to black/white mask: %d white pixels (to remove)",
324
  int((binmask_clean > 0).sum()))
 
 
 
 
 
 
 
 
 
325
  else:
326
  mask_rgba = _load_rgba_mask_from_image(m)
327
 
328
- result = process_inpaint(np.array(img), mask_rgba, invert_mask=invert_mask)
 
 
 
 
 
329
  result_name = f"output_{uuid.uuid4().hex}.png"
330
  result_path = os.path.join(OUTPUT_DIR, result_name)
331
  Image.fromarray(result).save(result_path)
 
322
 
323
  log.info("Auto-converted painted image to black/white mask: %d white pixels (to remove)",
324
  int((binmask_clean > 0).sum()))
325
+
326
+ if int((binmask_clean > 0).sum()) < 50:
327
+ log.error("CRITICAL: Mask detection found very few pixels! Returning original image.")
328
+ # Return original image if mask is invalid
329
+ result = np.array(img.convert("RGB"))
330
+ result_name = f"output_{uuid.uuid4().hex}.png"
331
+ result_path = os.path.join(OUTPUT_DIR, result_name)
332
+ Image.fromarray(result).save(result_path)
333
+ return {"result": result_name, "error": "mask detection failed - very few pixels detected"}
334
  else:
335
  mask_rgba = _load_rgba_mask_from_image(m)
336
 
337
+ # When mask_is_painted=true, we create white=remove masks, so invert_mask should be False
338
+ # (white pixels should stay white to indicate removal)
339
+ actual_invert = invert_mask if not mask_is_painted else False
340
+ log.info("Using invert_mask=%s (mask_is_painted=%s)", actual_invert, mask_is_painted)
341
+
342
+ result = process_inpaint(np.array(img), mask_rgba, invert_mask=actual_invert)
343
  result_name = f"output_{uuid.uuid4().hex}.png"
344
  result_path = os.path.join(OUTPUT_DIR, result_name)
345
  Image.fromarray(result).save(result_path)
src/core.py CHANGED
@@ -460,40 +460,35 @@ def process_inpaint(image, mask, invert_mask=True):
460
 
461
  # Convert RGBA mask to single-channel mask.
462
  # Standard LaMa convention: 1 = remove, 0 = keep
463
- # The mask can come in different formats:
464
- # - RGBA with alpha channel encoding (alpha=0 means remove when invert_mask=True)
465
- # - RGBA with RGB encoding (white/colored areas mean remove)
466
 
467
  alpha_channel = mask[:,:,3]
468
  rgb_channels = mask[:,:,:3]
469
 
470
- # Check if alpha channel is meaningful (not all 255)
471
- alpha_mean = alpha_channel.mean()
 
 
 
 
472
 
 
 
 
 
 
 
473
  if alpha_mean < 50:
474
- # Alpha channel is mostly transparent - use alpha directly
475
- # Transparent (0) = remove, Opaque (255) = keep
476
  if invert_mask:
477
- mask = 255 - alpha_channel # transparent → white (remove)
478
  else:
479
- mask = alpha_channel # opaque → white (remove)
480
- elif alpha_mean > 200:
481
- # Alpha channel is mostly opaque - check RGB channels for paint colors
482
- # Detect magenta (255, 0, 255) or any bright colored paint
483
- gray = cv2.cvtColor(rgb_channels, cv2.COLOR_RGB2GRAY)
484
- # White or bright colors (>200) in RGB = remove
485
- mask_rgb = (gray > 200).astype(np.uint8) * 255
486
- # Also detect magenta specifically
487
- magenta = np.all(rgb_channels == [255, 0, 255], axis=2).astype(np.uint8) * 255
488
- mask = np.maximum(mask_rgb, magenta)
489
- if not invert_mask:
490
- mask = 255 - mask # invert if needed
491
- else:
492
- # Mixed alpha - use alpha channel with inversion logic
493
- if invert_mask:
494
- mask = 255 - alpha_channel
495
- else:
496
- mask = alpha_channel
497
 
498
  mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
499
 
 
460
 
461
  # Convert RGBA mask to single-channel mask.
462
  # Standard LaMa convention: 1 = remove, 0 = keep
463
+ # Simple approach: white pixels in RGB = remove, black = keep
464
+ # This matches the reference model behavior
 
465
 
466
  alpha_channel = mask[:,:,3]
467
  rgb_channels = mask[:,:,:3]
468
 
469
+ # Convert RGB to grayscale to detect white/black
470
+ gray = cv2.cvtColor(rgb_channels, cv2.COLOR_RGB2GRAY)
471
+
472
+ # Standard: white (255) = remove, black (0) = keep
473
+ # Detect white pixels (>128) as removal areas
474
+ mask = (gray > 128).astype(np.uint8) * 255
475
 
476
+ # Also explicitly detect magenta (255, 0, 255) which is commonly used for painting
477
+ magenta = np.all(rgb_channels == [255, 0, 255], axis=2).astype(np.uint8) * 255
478
+ mask = np.maximum(mask, magenta)
479
+
480
+ # If alpha channel is mostly transparent (<50 mean), use it as mask source
481
+ alpha_mean = alpha_channel.mean()
482
  if alpha_mean < 50:
483
+ # Transparent areas (alpha=0) should be removed
 
484
  if invert_mask:
485
+ mask = np.maximum(mask, (255 - alpha_channel)) # transparent → white
486
  else:
487
+ mask = np.maximum(mask, alpha_channel) # opaque → white
488
+
489
+ # Apply invert_mask if needed (for special cases where black=remove)
490
+ if not invert_mask:
491
+ mask = 255 - mask
 
 
 
 
 
 
 
 
 
 
 
 
 
492
 
493
  mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
494