Spaces:
Running
Running
| """ | |
| Semantic prediction service. | |
| Detects and classifies musical symbols (notes, rests, clefs, etc.). | |
| Supports both single-model and multi-model cluster directories. | |
| """ | |
| import os | |
| import re | |
| import math | |
| import numpy as np | |
| import torch | |
| import cv2 | |
| import yaml | |
| import logging | |
| from predictors.torchscript_predictor import TorchScriptPredictor, resolve_model_path | |
| from common.image_utils import ( | |
| array_from_image_stream, slice_feature, splice_output_tensor, | |
| MARGIN_DIVIDER | |
| ) | |
| from common.transform import Composer | |
| VERTICAL_UNITS = 24. | |
| POINT_RADIUS_MAX = 8 | |
| def detect_points(heatmap, vertical_units=24): | |
| """Detect point features (notes, symbols) in heatmap.""" | |
| unit = heatmap.shape[0] / vertical_units | |
| y0 = heatmap.shape[0] / 2.0 | |
| blur_kernel = (heatmap.shape[0] // 128) * 2 + 1 | |
| if blur_kernel > 1: | |
| heatmap_blur = cv2.GaussianBlur(heatmap, (blur_kernel, blur_kernel), 0) | |
| else: | |
| heatmap_blur = heatmap | |
| thresh = cv2.adaptiveThreshold( | |
| heatmap_blur, 255, | |
| cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 3, 0 | |
| ) | |
| contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| points = [] | |
| for c in contours: | |
| (x, y), radius = cv2.minEnclosingCircle(c) | |
| confidence = 0 | |
| for px in range(max(math.floor(x - radius), 0), min(math.ceil(x + radius), heatmap.shape[1])): | |
| for py in range(max(math.floor(y - radius), 0), min(math.ceil(y + radius), heatmap.shape[0])): | |
| confidence += heatmap[py, px] / 255. | |
| if radius < POINT_RADIUS_MAX: | |
| points.append({ | |
| 'mark': (x, y, radius), | |
| 'x': x / unit, | |
| 'y': (y - y0) / unit, | |
| 'confidence': float(confidence), | |
| }) | |
| return points | |
| def detect_vlines(heatmap, vertical_units=24): | |
| """Detect vertical line features (barlines, stems) in heatmap.""" | |
| unit = heatmap.shape[0] / vertical_units | |
| y0 = heatmap.shape[0] / 2.0 | |
| blur_kernel = (heatmap.shape[0] // 128) * 2 + 1 | |
| if blur_kernel > 1: | |
| heatmap_blur = cv2.GaussianBlur(heatmap, (blur_kernel, blur_kernel), 0) | |
| else: | |
| heatmap_blur = heatmap | |
| thresh = cv2.adaptiveThreshold( | |
| heatmap_blur, 255, | |
| cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 3, 0 | |
| ) | |
| contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| lines = [] | |
| for contour in contours: | |
| left, top, width, height = cv2.boundingRect(contour) | |
| x = (left + width / 2) / unit | |
| y1 = (top - y0) / unit | |
| y2 = (top + height - y0) / unit | |
| confidence = 0 | |
| for px in range(left, left + width): | |
| for py in range(top, top + height): | |
| confidence += heatmap[py, px] / 255. | |
| length = max(height, 2.5) | |
| confidence /= length * 0.8 | |
| lines.append({ | |
| 'x': x, | |
| 'y': y1, | |
| 'extension': {'y1': y1, 'y2': y2}, | |
| 'confidence': float(confidence), | |
| 'mark': (left + width / 2, top, top + height), | |
| }) | |
| return lines | |
| def detect_rectangles(heatmap, vertical_units=24): | |
| """Detect rectangular features (text boxes) in heatmap.""" | |
| unit = heatmap.shape[0] / vertical_units | |
| y0 = heatmap.shape[0] / 2.0 | |
| _, thresh = cv2.threshold(heatmap, 92, 255, cv2.THRESH_BINARY) | |
| contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| rects = [] | |
| for contour in contours: | |
| left, top, width, height = cv2.boundingRect(contour) | |
| if width * height / unit / unit < 2: | |
| continue | |
| x = (left + width / 2) / unit | |
| y = (top + height / 2) / unit | |
| confidence = 0 | |
| for px in range(left, left + width): | |
| for py in range(top, top + height): | |
| confidence += heatmap[py, px] / 255. | |
| confidence /= width * height | |
| rects.append({ | |
| 'x': x, | |
| 'y': y - y0 / unit, | |
| 'extension': {'width': width / unit, 'height': height / unit}, | |
| 'confidence': float(confidence), | |
| 'mark': (left, top, width, height), | |
| }) | |
| return rects | |
| def detect_boxes(heatmap, vertical_units=24): | |
| """Detect rotated box features in heatmap.""" | |
| unit = heatmap.shape[0] / vertical_units | |
| _, thresh = cv2.threshold(heatmap, 92, 255, cv2.THRESH_BINARY) | |
| contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| rects = [] | |
| for contour in contours: | |
| rect = cv2.minAreaRect(contour) | |
| pos, size, theta = rect | |
| confidence = math.sqrt(size[0] * size[1]) | |
| if min(*size) / unit < 8: | |
| continue | |
| rects.append({ | |
| 'x': pos[0], | |
| 'y': pos[1], | |
| 'extension': {'width': size[0], 'height': size[1], 'theta': theta}, | |
| 'confidence': confidence, | |
| 'mark': rect, | |
| }) | |
| return rects | |
| class ScoreSemantic: | |
| """Score semantic analysis results.""" | |
| def __init__(self, heatmaps, labels, confidence_table=None): | |
| self.data = { | |
| '__prototype': 'SemanticGraph', | |
| 'points': [], | |
| 'staffY': None, | |
| } | |
| assert len(labels) == len(heatmaps), \ | |
| f'classes - heatmaps count mismatch: {len(labels)} - {len(heatmaps)}' | |
| for i, semantic in enumerate(labels): | |
| mean_confidence = 1 | |
| if confidence_table is not None: | |
| item = confidence_table[i] | |
| assert item['semantic'] == semantic | |
| mean_confidence = max(item['mean_confidence'], 1e-4) | |
| if re.match(r'^vline_', semantic): | |
| lines = detect_vlines(heatmaps[i], vertical_units=VERTICAL_UNITS) | |
| for line in lines: | |
| self.data['points'].append({ | |
| 'semantic': semantic, | |
| 'x': line['x'], | |
| 'y': line['y'], | |
| 'extension': line['extension'], | |
| 'confidence': line['confidence'] / mean_confidence, | |
| }) | |
| elif re.match(r'^rect_', semantic): | |
| rectangles = detect_rectangles(heatmaps[i], vertical_units=VERTICAL_UNITS) | |
| for rect in rectangles: | |
| self.data['points'].append({ | |
| 'semantic': semantic, | |
| 'x': rect['x'], | |
| 'y': rect['y'], | |
| 'extension': rect['extension'], | |
| 'confidence': rect['confidence'] / mean_confidence, | |
| }) | |
| elif re.match(r'^box_', semantic): | |
| boxes = detect_boxes(heatmaps[i], vertical_units=VERTICAL_UNITS) | |
| for rect in boxes: | |
| self.data['points'].append({ | |
| 'semantic': semantic, | |
| 'x': rect['x'], | |
| 'y': rect['y'], | |
| 'extension': rect['extension'], | |
| 'confidence': rect['confidence'] / mean_confidence, | |
| }) | |
| else: | |
| points = detect_points(heatmaps[i], vertical_units=VERTICAL_UNITS) | |
| for point in points: | |
| self.data['points'].append({ | |
| 'semantic': semantic, | |
| 'x': point['x'], | |
| 'y': point['y'], | |
| 'confidence': point['confidence'] / mean_confidence, | |
| }) | |
| def json(self): | |
| return self.data | |
| def _is_cluster_dir(model_path): | |
| """Check if model_path is a semantic cluster directory (has 'subs' in .state.yaml).""" | |
| if not os.path.isdir(model_path): | |
| return False | |
| state_file = os.path.join(model_path, '.state.yaml') | |
| if not os.path.exists(state_file): | |
| return False | |
| with open(state_file, 'r') as f: | |
| state = yaml.safe_load(f) | |
| return 'subs' in state | |
| class SemanticService: | |
| """Semantic prediction service. | |
| Handles both single TorchScript models and multi-model cluster directories. | |
| A cluster directory has a .state.yaml with 'subs' listing sub-model directories. | |
| """ | |
| DEFAULT_TRANS = ['Mono', 'HWC2CHW'] | |
| DEFAULT_SLICING_WIDTH = 512 | |
| def __init__(self, model_path, device='cuda', trans=None, slicing_width=None, | |
| labels=None, confidence_table=None, **kwargs): | |
| self.device = device | |
| if _is_cluster_dir(model_path): | |
| self._init_cluster(model_path, device, trans, slicing_width) | |
| else: | |
| self._init_single(model_path, device, trans, slicing_width, labels, confidence_table) | |
| def _init_single(self, model_path, device, trans, slicing_width, labels, confidence_table): | |
| """Initialize with a single TorchScript model.""" | |
| resolved = resolve_model_path(model_path) | |
| self.model = torch.jit.load(resolved, map_location=device) | |
| self.model.eval() | |
| self.sub_models = None | |
| self.composer = Composer(trans or self.DEFAULT_TRANS) | |
| self.slicing_width = slicing_width or self.DEFAULT_SLICING_WIDTH | |
| self.labels = labels or [] | |
| self.confidence_table = confidence_table | |
| logging.info('SemanticService: single model loaded: %s', resolved) | |
| def _init_cluster(self, model_path, device, trans, slicing_width): | |
| """Initialize with a multi-model cluster directory.""" | |
| state_file = os.path.join(model_path, '.state.yaml') | |
| with open(state_file, 'r') as f: | |
| cluster_state = yaml.safe_load(f) | |
| # Get predictor config from cluster .state.yaml | |
| predictor_config = cluster_state.get('predictor', {}) | |
| self.composer = Composer( | |
| trans or predictor_config.get('trans') or self.DEFAULT_TRANS | |
| ) | |
| self.slicing_width = ( | |
| slicing_width | |
| or predictor_config.get('slicing_width') | |
| or self.DEFAULT_SLICING_WIDTH | |
| ) | |
| # Confidence table dict from cluster config | |
| ct_dict = predictor_config.get('confidence_table', {}) | |
| # Load each sub-model | |
| self.sub_models = [] | |
| self.labels = [] | |
| subs = cluster_state.get('subs', []) | |
| for sub_name in subs: | |
| sub_dir = os.path.join(model_path, sub_name) | |
| sub_state_file = os.path.join(sub_dir, '.state.yaml') | |
| with open(sub_state_file, 'r') as f: | |
| sub_state = yaml.safe_load(f) | |
| sub_labels = sub_state.get('data', {}).get('args', {}).get('labels', []) | |
| sub_model_file = resolve_model_path(sub_dir) | |
| model = torch.jit.load(sub_model_file, map_location=device) | |
| model.eval() | |
| self.sub_models.append(model) | |
| self.labels.extend(sub_labels) | |
| logging.info(' sub-model %s: %d labels, file=%s', | |
| sub_name, len(sub_labels), os.path.basename(sub_model_file)) | |
| # Build confidence table list matching label order | |
| self.confidence_table = None | |
| if ct_dict: | |
| self.confidence_table = [] | |
| for label in self.labels: | |
| mean_conf = ct_dict.get(label, 1.0) | |
| self.confidence_table.append({ | |
| 'semantic': label, | |
| 'mean_confidence': mean_conf, | |
| }) | |
| self.model = None # not used for cluster | |
| logging.info('SemanticService: cluster loaded with %d sub-models, %d total labels', | |
| len(self.sub_models), len(self.labels)) | |
| def run_inference(self, batch): | |
| """Run model inference with no_grad context.""" | |
| with torch.no_grad(): | |
| if self.sub_models is not None: | |
| # Cluster: run each sub-model and concatenate channels | |
| outputs = [] | |
| for model in self.sub_models: | |
| output = model(batch) | |
| if isinstance(output, tuple): | |
| _, semantic = output | |
| else: | |
| semantic = output | |
| outputs.append(semantic) | |
| return torch.cat(outputs, dim=1) | |
| else: | |
| return self.model(batch) | |
| def predict(self, streams, **kwargs): | |
| """ | |
| Predict semantic symbols from image streams. | |
| streams: list of image byte buffers | |
| yields: semantic graph results | |
| """ | |
| for stream in streams: | |
| image = array_from_image_stream(stream) | |
| if image is None: | |
| yield {'error': 'Invalid image'} | |
| continue | |
| # Slice image | |
| pieces = list(slice_feature( | |
| image, | |
| width=self.slicing_width, | |
| overlapping=2 / MARGIN_DIVIDER, | |
| padding=True | |
| )) | |
| pieces = np.array(pieces, dtype=np.uint8) | |
| # Transform | |
| staves, _ = self.composer(pieces, np.ones((1, 4, 4, 2))) | |
| batch = torch.from_numpy(staves).to(self.device) | |
| # Inference | |
| output = self.run_inference(batch) | |
| # Handle tuple output (single model case) | |
| if isinstance(output, tuple): | |
| _, output = output | |
| semantic = splice_output_tensor(output) | |
| # Build semantic result | |
| ss = ScoreSemantic( | |
| np.uint8(semantic * 255), | |
| self.labels, | |
| confidence_table=self.confidence_table | |
| ) | |
| yield ss.json() | |