""" Layout prediction service. Detects page layout: staff boxes, lines, intervals, rotation angle. Includes system/staff detection for the full prediction pipeline. """ import numpy as np import torch import cv2 import PIL.Image import hashlib import io import logging from predictors.torchscript_predictor import TorchScriptPredictor from common.image_utils import ( array_from_image_stream, resize_page_image, normalize_image_dimension, encode_image_base64 ) from common.transform import Composer RESIZE_WIDTH = 600 CANVAS_WIDTH_MIN = 1024 SYSTEM_HEIGHT_ENLARGE = 0.02 SYSTEM_LEFT_ENLARGE = 0.03 SYSTEM_RIGHT_ENLARGE = 0.01 STAFF_PADDING_LEFT = 32 STAFF_HEIGHT_UNITS = 24 UNIT_SIZE = 8 def _detect_systems(image): """Detect musical systems (staff groups) from max-channel heatmap.""" height, width = image.shape blur = cv2.GaussianBlur(image, (5, 5), 0) thresh = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 201, -40) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) marginLeft = SYSTEM_LEFT_ENLARGE * width marginRight = SYSTEM_RIGHT_ENLARGE * width areas = [] for contour in contours: x, y, w, h = cv2.boundingRect(contour) rw = w / width rh = h / width if (rw > 0.6 and rh > 0.02) or (rw > 0.12 and rh > 0.2): left = max(x - marginLeft, 0) right = min(x + w + marginRight, width) areas.append({'x': left, 'y': y, 'width': right - left, 'height': h}) areas.sort(key=lambda a: a['y']) # Enlarge heights to include surrounding space marginY = SYSTEM_HEIGHT_ENLARGE * width marginYMax = marginY * 8 ctx = {'lastMargin': 0} def enlarge(i, area, ctx): top = area['y'] bottom = top + area['height'] if i > 0: lastArea = areas[i - 1] ctx['lastMargin'] = max(ctx['lastMargin'], lastArea['y'] + lastArea['height'], top - marginYMax) top = max(0, min(top - marginY, ctx['lastMargin'])) else: top = min(top, max(marginY, top - marginYMax)) if i < len(areas) - 1: nextArea = areas[i + 1] bottom = min(height, max(bottom + marginY, nextArea['y']), bottom + marginYMax) else: bottom = min(height, bottom + marginYMax) return {'top': top, 'bottom': bottom} enlarges = [enlarge(i, area, ctx) for i, area in enumerate(areas)] for i, area in enumerate(areas): area['y'] = enlarges[i]['top'] area['height'] = enlarges[i]['bottom'] - enlarges[i]['top'] return {'areas': areas} def _detect_staves_from_hbl(HB, HL, interval): """Detect individual staves within a system using staff-box and horizontal-line heatmaps.""" _, HB = cv2.threshold(HB, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) contours, _ = cv2.findContours(HB, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) height, width = HB.shape STAFF_HEIGHT_MIN = interval * 3 STAFF_WIDTH_MIN = max(width * 0.6, width - interval * 12) UPSCALE = 4 upInterval = interval * UPSCALE rects = map(cv2.boundingRect, contours) rects = filter(lambda rect: rect[2] > STAFF_WIDTH_MIN and rect[3] > STAFF_HEIGHT_MIN, rects) rects = sorted(rects, key=lambda rect: rect[1]) # Merge overlapping rectangles preRects = [] for rect in rects: x, y, w, h = rect ri = next((i for i, rc in enumerate(preRects) if (y + h / 2) - (rc[1] + rc[3] / 2) < (h + rc[3]) / 2), -1) if ri < 0: preRects.append(rect) else: rc = list(preRects[ri]) if w > rc[2]: preRects[ri] = (min(x, rc[0]), rect[1], max(x + w, rc[0] + rc[2]) - x, rect[3]) else: rc[0] = min(x, rc[0]) rc[2] = max(x + w, rc[0] + rc[2]) - rc[0] if len(preRects) == 0: return {'reason': 'no block rectangles detected'} phi1 = min(rc[0] for rc in preRects) phi2 = min(rc[0] + rc[2] for rc in preRects) middleRhos = [] for rect in preRects: rx, ry, rw, rh = rect x = max(rx, 0) y = max(round(ry - interval), 0) roi = (x, y, min(rw, width - x), min(round(rh + interval + ry - y), height - y)) staffLines = HL[roi[1]:roi[1] + roi[3], roi[0]:roi[0] + roi[2]] hlineColumn = cv2.resize(staffLines, (1, staffLines.shape[0] * UPSCALE), 0, 0, cv2.INTER_LINEAR).flatten() i1 = round(upInterval) kernel = np.zeros(i1 * 4 + 1) kernel[::i1] = 1 convolutionLine = np.convolve(hlineColumn, kernel) convolutionLineMax = np.max(convolutionLine) middleY = np.where(convolutionLine == convolutionLineMax)[0][0] - round(upInterval * 2) middleRhos.append(y + middleY / UPSCALE) return { 'interval': interval, 'phi1': phi1, 'phi2': phi2, 'middleRhos': middleRhos, } def _scale_detection(detection, scale): """Scale detection coordinates by given factor.""" areas = [{ 'x': a['x'] * scale, 'y': a['y'] * scale, 'width': a['width'] * scale, 'height': a['height'] * scale, 'staff_images': a.get('staff_images'), 'staves': { 'interval': a['staves']['interval'] * scale, 'phi1': a['staves']['phi1'] * scale, 'phi2': a['staves']['phi2'] * scale, 'middleRhos': [rho * scale for rho in a['staves']['middleRhos']], } if (a.get('staves') and a['staves'].get('middleRhos') is not None) else None, } for a in detection['areas']] return {'areas': areas} class PageLayout: """Process layout heatmap to extract page parameters.""" def __init__(self, heatmap): """ heatmap: (3, H, W) - channels: [VL, StaffBox, HL] """ lines_map = heatmap[2] self.interval = self.measure_interval(lines_map) heatmap_uint8 = np.uint8(heatmap * 255) heatmap_uint8 = np.moveaxis(heatmap_uint8, 0, -1) self.image = PIL.Image.fromarray(heatmap_uint8, 'RGB') self.heatmap = heatmap_uint8 staves_map = heatmap_uint8[:, :, 1] self.theta = self.measure_theta(staves_map) def json(self): return { 'image': encode_image_base64(self.image), 'theta': self.theta, 'interval': self.interval, } def detect(self, image, ratio): """Full detection: systems, staves, staff images.""" if self.theta is None: return { 'theta': self.theta, 'interval': self.interval * image.shape[1] / RESIZE_WIDTH, 'detection': None, } original_size = (image.shape[1], image.shape[0]) aligned_height = int(image.shape[1] * ratio) if image.shape[0] < aligned_height: image = np.pad(image, ((0, aligned_height - image.shape[0]), (0, 0), (0, 0)), mode='constant') elif image.shape[0] > aligned_height: image = image[:aligned_height] canvas_size = (original_size[0], aligned_height) while canvas_size[0] < CANVAS_WIDTH_MIN: canvas_size = (canvas_size[0] * 2, canvas_size[1] * 2) if canvas_size[0] > original_size[0]: image = cv2.resize(image, canvas_size) rot_mat = cv2.getRotationMatrix2D((canvas_size[0] / 2, canvas_size[1] / 2), self.theta * 180 / np.pi, 1) image = cv2.warpAffine(image, rot_mat, canvas_size, flags=cv2.INTER_CUBIC) if len(image.shape) < 3: image = np.expand_dims(image, -1) heatmap = cv2.resize(self.heatmap, (canvas_size[0], round(canvas_size[0] * self.heatmap.shape[0] / self.heatmap.shape[1])), interpolation=cv2.INTER_CUBIC) if heatmap.shape[0] > canvas_size[1]: heatmap = heatmap[:canvas_size[1]] elif heatmap.shape[0] < canvas_size[1]: heatmap = np.pad(heatmap, ((0, canvas_size[1] - heatmap.shape[0]), (0, 0), (0, 0)), mode='constant') heatmap = cv2.warpAffine(heatmap, rot_mat, canvas_size, flags=cv2.INTER_LINEAR) HB = heatmap[:, :, 1] HL = heatmap[:, :, 2] block = heatmap.max(axis=2) detection = _detect_systems(block) canvas_interval = self.interval * canvas_size[0] / RESIZE_WIDTH for si, area in enumerate(detection['areas']): l, r, t, b = map(round, (area['x'], area['x'] + area['width'], area['y'], area['y'] + area['height'])) system_image = image[t:b, l:r, :] hb = HB[t:b, l:r] hl = HL[t:b, l:r] area['staves'] = _detect_staves_from_hbl(hb, hl, canvas_interval) if not area.get('staves') or area['staves'].get('middleRhos') is None: continue area['staff_images'] = [] interval = area['staves']['interval'] unit_scaling = UNIT_SIZE / interval padding_left = round(STAFF_PADDING_LEFT * UNIT_SIZE / interval / unit_scaling) staff_width = round(system_image.shape[1] * unit_scaling) + STAFF_PADDING_LEFT staff_size = (staff_width, STAFF_HEIGHT_UNITS * UNIT_SIZE) for ssi, rho in enumerate(area['staves']['middleRhos']): map_x = (np.tile(np.arange(staff_size[0], dtype=np.float32), (staff_size[1], 1)) - STAFF_PADDING_LEFT) / unit_scaling map_y = (np.tile(np.arange(staff_size[1], dtype=np.float32), (staff_size[0], 1)).T - staff_size[1] / 2) / unit_scaling + rho map_x, map_y = map_x.astype(np.float32), map_y.astype(np.float32) staff_image = cv2.remap(system_image, map_x, map_y, cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT, borderValue=(255, 255, 255)) # Encode staff image as PNG bytes for downstream predictors _, png_data = cv2.imencode('.png', staff_image) staff_bytes = png_data.tobytes() area['staff_images'].append({ 'hash': None, 'image': staff_bytes, 'position': { 'x': -STAFF_PADDING_LEFT / UNIT_SIZE, 'y': -STAFF_HEIGHT_UNITS / 2, 'width': staff_size[0] / UNIT_SIZE, 'height': staff_size[1] / UNIT_SIZE, }, }) page_interval = self.interval * original_size[0] / RESIZE_WIDTH return { 'sourceSize': { 'width': original_size[0], 'height': original_size[1], }, 'theta': self.theta, 'interval': page_interval, 'detection': _scale_detection(detection, original_size[0] / canvas_size[0]), } @staticmethod def measure_theta(heatmap): """Measure page rotation angle using Hough lines.""" edges = cv2.Canny(heatmap, 50, 150, apertureSize=3) lines = cv2.HoughLines( edges, 1, np.pi / 18000, round(heatmap.shape[1] * 0.4), min_theta=np.pi * 0.48, max_theta=np.pi * 0.52 ) if lines is None: return None avg_theta = sum(line[0][1] for line in lines) / len(lines) return float(avg_theta - np.pi / 2) @staticmethod def measure_interval(heatmap): """Measure staff line interval using autocorrelation.""" UPSCALE = 4 width = heatmap.shape[1] heatmap = cv2.resize( heatmap, (heatmap.shape[1] // UPSCALE, heatmap.shape[0] * UPSCALE), interpolation=cv2.INTER_LINEAR ) interval_min = round(width * 0.002 * UPSCALE) interval_max = round(width * 0.025 * UPSCALE) brights = [] for y in range(interval_min, interval_max): m1, m2 = heatmap[y:], heatmap[:-y] p = np.multiply(m1, m2) brights.append(np.mean(p)) # Subtract 2x interval to weaken harmonics brights = np.array([brights]) brights2 = cv2.resize(brights, (brights.shape[1] * 2, 1)) brights = brights.flatten()[interval_min:] brights2 = brights2.flatten()[:len(brights)] brights -= brights2 * 0.5 return (interval_min * 2 + int(np.argmax(brights))) / UPSCALE class LayoutService(TorchScriptPredictor): """Layout prediction service using TorchScript model.""" # Default transform pipeline DEFAULT_TRANS = ['Mono', 'HWC2CHW'] def __init__(self, model_path, device='cuda', trans=None, **kwargs): super().__init__(model_path, device) self.composer = Composer(trans or self.DEFAULT_TRANS) def predict(self, streams, **kwargs): """ Predict page layout from image streams (basic mode). streams: list of image byte buffers yields: layout JSON results """ for stream in streams: image = array_from_image_stream(stream) if image is None: yield {'error': 'Invalid image'} continue image = np.expand_dims(image, 0) # (1, H, W, C) image = normalize_image_dimension(image) batch, _ = self.composer(image, np.ones((1, 4, 4, 2))) batch = torch.from_numpy(batch).to(self.device) output = self.run_inference(batch) output = output.cpu().numpy() hotmap = output[0] # (C, H, W) yield PageLayout(hotmap).json() def predictDetection(self, streams): """ Predict layout with full system/staff detection. streams: list of image byte buffers yields: detection results with sourceSize, theta, interval, detection.areas """ images = [array_from_image_stream(stream) for stream in streams] ratio = max(img.shape[0] / img.shape[1] for img in images) height = int(RESIZE_WIDTH * ratio) height += -height % 4 unified_images = [resize_page_image(img, (RESIZE_WIDTH, height)) for img in images] image_array = np.stack(unified_images, axis=0) batch, _ = self.composer(image_array, np.ones((1, 4, 4, 2))) batch = torch.from_numpy(batch).to(self.device) with torch.no_grad(): output = self.run_inference(batch) output = output.cpu().numpy() for i, heatmap in enumerate(output): image = images[i] layout = PageLayout(heatmap) result = layout.detect(image, ratio) yield result