Spaces:
Running
Running
| """ | |
| Layout prediction service. | |
| Detects page layout: staff boxes, lines, intervals, rotation angle. | |
| Includes system/staff detection for the full prediction pipeline. | |
| """ | |
| import numpy as np | |
| import torch | |
| import cv2 | |
| import PIL.Image | |
| import hashlib | |
| import io | |
| import logging | |
| from predictors.torchscript_predictor import TorchScriptPredictor | |
| from common.image_utils import ( | |
| array_from_image_stream, resize_page_image, normalize_image_dimension, | |
| encode_image_base64 | |
| ) | |
| from common.transform import Composer | |
| RESIZE_WIDTH = 600 | |
| CANVAS_WIDTH_MIN = 1024 | |
| SYSTEM_HEIGHT_ENLARGE = 0.02 | |
| SYSTEM_LEFT_ENLARGE = 0.03 | |
| SYSTEM_RIGHT_ENLARGE = 0.01 | |
| STAFF_PADDING_LEFT = 32 | |
| STAFF_HEIGHT_UNITS = 24 | |
| UNIT_SIZE = 8 | |
| def _detect_systems(image): | |
| """Detect musical systems (staff groups) from max-channel heatmap.""" | |
| height, width = image.shape | |
| blur = cv2.GaussianBlur(image, (5, 5), 0) | |
| thresh = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 201, -40) | |
| contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| marginLeft = SYSTEM_LEFT_ENLARGE * width | |
| marginRight = SYSTEM_RIGHT_ENLARGE * width | |
| areas = [] | |
| for contour in contours: | |
| x, y, w, h = cv2.boundingRect(contour) | |
| rw = w / width | |
| rh = h / width | |
| if (rw > 0.6 and rh > 0.02) or (rw > 0.12 and rh > 0.2): | |
| left = max(x - marginLeft, 0) | |
| right = min(x + w + marginRight, width) | |
| areas.append({'x': left, 'y': y, 'width': right - left, 'height': h}) | |
| areas.sort(key=lambda a: a['y']) | |
| # Enlarge heights to include surrounding space | |
| marginY = SYSTEM_HEIGHT_ENLARGE * width | |
| marginYMax = marginY * 8 | |
| ctx = {'lastMargin': 0} | |
| def enlarge(i, area, ctx): | |
| top = area['y'] | |
| bottom = top + area['height'] | |
| if i > 0: | |
| lastArea = areas[i - 1] | |
| ctx['lastMargin'] = max(ctx['lastMargin'], lastArea['y'] + lastArea['height'], top - marginYMax) | |
| top = max(0, min(top - marginY, ctx['lastMargin'])) | |
| else: | |
| top = min(top, max(marginY, top - marginYMax)) | |
| if i < len(areas) - 1: | |
| nextArea = areas[i + 1] | |
| bottom = min(height, max(bottom + marginY, nextArea['y']), bottom + marginYMax) | |
| else: | |
| bottom = min(height, bottom + marginYMax) | |
| return {'top': top, 'bottom': bottom} | |
| enlarges = [enlarge(i, area, ctx) for i, area in enumerate(areas)] | |
| for i, area in enumerate(areas): | |
| area['y'] = enlarges[i]['top'] | |
| area['height'] = enlarges[i]['bottom'] - enlarges[i]['top'] | |
| return {'areas': areas} | |
| def _detect_staves_from_hbl(HB, HL, interval): | |
| """Detect individual staves within a system using staff-box and horizontal-line heatmaps.""" | |
| _, HB = cv2.threshold(HB, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| contours, _ = cv2.findContours(HB, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| height, width = HB.shape | |
| STAFF_HEIGHT_MIN = interval * 3 | |
| STAFF_WIDTH_MIN = max(width * 0.6, width - interval * 12) | |
| UPSCALE = 4 | |
| upInterval = interval * UPSCALE | |
| rects = map(cv2.boundingRect, contours) | |
| rects = filter(lambda rect: rect[2] > STAFF_WIDTH_MIN and rect[3] > STAFF_HEIGHT_MIN, rects) | |
| rects = sorted(rects, key=lambda rect: rect[1]) | |
| # Merge overlapping rectangles | |
| preRects = [] | |
| for rect in rects: | |
| x, y, w, h = rect | |
| ri = next((i for i, rc in enumerate(preRects) | |
| if (y + h / 2) - (rc[1] + rc[3] / 2) < (h + rc[3]) / 2), -1) | |
| if ri < 0: | |
| preRects.append(rect) | |
| else: | |
| rc = list(preRects[ri]) | |
| if w > rc[2]: | |
| preRects[ri] = (min(x, rc[0]), rect[1], max(x + w, rc[0] + rc[2]) - x, rect[3]) | |
| else: | |
| rc[0] = min(x, rc[0]) | |
| rc[2] = max(x + w, rc[0] + rc[2]) - rc[0] | |
| if len(preRects) == 0: | |
| return {'reason': 'no block rectangles detected'} | |
| phi1 = min(rc[0] for rc in preRects) | |
| phi2 = min(rc[0] + rc[2] for rc in preRects) | |
| middleRhos = [] | |
| for rect in preRects: | |
| rx, ry, rw, rh = rect | |
| x = max(rx, 0) | |
| y = max(round(ry - interval), 0) | |
| roi = (x, y, min(rw, width - x), min(round(rh + interval + ry - y), height - y)) | |
| staffLines = HL[roi[1]:roi[1] + roi[3], roi[0]:roi[0] + roi[2]] | |
| hlineColumn = cv2.resize(staffLines, (1, staffLines.shape[0] * UPSCALE), 0, 0, cv2.INTER_LINEAR).flatten() | |
| i1 = round(upInterval) | |
| kernel = np.zeros(i1 * 4 + 1) | |
| kernel[::i1] = 1 | |
| convolutionLine = np.convolve(hlineColumn, kernel) | |
| convolutionLineMax = np.max(convolutionLine) | |
| middleY = np.where(convolutionLine == convolutionLineMax)[0][0] - round(upInterval * 2) | |
| middleRhos.append(y + middleY / UPSCALE) | |
| return { | |
| 'interval': interval, | |
| 'phi1': phi1, | |
| 'phi2': phi2, | |
| 'middleRhos': middleRhos, | |
| } | |
| def _scale_detection(detection, scale): | |
| """Scale detection coordinates by given factor.""" | |
| areas = [{ | |
| 'x': a['x'] * scale, | |
| 'y': a['y'] * scale, | |
| 'width': a['width'] * scale, | |
| 'height': a['height'] * scale, | |
| 'staff_images': a.get('staff_images'), | |
| 'staves': { | |
| 'interval': a['staves']['interval'] * scale, | |
| 'phi1': a['staves']['phi1'] * scale, | |
| 'phi2': a['staves']['phi2'] * scale, | |
| 'middleRhos': [rho * scale for rho in a['staves']['middleRhos']], | |
| } if (a.get('staves') and a['staves'].get('middleRhos') is not None) else None, | |
| } for a in detection['areas']] | |
| return {'areas': areas} | |
| class PageLayout: | |
| """Process layout heatmap to extract page parameters.""" | |
| def __init__(self, heatmap): | |
| """ | |
| heatmap: (3, H, W) - channels: [VL, StaffBox, HL] | |
| """ | |
| lines_map = heatmap[2] | |
| self.interval = self.measure_interval(lines_map) | |
| heatmap_uint8 = np.uint8(heatmap * 255) | |
| heatmap_uint8 = np.moveaxis(heatmap_uint8, 0, -1) | |
| self.image = PIL.Image.fromarray(heatmap_uint8, 'RGB') | |
| self.heatmap = heatmap_uint8 | |
| staves_map = heatmap_uint8[:, :, 1] | |
| self.theta = self.measure_theta(staves_map) | |
| def json(self): | |
| return { | |
| 'image': encode_image_base64(self.image), | |
| 'theta': self.theta, | |
| 'interval': self.interval, | |
| } | |
| def detect(self, image, ratio): | |
| """Full detection: systems, staves, staff images.""" | |
| if self.theta is None: | |
| return { | |
| 'theta': self.theta, | |
| 'interval': self.interval * image.shape[1] / RESIZE_WIDTH, | |
| 'detection': None, | |
| } | |
| original_size = (image.shape[1], image.shape[0]) | |
| aligned_height = int(image.shape[1] * ratio) | |
| if image.shape[0] < aligned_height: | |
| image = np.pad(image, ((0, aligned_height - image.shape[0]), (0, 0), (0, 0)), mode='constant') | |
| elif image.shape[0] > aligned_height: | |
| image = image[:aligned_height] | |
| canvas_size = (original_size[0], aligned_height) | |
| while canvas_size[0] < CANVAS_WIDTH_MIN: | |
| canvas_size = (canvas_size[0] * 2, canvas_size[1] * 2) | |
| if canvas_size[0] > original_size[0]: | |
| image = cv2.resize(image, canvas_size) | |
| rot_mat = cv2.getRotationMatrix2D((canvas_size[0] / 2, canvas_size[1] / 2), self.theta * 180 / np.pi, 1) | |
| image = cv2.warpAffine(image, rot_mat, canvas_size, flags=cv2.INTER_CUBIC) | |
| if len(image.shape) < 3: | |
| image = np.expand_dims(image, -1) | |
| heatmap = cv2.resize(self.heatmap, (canvas_size[0], round(canvas_size[0] * self.heatmap.shape[0] / self.heatmap.shape[1])), interpolation=cv2.INTER_CUBIC) | |
| if heatmap.shape[0] > canvas_size[1]: | |
| heatmap = heatmap[:canvas_size[1]] | |
| elif heatmap.shape[0] < canvas_size[1]: | |
| heatmap = np.pad(heatmap, ((0, canvas_size[1] - heatmap.shape[0]), (0, 0), (0, 0)), mode='constant') | |
| heatmap = cv2.warpAffine(heatmap, rot_mat, canvas_size, flags=cv2.INTER_LINEAR) | |
| HB = heatmap[:, :, 1] | |
| HL = heatmap[:, :, 2] | |
| block = heatmap.max(axis=2) | |
| detection = _detect_systems(block) | |
| canvas_interval = self.interval * canvas_size[0] / RESIZE_WIDTH | |
| for si, area in enumerate(detection['areas']): | |
| l, r, t, b = map(round, (area['x'], area['x'] + area['width'], area['y'], area['y'] + area['height'])) | |
| system_image = image[t:b, l:r, :] | |
| hb = HB[t:b, l:r] | |
| hl = HL[t:b, l:r] | |
| area['staves'] = _detect_staves_from_hbl(hb, hl, canvas_interval) | |
| if not area.get('staves') or area['staves'].get('middleRhos') is None: | |
| continue | |
| area['staff_images'] = [] | |
| interval = area['staves']['interval'] | |
| unit_scaling = UNIT_SIZE / interval | |
| padding_left = round(STAFF_PADDING_LEFT * UNIT_SIZE / interval / unit_scaling) | |
| staff_width = round(system_image.shape[1] * unit_scaling) + STAFF_PADDING_LEFT | |
| staff_size = (staff_width, STAFF_HEIGHT_UNITS * UNIT_SIZE) | |
| for ssi, rho in enumerate(area['staves']['middleRhos']): | |
| map_x = (np.tile(np.arange(staff_size[0], dtype=np.float32), (staff_size[1], 1)) - STAFF_PADDING_LEFT) / unit_scaling | |
| map_y = (np.tile(np.arange(staff_size[1], dtype=np.float32), (staff_size[0], 1)).T - staff_size[1] / 2) / unit_scaling + rho | |
| map_x, map_y = map_x.astype(np.float32), map_y.astype(np.float32) | |
| staff_image = cv2.remap(system_image, map_x, map_y, cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT, borderValue=(255, 255, 255)) | |
| # Encode staff image as PNG bytes for downstream predictors | |
| _, png_data = cv2.imencode('.png', staff_image) | |
| staff_bytes = png_data.tobytes() | |
| area['staff_images'].append({ | |
| 'hash': None, | |
| 'image': staff_bytes, | |
| 'position': { | |
| 'x': -STAFF_PADDING_LEFT / UNIT_SIZE, | |
| 'y': -STAFF_HEIGHT_UNITS / 2, | |
| 'width': staff_size[0] / UNIT_SIZE, | |
| 'height': staff_size[1] / UNIT_SIZE, | |
| }, | |
| }) | |
| page_interval = self.interval * original_size[0] / RESIZE_WIDTH | |
| return { | |
| 'sourceSize': { | |
| 'width': original_size[0], | |
| 'height': original_size[1], | |
| }, | |
| 'theta': self.theta, | |
| 'interval': page_interval, | |
| 'detection': _scale_detection(detection, original_size[0] / canvas_size[0]), | |
| } | |
| def measure_theta(heatmap): | |
| """Measure page rotation angle using Hough lines.""" | |
| edges = cv2.Canny(heatmap, 50, 150, apertureSize=3) | |
| lines = cv2.HoughLines( | |
| edges, 1, np.pi / 18000, | |
| round(heatmap.shape[1] * 0.4), | |
| min_theta=np.pi * 0.48, | |
| max_theta=np.pi * 0.52 | |
| ) | |
| if lines is None: | |
| return None | |
| avg_theta = sum(line[0][1] for line in lines) / len(lines) | |
| return float(avg_theta - np.pi / 2) | |
| def measure_interval(heatmap): | |
| """Measure staff line interval using autocorrelation.""" | |
| UPSCALE = 4 | |
| width = heatmap.shape[1] | |
| heatmap = cv2.resize( | |
| heatmap, | |
| (heatmap.shape[1] // UPSCALE, heatmap.shape[0] * UPSCALE), | |
| interpolation=cv2.INTER_LINEAR | |
| ) | |
| interval_min = round(width * 0.002 * UPSCALE) | |
| interval_max = round(width * 0.025 * UPSCALE) | |
| brights = [] | |
| for y in range(interval_min, interval_max): | |
| m1, m2 = heatmap[y:], heatmap[:-y] | |
| p = np.multiply(m1, m2) | |
| brights.append(np.mean(p)) | |
| # Subtract 2x interval to weaken harmonics | |
| brights = np.array([brights]) | |
| brights2 = cv2.resize(brights, (brights.shape[1] * 2, 1)) | |
| brights = brights.flatten()[interval_min:] | |
| brights2 = brights2.flatten()[:len(brights)] | |
| brights -= brights2 * 0.5 | |
| return (interval_min * 2 + int(np.argmax(brights))) / UPSCALE | |
| class LayoutService(TorchScriptPredictor): | |
| """Layout prediction service using TorchScript model.""" | |
| # Default transform pipeline | |
| DEFAULT_TRANS = ['Mono', 'HWC2CHW'] | |
| def __init__(self, model_path, device='cuda', trans=None, **kwargs): | |
| super().__init__(model_path, device) | |
| self.composer = Composer(trans or self.DEFAULT_TRANS) | |
| def predict(self, streams, **kwargs): | |
| """ | |
| Predict page layout from image streams (basic mode). | |
| streams: list of image byte buffers | |
| yields: layout JSON results | |
| """ | |
| for stream in streams: | |
| image = array_from_image_stream(stream) | |
| if image is None: | |
| yield {'error': 'Invalid image'} | |
| continue | |
| image = np.expand_dims(image, 0) # (1, H, W, C) | |
| image = normalize_image_dimension(image) | |
| batch, _ = self.composer(image, np.ones((1, 4, 4, 2))) | |
| batch = torch.from_numpy(batch).to(self.device) | |
| output = self.run_inference(batch) | |
| output = output.cpu().numpy() | |
| hotmap = output[0] # (C, H, W) | |
| yield PageLayout(hotmap).json() | |
| def predictDetection(self, streams): | |
| """ | |
| Predict layout with full system/staff detection. | |
| streams: list of image byte buffers | |
| yields: detection results with sourceSize, theta, interval, detection.areas | |
| """ | |
| images = [array_from_image_stream(stream) for stream in streams] | |
| ratio = max(img.shape[0] / img.shape[1] for img in images) | |
| height = int(RESIZE_WIDTH * ratio) | |
| height += -height % 4 | |
| unified_images = [resize_page_image(img, (RESIZE_WIDTH, height)) for img in images] | |
| image_array = np.stack(unified_images, axis=0) | |
| batch, _ = self.composer(image_array, np.ones((1, 4, 4, 2))) | |
| batch = torch.from_numpy(batch).to(self.device) | |
| with torch.no_grad(): | |
| output = self.run_inference(batch) | |
| output = output.cpu().numpy() | |
| for i, heatmap in enumerate(output): | |
| image = images[i] | |
| layout = PageLayout(heatmap) | |
| result = layout.detect(image, ratio) | |
| yield result | |