Spaces:
Running
Running
File size: 5,116 Bytes
2b7aae2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 | """
Image processing utilities for ML services.
"""
import io
import math
import base64
import numpy as np
import cv2
import PIL.Image
MARGIN_DIVIDER = 8
def array_from_image_stream(buffer):
"""Convert image bytes to numpy array (H, W, C)."""
if not isinstance(buffer, bytes):
return None
stream = io.BytesIO(buffer)
image = PIL.Image.open(stream)
arr = np.array(image)
if len(arr.shape) >= 3:
return arr[:, :, :3] # RGB only
arr = np.expand_dims(arr, -1)
return arr
def slice_feature(source, width, overlapping=0.25, padding=False):
"""
Slice image horizontally into overlapping pieces.
source: (height, width, channel)
yields: (height, width, channel) pieces
"""
step = math.floor(width - source.shape[0] * overlapping)
if padding:
margin = math.floor(source.shape[0] * overlapping) // 2
center = source
source = np.zeros((source.shape[0], source.shape[1] + margin * 2, source.shape[2]))
source[:, margin:-margin, :] = center
source[:, :margin, :] = center[:, :1, :]
source[:, -margin:, :] = center[:, -1:, :]
for x in range(0, source.shape[1], step):
if x + width <= source.shape[1]:
yield source[:, x:x + width, :]
else:
sliced = np.ones((source.shape[0], width, source.shape[2]), dtype=np.float32) * 255
sliced[:, :source.shape[1] - x, :] = source[:, x:, :]
yield sliced
def splice_pieces(pieces, margin_divider, keep_margin=False):
"""
Splice output tensor pieces back together.
pieces: (batch, channel, height, width)
returns: (channel, height, width)
"""
piece_height, piece_width = pieces.shape[2:]
margin_width = piece_height // margin_divider
patch_width = piece_width - margin_width * 2
result = np.zeros((pieces.shape[1], pieces.shape[2], patch_width * pieces.shape[0]), dtype=np.float32)
for i, piece in enumerate(pieces):
result[:, :, i * patch_width : (i + 1) * patch_width] = piece[:, :, margin_width:-margin_width]
if keep_margin:
return np.concatenate((pieces[0, :, :, :margin_width], result, pieces[-1, :, :, -margin_width:]), axis=2)
return result
def soft_splice_pieces(pieces, margin_divider):
"""
Splice pieces with soft blending at overlaps.
pieces: (batch, channel, height, width)
returns: (channel, height, width)
"""
batches, channels, piece_height, piece_width = pieces.shape
overlap_width = piece_height * 2 // margin_divider
segment_width = piece_width - overlap_width
slope = np.arange(overlap_width, dtype=np.float32) / overlap_width
slope = slope.reshape((1, 1, overlap_width))
inv_slope = 1 - slope
result = np.zeros((channels, piece_height, segment_width * batches + overlap_width), dtype=np.float32)
for i, piece in enumerate(pieces):
if i > 0:
piece[:, :, :overlap_width] *= slope
if i < batches - 1:
piece[:, :, -overlap_width:] *= inv_slope
result[:, :, segment_width * i:segment_width * i + piece_width] += piece
return result
def splice_output_tensor(tensor, keep_margin=False, soft=False, margin_divider=MARGIN_DIVIDER):
"""
Splice PyTorch tensor output back together.
tensor: PyTorch tensor or numpy array
"""
if tensor is None:
return None
arr = tensor.cpu().numpy() if hasattr(tensor, 'cpu') else tensor
if soft:
return soft_splice_pieces(arr, margin_divider)
return splice_pieces(arr, margin_divider, keep_margin=keep_margin)
def mask_to_alpha(mask, frac_y=False):
"""
Convert mask to LA (grayscale + alpha) format.
mask: [fore(h, w), back(h, w)]
"""
fore = mask[0] * 255
fore = np.stack([np.zeros(mask[0].shape, np.float32), fore], axis=2)
fore = np.uint8(np.clip(fore, 0, 255))
return fore
def gauge_to_rgb(gauge, frac_y=False):
"""
Convert gauge map to RGB.
gauge: [Y(h, w), K(h, w)]
"""
mapy = gauge[0] * 8 + 128
mapk = gauge[1] * 127 + 128
if frac_y:
B, R = np.modf(mapy)
result = np.stack([B * 256, mapk, R], axis=2)
else:
result = np.stack([np.zeros(mapy.shape, np.float32), mapk, mapy], axis=2)
return np.uint8(np.clip(result, 0, 255))
def encode_image_bytes(image, ext='.png', quality=80):
"""Encode PIL Image to bytes."""
fp = io.BytesIO()
image.save(fp, PIL.Image.registered_extensions()[ext], quality=quality)
return fp.getvalue()
def encode_image_base64(image, ext='.png', quality=80):
"""Encode PIL Image to base64 data URI."""
image_bytes = encode_image_bytes(image, ext, quality=quality)
b64 = base64.b64encode(image_bytes)
fmt = ext.replace('.', '')
return f'data:image/{fmt};base64,' + b64.decode('ascii')
def resize_page_image(img, size):
"""Resize page image maintaining aspect ratio, padding with white."""
w, h = size
filled_height = img.shape[0] * w // img.shape[1]
img = cv2.resize(img, (w, filled_height), interpolation=cv2.INTER_AREA)
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
if filled_height < h:
result = np.ones((h, w, img.shape[2]), dtype=np.uint8) * 255
result[:filled_height] = img
return result
return img[:h]
def normalize_image_dimension(image):
"""Ensure image dimensions are divisible by 4 (for UNet)."""
n, h, w, c = image.shape
if (h % 4 != 0) | (w % 4 != 0):
return image[:, :h - h % 4, :w - w % 4, :]
return image
|