""" Image processing utilities for ML services. """ import io import math import base64 import numpy as np import cv2 import PIL.Image MARGIN_DIVIDER = 8 def array_from_image_stream(buffer): """Convert image bytes to numpy array (H, W, C).""" if not isinstance(buffer, bytes): return None stream = io.BytesIO(buffer) image = PIL.Image.open(stream) arr = np.array(image) if len(arr.shape) >= 3: return arr[:, :, :3] # RGB only arr = np.expand_dims(arr, -1) return arr def slice_feature(source, width, overlapping=0.25, padding=False): """ Slice image horizontally into overlapping pieces. source: (height, width, channel) yields: (height, width, channel) pieces """ step = math.floor(width - source.shape[0] * overlapping) if padding: margin = math.floor(source.shape[0] * overlapping) // 2 center = source source = np.zeros((source.shape[0], source.shape[1] + margin * 2, source.shape[2])) source[:, margin:-margin, :] = center source[:, :margin, :] = center[:, :1, :] source[:, -margin:, :] = center[:, -1:, :] for x in range(0, source.shape[1], step): if x + width <= source.shape[1]: yield source[:, x:x + width, :] else: sliced = np.ones((source.shape[0], width, source.shape[2]), dtype=np.float32) * 255 sliced[:, :source.shape[1] - x, :] = source[:, x:, :] yield sliced def splice_pieces(pieces, margin_divider, keep_margin=False): """ Splice output tensor pieces back together. pieces: (batch, channel, height, width) returns: (channel, height, width) """ piece_height, piece_width = pieces.shape[2:] margin_width = piece_height // margin_divider patch_width = piece_width - margin_width * 2 result = np.zeros((pieces.shape[1], pieces.shape[2], patch_width * pieces.shape[0]), dtype=np.float32) for i, piece in enumerate(pieces): result[:, :, i * patch_width : (i + 1) * patch_width] = piece[:, :, margin_width:-margin_width] if keep_margin: return np.concatenate((pieces[0, :, :, :margin_width], result, pieces[-1, :, :, -margin_width:]), axis=2) return result def soft_splice_pieces(pieces, margin_divider): """ Splice pieces with soft blending at overlaps. pieces: (batch, channel, height, width) returns: (channel, height, width) """ batches, channels, piece_height, piece_width = pieces.shape overlap_width = piece_height * 2 // margin_divider segment_width = piece_width - overlap_width slope = np.arange(overlap_width, dtype=np.float32) / overlap_width slope = slope.reshape((1, 1, overlap_width)) inv_slope = 1 - slope result = np.zeros((channels, piece_height, segment_width * batches + overlap_width), dtype=np.float32) for i, piece in enumerate(pieces): if i > 0: piece[:, :, :overlap_width] *= slope if i < batches - 1: piece[:, :, -overlap_width:] *= inv_slope result[:, :, segment_width * i:segment_width * i + piece_width] += piece return result def splice_output_tensor(tensor, keep_margin=False, soft=False, margin_divider=MARGIN_DIVIDER): """ Splice PyTorch tensor output back together. tensor: PyTorch tensor or numpy array """ if tensor is None: return None arr = tensor.cpu().numpy() if hasattr(tensor, 'cpu') else tensor if soft: return soft_splice_pieces(arr, margin_divider) return splice_pieces(arr, margin_divider, keep_margin=keep_margin) def mask_to_alpha(mask, frac_y=False): """ Convert mask to LA (grayscale + alpha) format. mask: [fore(h, w), back(h, w)] """ fore = mask[0] * 255 fore = np.stack([np.zeros(mask[0].shape, np.float32), fore], axis=2) fore = np.uint8(np.clip(fore, 0, 255)) return fore def gauge_to_rgb(gauge, frac_y=False): """ Convert gauge map to RGB. gauge: [Y(h, w), K(h, w)] """ mapy = gauge[0] * 8 + 128 mapk = gauge[1] * 127 + 128 if frac_y: B, R = np.modf(mapy) result = np.stack([B * 256, mapk, R], axis=2) else: result = np.stack([np.zeros(mapy.shape, np.float32), mapk, mapy], axis=2) return np.uint8(np.clip(result, 0, 255)) def encode_image_bytes(image, ext='.png', quality=80): """Encode PIL Image to bytes.""" fp = io.BytesIO() image.save(fp, PIL.Image.registered_extensions()[ext], quality=quality) return fp.getvalue() def encode_image_base64(image, ext='.png', quality=80): """Encode PIL Image to base64 data URI.""" image_bytes = encode_image_bytes(image, ext, quality=quality) b64 = base64.b64encode(image_bytes) fmt = ext.replace('.', '') return f'data:image/{fmt};base64,' + b64.decode('ascii') def resize_page_image(img, size): """Resize page image maintaining aspect ratio, padding with white.""" w, h = size filled_height = img.shape[0] * w // img.shape[1] img = cv2.resize(img, (w, filled_height), interpolation=cv2.INTER_AREA) if len(img.shape) < 3: img = np.expand_dims(img, -1) if filled_height < h: result = np.ones((h, w, img.shape[2]), dtype=np.uint8) * 255 result[:filled_height] = img return result return img[:h] def normalize_image_dimension(image): """Ensure image dimensions are divisible by 4 (for UNet).""" n, h, w, c = image.shape if (h % 4 != 0) | (w % 4 != 0): return image[:, :h - h % 4, :w - w % 4, :] return image