""" Semantic prediction service. Detects and classifies musical symbols (notes, rests, clefs, etc.). Supports both single-model and multi-model cluster directories. """ import os import re import math import numpy as np import torch import cv2 import yaml import logging from predictors.torchscript_predictor import TorchScriptPredictor, resolve_model_path from common.image_utils import ( array_from_image_stream, slice_feature, splice_output_tensor, MARGIN_DIVIDER ) from common.transform import Composer VERTICAL_UNITS = 24. POINT_RADIUS_MAX = 8 def detect_points(heatmap, vertical_units=24): """Detect point features (notes, symbols) in heatmap.""" unit = heatmap.shape[0] / vertical_units y0 = heatmap.shape[0] / 2.0 blur_kernel = (heatmap.shape[0] // 128) * 2 + 1 if blur_kernel > 1: heatmap_blur = cv2.GaussianBlur(heatmap, (blur_kernel, blur_kernel), 0) else: heatmap_blur = heatmap thresh = cv2.adaptiveThreshold( heatmap_blur, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 3, 0 ) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) points = [] for c in contours: (x, y), radius = cv2.minEnclosingCircle(c) confidence = 0 for px in range(max(math.floor(x - radius), 0), min(math.ceil(x + radius), heatmap.shape[1])): for py in range(max(math.floor(y - radius), 0), min(math.ceil(y + radius), heatmap.shape[0])): confidence += heatmap[py, px] / 255. if radius < POINT_RADIUS_MAX: points.append({ 'mark': (x, y, radius), 'x': x / unit, 'y': (y - y0) / unit, 'confidence': float(confidence), }) return points def detect_vlines(heatmap, vertical_units=24): """Detect vertical line features (barlines, stems) in heatmap.""" unit = heatmap.shape[0] / vertical_units y0 = heatmap.shape[0] / 2.0 blur_kernel = (heatmap.shape[0] // 128) * 2 + 1 if blur_kernel > 1: heatmap_blur = cv2.GaussianBlur(heatmap, (blur_kernel, blur_kernel), 0) else: heatmap_blur = heatmap thresh = cv2.adaptiveThreshold( heatmap_blur, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 3, 0 ) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) lines = [] for contour in contours: left, top, width, height = cv2.boundingRect(contour) x = (left + width / 2) / unit y1 = (top - y0) / unit y2 = (top + height - y0) / unit confidence = 0 for px in range(left, left + width): for py in range(top, top + height): confidence += heatmap[py, px] / 255. length = max(height, 2.5) confidence /= length * 0.8 lines.append({ 'x': x, 'y': y1, 'extension': {'y1': y1, 'y2': y2}, 'confidence': float(confidence), 'mark': (left + width / 2, top, top + height), }) return lines def detect_rectangles(heatmap, vertical_units=24): """Detect rectangular features (text boxes) in heatmap.""" unit = heatmap.shape[0] / vertical_units y0 = heatmap.shape[0] / 2.0 _, thresh = cv2.threshold(heatmap, 92, 255, cv2.THRESH_BINARY) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) rects = [] for contour in contours: left, top, width, height = cv2.boundingRect(contour) if width * height / unit / unit < 2: continue x = (left + width / 2) / unit y = (top + height / 2) / unit confidence = 0 for px in range(left, left + width): for py in range(top, top + height): confidence += heatmap[py, px] / 255. confidence /= width * height rects.append({ 'x': x, 'y': y - y0 / unit, 'extension': {'width': width / unit, 'height': height / unit}, 'confidence': float(confidence), 'mark': (left, top, width, height), }) return rects def detect_boxes(heatmap, vertical_units=24): """Detect rotated box features in heatmap.""" unit = heatmap.shape[0] / vertical_units _, thresh = cv2.threshold(heatmap, 92, 255, cv2.THRESH_BINARY) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) rects = [] for contour in contours: rect = cv2.minAreaRect(contour) pos, size, theta = rect confidence = math.sqrt(size[0] * size[1]) if min(*size) / unit < 8: continue rects.append({ 'x': pos[0], 'y': pos[1], 'extension': {'width': size[0], 'height': size[1], 'theta': theta}, 'confidence': confidence, 'mark': rect, }) return rects class ScoreSemantic: """Score semantic analysis results.""" def __init__(self, heatmaps, labels, confidence_table=None): self.data = { '__prototype': 'SemanticGraph', 'points': [], 'staffY': None, } assert len(labels) == len(heatmaps), \ f'classes - heatmaps count mismatch: {len(labels)} - {len(heatmaps)}' for i, semantic in enumerate(labels): mean_confidence = 1 if confidence_table is not None: item = confidence_table[i] assert item['semantic'] == semantic mean_confidence = max(item['mean_confidence'], 1e-4) if re.match(r'^vline_', semantic): lines = detect_vlines(heatmaps[i], vertical_units=VERTICAL_UNITS) for line in lines: self.data['points'].append({ 'semantic': semantic, 'x': line['x'], 'y': line['y'], 'extension': line['extension'], 'confidence': line['confidence'] / mean_confidence, }) elif re.match(r'^rect_', semantic): rectangles = detect_rectangles(heatmaps[i], vertical_units=VERTICAL_UNITS) for rect in rectangles: self.data['points'].append({ 'semantic': semantic, 'x': rect['x'], 'y': rect['y'], 'extension': rect['extension'], 'confidence': rect['confidence'] / mean_confidence, }) elif re.match(r'^box_', semantic): boxes = detect_boxes(heatmaps[i], vertical_units=VERTICAL_UNITS) for rect in boxes: self.data['points'].append({ 'semantic': semantic, 'x': rect['x'], 'y': rect['y'], 'extension': rect['extension'], 'confidence': rect['confidence'] / mean_confidence, }) else: points = detect_points(heatmaps[i], vertical_units=VERTICAL_UNITS) for point in points: self.data['points'].append({ 'semantic': semantic, 'x': point['x'], 'y': point['y'], 'confidence': point['confidence'] / mean_confidence, }) def json(self): return self.data def _is_cluster_dir(model_path): """Check if model_path is a semantic cluster directory (has 'subs' in .state.yaml).""" if not os.path.isdir(model_path): return False state_file = os.path.join(model_path, '.state.yaml') if not os.path.exists(state_file): return False with open(state_file, 'r') as f: state = yaml.safe_load(f) return 'subs' in state class SemanticService: """Semantic prediction service. Handles both single TorchScript models and multi-model cluster directories. A cluster directory has a .state.yaml with 'subs' listing sub-model directories. """ DEFAULT_TRANS = ['Mono', 'HWC2CHW'] DEFAULT_SLICING_WIDTH = 512 def __init__(self, model_path, device='cuda', trans=None, slicing_width=None, labels=None, confidence_table=None, **kwargs): self.device = device if _is_cluster_dir(model_path): self._init_cluster(model_path, device, trans, slicing_width) else: self._init_single(model_path, device, trans, slicing_width, labels, confidence_table) def _init_single(self, model_path, device, trans, slicing_width, labels, confidence_table): """Initialize with a single TorchScript model.""" resolved = resolve_model_path(model_path) self.model = torch.jit.load(resolved, map_location=device) self.model.eval() self.sub_models = None self.composer = Composer(trans or self.DEFAULT_TRANS) self.slicing_width = slicing_width or self.DEFAULT_SLICING_WIDTH self.labels = labels or [] self.confidence_table = confidence_table logging.info('SemanticService: single model loaded: %s', resolved) def _init_cluster(self, model_path, device, trans, slicing_width): """Initialize with a multi-model cluster directory.""" state_file = os.path.join(model_path, '.state.yaml') with open(state_file, 'r') as f: cluster_state = yaml.safe_load(f) # Get predictor config from cluster .state.yaml predictor_config = cluster_state.get('predictor', {}) self.composer = Composer( trans or predictor_config.get('trans') or self.DEFAULT_TRANS ) self.slicing_width = ( slicing_width or predictor_config.get('slicing_width') or self.DEFAULT_SLICING_WIDTH ) # Confidence table dict from cluster config ct_dict = predictor_config.get('confidence_table', {}) # Load each sub-model self.sub_models = [] self.labels = [] subs = cluster_state.get('subs', []) for sub_name in subs: sub_dir = os.path.join(model_path, sub_name) sub_state_file = os.path.join(sub_dir, '.state.yaml') with open(sub_state_file, 'r') as f: sub_state = yaml.safe_load(f) sub_labels = sub_state.get('data', {}).get('args', {}).get('labels', []) sub_model_file = resolve_model_path(sub_dir) model = torch.jit.load(sub_model_file, map_location=device) model.eval() self.sub_models.append(model) self.labels.extend(sub_labels) logging.info(' sub-model %s: %d labels, file=%s', sub_name, len(sub_labels), os.path.basename(sub_model_file)) # Build confidence table list matching label order self.confidence_table = None if ct_dict: self.confidence_table = [] for label in self.labels: mean_conf = ct_dict.get(label, 1.0) self.confidence_table.append({ 'semantic': label, 'mean_confidence': mean_conf, }) self.model = None # not used for cluster logging.info('SemanticService: cluster loaded with %d sub-models, %d total labels', len(self.sub_models), len(self.labels)) def run_inference(self, batch): """Run model inference with no_grad context.""" with torch.no_grad(): if self.sub_models is not None: # Cluster: run each sub-model and concatenate channels outputs = [] for model in self.sub_models: output = model(batch) if isinstance(output, tuple): _, semantic = output else: semantic = output outputs.append(semantic) return torch.cat(outputs, dim=1) else: return self.model(batch) def predict(self, streams, **kwargs): """ Predict semantic symbols from image streams. streams: list of image byte buffers yields: semantic graph results """ for stream in streams: image = array_from_image_stream(stream) if image is None: yield {'error': 'Invalid image'} continue # Slice image pieces = list(slice_feature( image, width=self.slicing_width, overlapping=2 / MARGIN_DIVIDER, padding=True )) pieces = np.array(pieces, dtype=np.uint8) # Transform staves, _ = self.composer(pieces, np.ones((1, 4, 4, 2))) batch = torch.from_numpy(staves).to(self.device) # Inference output = self.run_inference(batch) # Handle tuple output (single model case) if isinstance(output, tuple): _, output = output semantic = splice_output_tensor(output) # Build semantic result ss = ScoreSemantic( np.uint8(semantic * 255), self.labels, confidence_table=self.confidence_table ) yield ss.json()