Spaces:
Running
Running
| """ | |
| Image processing utilities for ML services. | |
| """ | |
| import io | |
| import math | |
| import base64 | |
| import numpy as np | |
| import cv2 | |
| import PIL.Image | |
| MARGIN_DIVIDER = 8 | |
| def array_from_image_stream(buffer): | |
| """Convert image bytes to numpy array (H, W, C).""" | |
| if not isinstance(buffer, bytes): | |
| return None | |
| stream = io.BytesIO(buffer) | |
| image = PIL.Image.open(stream) | |
| arr = np.array(image) | |
| if len(arr.shape) >= 3: | |
| return arr[:, :, :3] # RGB only | |
| arr = np.expand_dims(arr, -1) | |
| return arr | |
| def slice_feature(source, width, overlapping=0.25, padding=False): | |
| """ | |
| Slice image horizontally into overlapping pieces. | |
| source: (height, width, channel) | |
| yields: (height, width, channel) pieces | |
| """ | |
| step = math.floor(width - source.shape[0] * overlapping) | |
| if padding: | |
| margin = math.floor(source.shape[0] * overlapping) // 2 | |
| center = source | |
| source = np.zeros((source.shape[0], source.shape[1] + margin * 2, source.shape[2])) | |
| source[:, margin:-margin, :] = center | |
| source[:, :margin, :] = center[:, :1, :] | |
| source[:, -margin:, :] = center[:, -1:, :] | |
| for x in range(0, source.shape[1], step): | |
| if x + width <= source.shape[1]: | |
| yield source[:, x:x + width, :] | |
| else: | |
| sliced = np.ones((source.shape[0], width, source.shape[2]), dtype=np.float32) * 255 | |
| sliced[:, :source.shape[1] - x, :] = source[:, x:, :] | |
| yield sliced | |
| def splice_pieces(pieces, margin_divider, keep_margin=False): | |
| """ | |
| Splice output tensor pieces back together. | |
| pieces: (batch, channel, height, width) | |
| returns: (channel, height, width) | |
| """ | |
| piece_height, piece_width = pieces.shape[2:] | |
| margin_width = piece_height // margin_divider | |
| patch_width = piece_width - margin_width * 2 | |
| result = np.zeros((pieces.shape[1], pieces.shape[2], patch_width * pieces.shape[0]), dtype=np.float32) | |
| for i, piece in enumerate(pieces): | |
| result[:, :, i * patch_width : (i + 1) * patch_width] = piece[:, :, margin_width:-margin_width] | |
| if keep_margin: | |
| return np.concatenate((pieces[0, :, :, :margin_width], result, pieces[-1, :, :, -margin_width:]), axis=2) | |
| return result | |
| def soft_splice_pieces(pieces, margin_divider): | |
| """ | |
| Splice pieces with soft blending at overlaps. | |
| pieces: (batch, channel, height, width) | |
| returns: (channel, height, width) | |
| """ | |
| batches, channels, piece_height, piece_width = pieces.shape | |
| overlap_width = piece_height * 2 // margin_divider | |
| segment_width = piece_width - overlap_width | |
| slope = np.arange(overlap_width, dtype=np.float32) / overlap_width | |
| slope = slope.reshape((1, 1, overlap_width)) | |
| inv_slope = 1 - slope | |
| result = np.zeros((channels, piece_height, segment_width * batches + overlap_width), dtype=np.float32) | |
| for i, piece in enumerate(pieces): | |
| if i > 0: | |
| piece[:, :, :overlap_width] *= slope | |
| if i < batches - 1: | |
| piece[:, :, -overlap_width:] *= inv_slope | |
| result[:, :, segment_width * i:segment_width * i + piece_width] += piece | |
| return result | |
| def splice_output_tensor(tensor, keep_margin=False, soft=False, margin_divider=MARGIN_DIVIDER): | |
| """ | |
| Splice PyTorch tensor output back together. | |
| tensor: PyTorch tensor or numpy array | |
| """ | |
| if tensor is None: | |
| return None | |
| arr = tensor.cpu().numpy() if hasattr(tensor, 'cpu') else tensor | |
| if soft: | |
| return soft_splice_pieces(arr, margin_divider) | |
| return splice_pieces(arr, margin_divider, keep_margin=keep_margin) | |
| def mask_to_alpha(mask, frac_y=False): | |
| """ | |
| Convert mask to LA (grayscale + alpha) format. | |
| mask: [fore(h, w), back(h, w)] | |
| """ | |
| fore = mask[0] * 255 | |
| fore = np.stack([np.zeros(mask[0].shape, np.float32), fore], axis=2) | |
| fore = np.uint8(np.clip(fore, 0, 255)) | |
| return fore | |
| def gauge_to_rgb(gauge, frac_y=False): | |
| """ | |
| Convert gauge map to RGB. | |
| gauge: [Y(h, w), K(h, w)] | |
| """ | |
| mapy = gauge[0] * 8 + 128 | |
| mapk = gauge[1] * 127 + 128 | |
| if frac_y: | |
| B, R = np.modf(mapy) | |
| result = np.stack([B * 256, mapk, R], axis=2) | |
| else: | |
| result = np.stack([np.zeros(mapy.shape, np.float32), mapk, mapy], axis=2) | |
| return np.uint8(np.clip(result, 0, 255)) | |
| def encode_image_bytes(image, ext='.png', quality=80): | |
| """Encode PIL Image to bytes.""" | |
| fp = io.BytesIO() | |
| image.save(fp, PIL.Image.registered_extensions()[ext], quality=quality) | |
| return fp.getvalue() | |
| def encode_image_base64(image, ext='.png', quality=80): | |
| """Encode PIL Image to base64 data URI.""" | |
| image_bytes = encode_image_bytes(image, ext, quality=quality) | |
| b64 = base64.b64encode(image_bytes) | |
| fmt = ext.replace('.', '') | |
| return f'data:image/{fmt};base64,' + b64.decode('ascii') | |
| def resize_page_image(img, size): | |
| """Resize page image maintaining aspect ratio, padding with white.""" | |
| w, h = size | |
| filled_height = img.shape[0] * w // img.shape[1] | |
| img = cv2.resize(img, (w, filled_height), interpolation=cv2.INTER_AREA) | |
| if len(img.shape) < 3: | |
| img = np.expand_dims(img, -1) | |
| if filled_height < h: | |
| result = np.ones((h, w, img.shape[2]), dtype=np.uint8) * 255 | |
| result[:filled_height] = img | |
| return result | |
| return img[:h] | |
| def normalize_image_dimension(image): | |
| """Ensure image dimensions are divisible by 4 (for UNet).""" | |
| n, h, w, c = image.shape | |
| if (h % 4 != 0) | (w % 4 != 0): | |
| return image[:, :h - h % 4, :w - w % 4, :] | |
| return image | |