starry / backend /python-services /services /semantic_service.py
k-l-lambda's picture
feat: add Python ML services (CPU mode) with model download
2b7aae2
"""
Semantic prediction service.
Detects and classifies musical symbols (notes, rests, clefs, etc.).
Supports both single-model and multi-model cluster directories.
"""
import os
import re
import math
import numpy as np
import torch
import cv2
import yaml
import logging
from predictors.torchscript_predictor import TorchScriptPredictor, resolve_model_path
from common.image_utils import (
array_from_image_stream, slice_feature, splice_output_tensor,
MARGIN_DIVIDER
)
from common.transform import Composer
VERTICAL_UNITS = 24.
POINT_RADIUS_MAX = 8
def detect_points(heatmap, vertical_units=24):
"""Detect point features (notes, symbols) in heatmap."""
unit = heatmap.shape[0] / vertical_units
y0 = heatmap.shape[0] / 2.0
blur_kernel = (heatmap.shape[0] // 128) * 2 + 1
if blur_kernel > 1:
heatmap_blur = cv2.GaussianBlur(heatmap, (blur_kernel, blur_kernel), 0)
else:
heatmap_blur = heatmap
thresh = cv2.adaptiveThreshold(
heatmap_blur, 255,
cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 3, 0
)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
points = []
for c in contours:
(x, y), radius = cv2.minEnclosingCircle(c)
confidence = 0
for px in range(max(math.floor(x - radius), 0), min(math.ceil(x + radius), heatmap.shape[1])):
for py in range(max(math.floor(y - radius), 0), min(math.ceil(y + radius), heatmap.shape[0])):
confidence += heatmap[py, px] / 255.
if radius < POINT_RADIUS_MAX:
points.append({
'mark': (x, y, radius),
'x': x / unit,
'y': (y - y0) / unit,
'confidence': float(confidence),
})
return points
def detect_vlines(heatmap, vertical_units=24):
"""Detect vertical line features (barlines, stems) in heatmap."""
unit = heatmap.shape[0] / vertical_units
y0 = heatmap.shape[0] / 2.0
blur_kernel = (heatmap.shape[0] // 128) * 2 + 1
if blur_kernel > 1:
heatmap_blur = cv2.GaussianBlur(heatmap, (blur_kernel, blur_kernel), 0)
else:
heatmap_blur = heatmap
thresh = cv2.adaptiveThreshold(
heatmap_blur, 255,
cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 3, 0
)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
lines = []
for contour in contours:
left, top, width, height = cv2.boundingRect(contour)
x = (left + width / 2) / unit
y1 = (top - y0) / unit
y2 = (top + height - y0) / unit
confidence = 0
for px in range(left, left + width):
for py in range(top, top + height):
confidence += heatmap[py, px] / 255.
length = max(height, 2.5)
confidence /= length * 0.8
lines.append({
'x': x,
'y': y1,
'extension': {'y1': y1, 'y2': y2},
'confidence': float(confidence),
'mark': (left + width / 2, top, top + height),
})
return lines
def detect_rectangles(heatmap, vertical_units=24):
"""Detect rectangular features (text boxes) in heatmap."""
unit = heatmap.shape[0] / vertical_units
y0 = heatmap.shape[0] / 2.0
_, thresh = cv2.threshold(heatmap, 92, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
rects = []
for contour in contours:
left, top, width, height = cv2.boundingRect(contour)
if width * height / unit / unit < 2:
continue
x = (left + width / 2) / unit
y = (top + height / 2) / unit
confidence = 0
for px in range(left, left + width):
for py in range(top, top + height):
confidence += heatmap[py, px] / 255.
confidence /= width * height
rects.append({
'x': x,
'y': y - y0 / unit,
'extension': {'width': width / unit, 'height': height / unit},
'confidence': float(confidence),
'mark': (left, top, width, height),
})
return rects
def detect_boxes(heatmap, vertical_units=24):
"""Detect rotated box features in heatmap."""
unit = heatmap.shape[0] / vertical_units
_, thresh = cv2.threshold(heatmap, 92, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
rects = []
for contour in contours:
rect = cv2.minAreaRect(contour)
pos, size, theta = rect
confidence = math.sqrt(size[0] * size[1])
if min(*size) / unit < 8:
continue
rects.append({
'x': pos[0],
'y': pos[1],
'extension': {'width': size[0], 'height': size[1], 'theta': theta},
'confidence': confidence,
'mark': rect,
})
return rects
class ScoreSemantic:
"""Score semantic analysis results."""
def __init__(self, heatmaps, labels, confidence_table=None):
self.data = {
'__prototype': 'SemanticGraph',
'points': [],
'staffY': None,
}
assert len(labels) == len(heatmaps), \
f'classes - heatmaps count mismatch: {len(labels)} - {len(heatmaps)}'
for i, semantic in enumerate(labels):
mean_confidence = 1
if confidence_table is not None:
item = confidence_table[i]
assert item['semantic'] == semantic
mean_confidence = max(item['mean_confidence'], 1e-4)
if re.match(r'^vline_', semantic):
lines = detect_vlines(heatmaps[i], vertical_units=VERTICAL_UNITS)
for line in lines:
self.data['points'].append({
'semantic': semantic,
'x': line['x'],
'y': line['y'],
'extension': line['extension'],
'confidence': line['confidence'] / mean_confidence,
})
elif re.match(r'^rect_', semantic):
rectangles = detect_rectangles(heatmaps[i], vertical_units=VERTICAL_UNITS)
for rect in rectangles:
self.data['points'].append({
'semantic': semantic,
'x': rect['x'],
'y': rect['y'],
'extension': rect['extension'],
'confidence': rect['confidence'] / mean_confidence,
})
elif re.match(r'^box_', semantic):
boxes = detect_boxes(heatmaps[i], vertical_units=VERTICAL_UNITS)
for rect in boxes:
self.data['points'].append({
'semantic': semantic,
'x': rect['x'],
'y': rect['y'],
'extension': rect['extension'],
'confidence': rect['confidence'] / mean_confidence,
})
else:
points = detect_points(heatmaps[i], vertical_units=VERTICAL_UNITS)
for point in points:
self.data['points'].append({
'semantic': semantic,
'x': point['x'],
'y': point['y'],
'confidence': point['confidence'] / mean_confidence,
})
def json(self):
return self.data
def _is_cluster_dir(model_path):
"""Check if model_path is a semantic cluster directory (has 'subs' in .state.yaml)."""
if not os.path.isdir(model_path):
return False
state_file = os.path.join(model_path, '.state.yaml')
if not os.path.exists(state_file):
return False
with open(state_file, 'r') as f:
state = yaml.safe_load(f)
return 'subs' in state
class SemanticService:
"""Semantic prediction service.
Handles both single TorchScript models and multi-model cluster directories.
A cluster directory has a .state.yaml with 'subs' listing sub-model directories.
"""
DEFAULT_TRANS = ['Mono', 'HWC2CHW']
DEFAULT_SLICING_WIDTH = 512
def __init__(self, model_path, device='cuda', trans=None, slicing_width=None,
labels=None, confidence_table=None, **kwargs):
self.device = device
if _is_cluster_dir(model_path):
self._init_cluster(model_path, device, trans, slicing_width)
else:
self._init_single(model_path, device, trans, slicing_width, labels, confidence_table)
def _init_single(self, model_path, device, trans, slicing_width, labels, confidence_table):
"""Initialize with a single TorchScript model."""
resolved = resolve_model_path(model_path)
self.model = torch.jit.load(resolved, map_location=device)
self.model.eval()
self.sub_models = None
self.composer = Composer(trans or self.DEFAULT_TRANS)
self.slicing_width = slicing_width or self.DEFAULT_SLICING_WIDTH
self.labels = labels or []
self.confidence_table = confidence_table
logging.info('SemanticService: single model loaded: %s', resolved)
def _init_cluster(self, model_path, device, trans, slicing_width):
"""Initialize with a multi-model cluster directory."""
state_file = os.path.join(model_path, '.state.yaml')
with open(state_file, 'r') as f:
cluster_state = yaml.safe_load(f)
# Get predictor config from cluster .state.yaml
predictor_config = cluster_state.get('predictor', {})
self.composer = Composer(
trans or predictor_config.get('trans') or self.DEFAULT_TRANS
)
self.slicing_width = (
slicing_width
or predictor_config.get('slicing_width')
or self.DEFAULT_SLICING_WIDTH
)
# Confidence table dict from cluster config
ct_dict = predictor_config.get('confidence_table', {})
# Load each sub-model
self.sub_models = []
self.labels = []
subs = cluster_state.get('subs', [])
for sub_name in subs:
sub_dir = os.path.join(model_path, sub_name)
sub_state_file = os.path.join(sub_dir, '.state.yaml')
with open(sub_state_file, 'r') as f:
sub_state = yaml.safe_load(f)
sub_labels = sub_state.get('data', {}).get('args', {}).get('labels', [])
sub_model_file = resolve_model_path(sub_dir)
model = torch.jit.load(sub_model_file, map_location=device)
model.eval()
self.sub_models.append(model)
self.labels.extend(sub_labels)
logging.info(' sub-model %s: %d labels, file=%s',
sub_name, len(sub_labels), os.path.basename(sub_model_file))
# Build confidence table list matching label order
self.confidence_table = None
if ct_dict:
self.confidence_table = []
for label in self.labels:
mean_conf = ct_dict.get(label, 1.0)
self.confidence_table.append({
'semantic': label,
'mean_confidence': mean_conf,
})
self.model = None # not used for cluster
logging.info('SemanticService: cluster loaded with %d sub-models, %d total labels',
len(self.sub_models), len(self.labels))
def run_inference(self, batch):
"""Run model inference with no_grad context."""
with torch.no_grad():
if self.sub_models is not None:
# Cluster: run each sub-model and concatenate channels
outputs = []
for model in self.sub_models:
output = model(batch)
if isinstance(output, tuple):
_, semantic = output
else:
semantic = output
outputs.append(semantic)
return torch.cat(outputs, dim=1)
else:
return self.model(batch)
def predict(self, streams, **kwargs):
"""
Predict semantic symbols from image streams.
streams: list of image byte buffers
yields: semantic graph results
"""
for stream in streams:
image = array_from_image_stream(stream)
if image is None:
yield {'error': 'Invalid image'}
continue
# Slice image
pieces = list(slice_feature(
image,
width=self.slicing_width,
overlapping=2 / MARGIN_DIVIDER,
padding=True
))
pieces = np.array(pieces, dtype=np.uint8)
# Transform
staves, _ = self.composer(pieces, np.ones((1, 4, 4, 2)))
batch = torch.from_numpy(staves).to(self.device)
# Inference
output = self.run_inference(batch)
# Handle tuple output (single model case)
if isinstance(output, tuple):
_, output = output
semantic = splice_output_tensor(output)
# Build semantic result
ss = ScoreSemantic(
np.uint8(semantic * 255),
self.labels,
confidence_table=self.confidence_table
)
yield ss.json()