k-l-lambda's picture
feat: add Python ML services (CPU mode) with model download
2b7aae2
"""
Layout prediction service.
Detects page layout: staff boxes, lines, intervals, rotation angle.
Includes system/staff detection for the full prediction pipeline.
"""
import numpy as np
import torch
import cv2
import PIL.Image
import hashlib
import io
import logging
from predictors.torchscript_predictor import TorchScriptPredictor
from common.image_utils import (
array_from_image_stream, resize_page_image, normalize_image_dimension,
encode_image_base64
)
from common.transform import Composer
RESIZE_WIDTH = 600
CANVAS_WIDTH_MIN = 1024
SYSTEM_HEIGHT_ENLARGE = 0.02
SYSTEM_LEFT_ENLARGE = 0.03
SYSTEM_RIGHT_ENLARGE = 0.01
STAFF_PADDING_LEFT = 32
STAFF_HEIGHT_UNITS = 24
UNIT_SIZE = 8
def _detect_systems(image):
"""Detect musical systems (staff groups) from max-channel heatmap."""
height, width = image.shape
blur = cv2.GaussianBlur(image, (5, 5), 0)
thresh = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 201, -40)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
marginLeft = SYSTEM_LEFT_ENLARGE * width
marginRight = SYSTEM_RIGHT_ENLARGE * width
areas = []
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
rw = w / width
rh = h / width
if (rw > 0.6 and rh > 0.02) or (rw > 0.12 and rh > 0.2):
left = max(x - marginLeft, 0)
right = min(x + w + marginRight, width)
areas.append({'x': left, 'y': y, 'width': right - left, 'height': h})
areas.sort(key=lambda a: a['y'])
# Enlarge heights to include surrounding space
marginY = SYSTEM_HEIGHT_ENLARGE * width
marginYMax = marginY * 8
ctx = {'lastMargin': 0}
def enlarge(i, area, ctx):
top = area['y']
bottom = top + area['height']
if i > 0:
lastArea = areas[i - 1]
ctx['lastMargin'] = max(ctx['lastMargin'], lastArea['y'] + lastArea['height'], top - marginYMax)
top = max(0, min(top - marginY, ctx['lastMargin']))
else:
top = min(top, max(marginY, top - marginYMax))
if i < len(areas) - 1:
nextArea = areas[i + 1]
bottom = min(height, max(bottom + marginY, nextArea['y']), bottom + marginYMax)
else:
bottom = min(height, bottom + marginYMax)
return {'top': top, 'bottom': bottom}
enlarges = [enlarge(i, area, ctx) for i, area in enumerate(areas)]
for i, area in enumerate(areas):
area['y'] = enlarges[i]['top']
area['height'] = enlarges[i]['bottom'] - enlarges[i]['top']
return {'areas': areas}
def _detect_staves_from_hbl(HB, HL, interval):
"""Detect individual staves within a system using staff-box and horizontal-line heatmaps."""
_, HB = cv2.threshold(HB, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
contours, _ = cv2.findContours(HB, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
height, width = HB.shape
STAFF_HEIGHT_MIN = interval * 3
STAFF_WIDTH_MIN = max(width * 0.6, width - interval * 12)
UPSCALE = 4
upInterval = interval * UPSCALE
rects = map(cv2.boundingRect, contours)
rects = filter(lambda rect: rect[2] > STAFF_WIDTH_MIN and rect[3] > STAFF_HEIGHT_MIN, rects)
rects = sorted(rects, key=lambda rect: rect[1])
# Merge overlapping rectangles
preRects = []
for rect in rects:
x, y, w, h = rect
ri = next((i for i, rc in enumerate(preRects)
if (y + h / 2) - (rc[1] + rc[3] / 2) < (h + rc[3]) / 2), -1)
if ri < 0:
preRects.append(rect)
else:
rc = list(preRects[ri])
if w > rc[2]:
preRects[ri] = (min(x, rc[0]), rect[1], max(x + w, rc[0] + rc[2]) - x, rect[3])
else:
rc[0] = min(x, rc[0])
rc[2] = max(x + w, rc[0] + rc[2]) - rc[0]
if len(preRects) == 0:
return {'reason': 'no block rectangles detected'}
phi1 = min(rc[0] for rc in preRects)
phi2 = min(rc[0] + rc[2] for rc in preRects)
middleRhos = []
for rect in preRects:
rx, ry, rw, rh = rect
x = max(rx, 0)
y = max(round(ry - interval), 0)
roi = (x, y, min(rw, width - x), min(round(rh + interval + ry - y), height - y))
staffLines = HL[roi[1]:roi[1] + roi[3], roi[0]:roi[0] + roi[2]]
hlineColumn = cv2.resize(staffLines, (1, staffLines.shape[0] * UPSCALE), 0, 0, cv2.INTER_LINEAR).flatten()
i1 = round(upInterval)
kernel = np.zeros(i1 * 4 + 1)
kernel[::i1] = 1
convolutionLine = np.convolve(hlineColumn, kernel)
convolutionLineMax = np.max(convolutionLine)
middleY = np.where(convolutionLine == convolutionLineMax)[0][0] - round(upInterval * 2)
middleRhos.append(y + middleY / UPSCALE)
return {
'interval': interval,
'phi1': phi1,
'phi2': phi2,
'middleRhos': middleRhos,
}
def _scale_detection(detection, scale):
"""Scale detection coordinates by given factor."""
areas = [{
'x': a['x'] * scale,
'y': a['y'] * scale,
'width': a['width'] * scale,
'height': a['height'] * scale,
'staff_images': a.get('staff_images'),
'staves': {
'interval': a['staves']['interval'] * scale,
'phi1': a['staves']['phi1'] * scale,
'phi2': a['staves']['phi2'] * scale,
'middleRhos': [rho * scale for rho in a['staves']['middleRhos']],
} if (a.get('staves') and a['staves'].get('middleRhos') is not None) else None,
} for a in detection['areas']]
return {'areas': areas}
class PageLayout:
"""Process layout heatmap to extract page parameters."""
def __init__(self, heatmap):
"""
heatmap: (3, H, W) - channels: [VL, StaffBox, HL]
"""
lines_map = heatmap[2]
self.interval = self.measure_interval(lines_map)
heatmap_uint8 = np.uint8(heatmap * 255)
heatmap_uint8 = np.moveaxis(heatmap_uint8, 0, -1)
self.image = PIL.Image.fromarray(heatmap_uint8, 'RGB')
self.heatmap = heatmap_uint8
staves_map = heatmap_uint8[:, :, 1]
self.theta = self.measure_theta(staves_map)
def json(self):
return {
'image': encode_image_base64(self.image),
'theta': self.theta,
'interval': self.interval,
}
def detect(self, image, ratio):
"""Full detection: systems, staves, staff images."""
if self.theta is None:
return {
'theta': self.theta,
'interval': self.interval * image.shape[1] / RESIZE_WIDTH,
'detection': None,
}
original_size = (image.shape[1], image.shape[0])
aligned_height = int(image.shape[1] * ratio)
if image.shape[0] < aligned_height:
image = np.pad(image, ((0, aligned_height - image.shape[0]), (0, 0), (0, 0)), mode='constant')
elif image.shape[0] > aligned_height:
image = image[:aligned_height]
canvas_size = (original_size[0], aligned_height)
while canvas_size[0] < CANVAS_WIDTH_MIN:
canvas_size = (canvas_size[0] * 2, canvas_size[1] * 2)
if canvas_size[0] > original_size[0]:
image = cv2.resize(image, canvas_size)
rot_mat = cv2.getRotationMatrix2D((canvas_size[0] / 2, canvas_size[1] / 2), self.theta * 180 / np.pi, 1)
image = cv2.warpAffine(image, rot_mat, canvas_size, flags=cv2.INTER_CUBIC)
if len(image.shape) < 3:
image = np.expand_dims(image, -1)
heatmap = cv2.resize(self.heatmap, (canvas_size[0], round(canvas_size[0] * self.heatmap.shape[0] / self.heatmap.shape[1])), interpolation=cv2.INTER_CUBIC)
if heatmap.shape[0] > canvas_size[1]:
heatmap = heatmap[:canvas_size[1]]
elif heatmap.shape[0] < canvas_size[1]:
heatmap = np.pad(heatmap, ((0, canvas_size[1] - heatmap.shape[0]), (0, 0), (0, 0)), mode='constant')
heatmap = cv2.warpAffine(heatmap, rot_mat, canvas_size, flags=cv2.INTER_LINEAR)
HB = heatmap[:, :, 1]
HL = heatmap[:, :, 2]
block = heatmap.max(axis=2)
detection = _detect_systems(block)
canvas_interval = self.interval * canvas_size[0] / RESIZE_WIDTH
for si, area in enumerate(detection['areas']):
l, r, t, b = map(round, (area['x'], area['x'] + area['width'], area['y'], area['y'] + area['height']))
system_image = image[t:b, l:r, :]
hb = HB[t:b, l:r]
hl = HL[t:b, l:r]
area['staves'] = _detect_staves_from_hbl(hb, hl, canvas_interval)
if not area.get('staves') or area['staves'].get('middleRhos') is None:
continue
area['staff_images'] = []
interval = area['staves']['interval']
unit_scaling = UNIT_SIZE / interval
padding_left = round(STAFF_PADDING_LEFT * UNIT_SIZE / interval / unit_scaling)
staff_width = round(system_image.shape[1] * unit_scaling) + STAFF_PADDING_LEFT
staff_size = (staff_width, STAFF_HEIGHT_UNITS * UNIT_SIZE)
for ssi, rho in enumerate(area['staves']['middleRhos']):
map_x = (np.tile(np.arange(staff_size[0], dtype=np.float32), (staff_size[1], 1)) - STAFF_PADDING_LEFT) / unit_scaling
map_y = (np.tile(np.arange(staff_size[1], dtype=np.float32), (staff_size[0], 1)).T - staff_size[1] / 2) / unit_scaling + rho
map_x, map_y = map_x.astype(np.float32), map_y.astype(np.float32)
staff_image = cv2.remap(system_image, map_x, map_y, cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT, borderValue=(255, 255, 255))
# Encode staff image as PNG bytes for downstream predictors
_, png_data = cv2.imencode('.png', staff_image)
staff_bytes = png_data.tobytes()
area['staff_images'].append({
'hash': None,
'image': staff_bytes,
'position': {
'x': -STAFF_PADDING_LEFT / UNIT_SIZE,
'y': -STAFF_HEIGHT_UNITS / 2,
'width': staff_size[0] / UNIT_SIZE,
'height': staff_size[1] / UNIT_SIZE,
},
})
page_interval = self.interval * original_size[0] / RESIZE_WIDTH
return {
'sourceSize': {
'width': original_size[0],
'height': original_size[1],
},
'theta': self.theta,
'interval': page_interval,
'detection': _scale_detection(detection, original_size[0] / canvas_size[0]),
}
@staticmethod
def measure_theta(heatmap):
"""Measure page rotation angle using Hough lines."""
edges = cv2.Canny(heatmap, 50, 150, apertureSize=3)
lines = cv2.HoughLines(
edges, 1, np.pi / 18000,
round(heatmap.shape[1] * 0.4),
min_theta=np.pi * 0.48,
max_theta=np.pi * 0.52
)
if lines is None:
return None
avg_theta = sum(line[0][1] for line in lines) / len(lines)
return float(avg_theta - np.pi / 2)
@staticmethod
def measure_interval(heatmap):
"""Measure staff line interval using autocorrelation."""
UPSCALE = 4
width = heatmap.shape[1]
heatmap = cv2.resize(
heatmap,
(heatmap.shape[1] // UPSCALE, heatmap.shape[0] * UPSCALE),
interpolation=cv2.INTER_LINEAR
)
interval_min = round(width * 0.002 * UPSCALE)
interval_max = round(width * 0.025 * UPSCALE)
brights = []
for y in range(interval_min, interval_max):
m1, m2 = heatmap[y:], heatmap[:-y]
p = np.multiply(m1, m2)
brights.append(np.mean(p))
# Subtract 2x interval to weaken harmonics
brights = np.array([brights])
brights2 = cv2.resize(brights, (brights.shape[1] * 2, 1))
brights = brights.flatten()[interval_min:]
brights2 = brights2.flatten()[:len(brights)]
brights -= brights2 * 0.5
return (interval_min * 2 + int(np.argmax(brights))) / UPSCALE
class LayoutService(TorchScriptPredictor):
"""Layout prediction service using TorchScript model."""
# Default transform pipeline
DEFAULT_TRANS = ['Mono', 'HWC2CHW']
def __init__(self, model_path, device='cuda', trans=None, **kwargs):
super().__init__(model_path, device)
self.composer = Composer(trans or self.DEFAULT_TRANS)
def predict(self, streams, **kwargs):
"""
Predict page layout from image streams (basic mode).
streams: list of image byte buffers
yields: layout JSON results
"""
for stream in streams:
image = array_from_image_stream(stream)
if image is None:
yield {'error': 'Invalid image'}
continue
image = np.expand_dims(image, 0) # (1, H, W, C)
image = normalize_image_dimension(image)
batch, _ = self.composer(image, np.ones((1, 4, 4, 2)))
batch = torch.from_numpy(batch).to(self.device)
output = self.run_inference(batch)
output = output.cpu().numpy()
hotmap = output[0] # (C, H, W)
yield PageLayout(hotmap).json()
def predictDetection(self, streams):
"""
Predict layout with full system/staff detection.
streams: list of image byte buffers
yields: detection results with sourceSize, theta, interval, detection.areas
"""
images = [array_from_image_stream(stream) for stream in streams]
ratio = max(img.shape[0] / img.shape[1] for img in images)
height = int(RESIZE_WIDTH * ratio)
height += -height % 4
unified_images = [resize_page_image(img, (RESIZE_WIDTH, height)) for img in images]
image_array = np.stack(unified_images, axis=0)
batch, _ = self.composer(image_array, np.ones((1, 4, 4, 2)))
batch = torch.from_numpy(batch).to(self.device)
with torch.no_grad():
output = self.run_inference(batch)
output = output.cpu().numpy()
for i, heatmap in enumerate(output):
image = images[i]
layout = PageLayout(heatmap)
result = layout.detect(image, ratio)
yield result