k-l-lambda's picture
feat: add Python ML services (CPU mode) with model download
2b7aae2
"""
Image processing utilities for ML services.
"""
import io
import math
import base64
import numpy as np
import cv2
import PIL.Image
MARGIN_DIVIDER = 8
def array_from_image_stream(buffer):
"""Convert image bytes to numpy array (H, W, C)."""
if not isinstance(buffer, bytes):
return None
stream = io.BytesIO(buffer)
image = PIL.Image.open(stream)
arr = np.array(image)
if len(arr.shape) >= 3:
return arr[:, :, :3] # RGB only
arr = np.expand_dims(arr, -1)
return arr
def slice_feature(source, width, overlapping=0.25, padding=False):
"""
Slice image horizontally into overlapping pieces.
source: (height, width, channel)
yields: (height, width, channel) pieces
"""
step = math.floor(width - source.shape[0] * overlapping)
if padding:
margin = math.floor(source.shape[0] * overlapping) // 2
center = source
source = np.zeros((source.shape[0], source.shape[1] + margin * 2, source.shape[2]))
source[:, margin:-margin, :] = center
source[:, :margin, :] = center[:, :1, :]
source[:, -margin:, :] = center[:, -1:, :]
for x in range(0, source.shape[1], step):
if x + width <= source.shape[1]:
yield source[:, x:x + width, :]
else:
sliced = np.ones((source.shape[0], width, source.shape[2]), dtype=np.float32) * 255
sliced[:, :source.shape[1] - x, :] = source[:, x:, :]
yield sliced
def splice_pieces(pieces, margin_divider, keep_margin=False):
"""
Splice output tensor pieces back together.
pieces: (batch, channel, height, width)
returns: (channel, height, width)
"""
piece_height, piece_width = pieces.shape[2:]
margin_width = piece_height // margin_divider
patch_width = piece_width - margin_width * 2
result = np.zeros((pieces.shape[1], pieces.shape[2], patch_width * pieces.shape[0]), dtype=np.float32)
for i, piece in enumerate(pieces):
result[:, :, i * patch_width : (i + 1) * patch_width] = piece[:, :, margin_width:-margin_width]
if keep_margin:
return np.concatenate((pieces[0, :, :, :margin_width], result, pieces[-1, :, :, -margin_width:]), axis=2)
return result
def soft_splice_pieces(pieces, margin_divider):
"""
Splice pieces with soft blending at overlaps.
pieces: (batch, channel, height, width)
returns: (channel, height, width)
"""
batches, channels, piece_height, piece_width = pieces.shape
overlap_width = piece_height * 2 // margin_divider
segment_width = piece_width - overlap_width
slope = np.arange(overlap_width, dtype=np.float32) / overlap_width
slope = slope.reshape((1, 1, overlap_width))
inv_slope = 1 - slope
result = np.zeros((channels, piece_height, segment_width * batches + overlap_width), dtype=np.float32)
for i, piece in enumerate(pieces):
if i > 0:
piece[:, :, :overlap_width] *= slope
if i < batches - 1:
piece[:, :, -overlap_width:] *= inv_slope
result[:, :, segment_width * i:segment_width * i + piece_width] += piece
return result
def splice_output_tensor(tensor, keep_margin=False, soft=False, margin_divider=MARGIN_DIVIDER):
"""
Splice PyTorch tensor output back together.
tensor: PyTorch tensor or numpy array
"""
if tensor is None:
return None
arr = tensor.cpu().numpy() if hasattr(tensor, 'cpu') else tensor
if soft:
return soft_splice_pieces(arr, margin_divider)
return splice_pieces(arr, margin_divider, keep_margin=keep_margin)
def mask_to_alpha(mask, frac_y=False):
"""
Convert mask to LA (grayscale + alpha) format.
mask: [fore(h, w), back(h, w)]
"""
fore = mask[0] * 255
fore = np.stack([np.zeros(mask[0].shape, np.float32), fore], axis=2)
fore = np.uint8(np.clip(fore, 0, 255))
return fore
def gauge_to_rgb(gauge, frac_y=False):
"""
Convert gauge map to RGB.
gauge: [Y(h, w), K(h, w)]
"""
mapy = gauge[0] * 8 + 128
mapk = gauge[1] * 127 + 128
if frac_y:
B, R = np.modf(mapy)
result = np.stack([B * 256, mapk, R], axis=2)
else:
result = np.stack([np.zeros(mapy.shape, np.float32), mapk, mapy], axis=2)
return np.uint8(np.clip(result, 0, 255))
def encode_image_bytes(image, ext='.png', quality=80):
"""Encode PIL Image to bytes."""
fp = io.BytesIO()
image.save(fp, PIL.Image.registered_extensions()[ext], quality=quality)
return fp.getvalue()
def encode_image_base64(image, ext='.png', quality=80):
"""Encode PIL Image to base64 data URI."""
image_bytes = encode_image_bytes(image, ext, quality=quality)
b64 = base64.b64encode(image_bytes)
fmt = ext.replace('.', '')
return f'data:image/{fmt};base64,' + b64.decode('ascii')
def resize_page_image(img, size):
"""Resize page image maintaining aspect ratio, padding with white."""
w, h = size
filled_height = img.shape[0] * w // img.shape[1]
img = cv2.resize(img, (w, filled_height), interpolation=cv2.INTER_AREA)
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
if filled_height < h:
result = np.ones((h, w, img.shape[2]), dtype=np.uint8) * 255
result[:filled_height] = img
return result
return img[:h]
def normalize_image_dimension(image):
"""Ensure image dimensions are divisible by 4 (for UNet)."""
n, h, w, c = image.shape
if (h % 4 != 0) | (w % 4 != 0):
return image[:, :h - h % 4, :w - w % 4, :]
return image