k-l-lambda's picture
update: export from starry-refactor 2026-02-20 15:25
1958836
"""
Mask prediction service.
Generates staff foreground/background mask.
Supports both TorchScript (.pt) and state_dict (.chkpt) model formats.
"""
import os
import logging
import numpy as np
import torch
import yaml
import PIL.Image
from predictors.torchscript_predictor import resolve_model_path
from predictors.unet import UNet
from common.image_utils import (
array_from_image_stream, slice_feature, splice_output_tensor,
mask_to_alpha, encode_image_base64, encode_image_bytes,
MARGIN_DIVIDER
)
from common.transform import Composer
class _ScoreWidgetsMask(torch.nn.Module):
"""ScoreWidgetsMask architecture for loading .chkpt checkpoints."""
def __init__(self, in_channels=1, mask_channels=2, unet_depth=5, unet_init_width=32):
super().__init__()
self.mask = UNet(in_channels, mask_channels, depth=unet_depth, init_width=unet_init_width)
def forward(self, x):
return torch.sigmoid(self.mask(x))
def _load_mask_model(model_path, device):
"""Load mask model, handling both TorchScript and state_dict formats."""
resolved = resolve_model_path(model_path)
# Try TorchScript first
try:
model = torch.jit.load(resolved, map_location=device)
model.eval()
logging.info('MaskService: TorchScript model loaded: %s', resolved)
return model
except Exception as e:
logging.info('MaskService: not TorchScript (%s), trying state_dict...', str(e)[:60])
# Read model config from .state.yaml
model_dir = os.path.dirname(resolved)
state_file = os.path.join(model_dir, '.state.yaml')
unet_depth = 5
unet_init_width = 32
if os.path.exists(state_file):
with open(state_file, 'r') as f:
state = yaml.safe_load(f)
mask_config = state.get('model', {}).get('args', {}).get('mask', {})
unet_depth = mask_config.get('unet_depth', 5)
unet_init_width = mask_config.get('unet_init_width', 32)
model = _ScoreWidgetsMask(unet_depth=unet_depth, unet_init_width=unet_init_width)
checkpoint = torch.load(resolved, map_location=device, weights_only=False)
# Handle different checkpoint formats
state_dict = checkpoint
if isinstance(checkpoint, dict):
if 'model' in checkpoint:
state_dict = checkpoint['model']
# ScoreWidgetsMask saves as {'mask': {UNet weights}}
if isinstance(state_dict, dict) and 'mask' in state_dict:
model.mask.load_state_dict(state_dict['mask'])
else:
# Try loading directly (may have 'mask.' prefix from nn.Module default)
model.load_state_dict(state_dict, strict=False)
model.eval()
model.to(device)
logging.info('MaskService: state_dict loaded: %s (depth=%d, width=%d)',
resolved, unet_depth, unet_init_width)
return model
class StaffMask:
"""Staff mask representation."""
def __init__(self, hotmap):
"""
hotmap: (2, H, W) - channels: [foreground, background]
"""
hotmap = mask_to_alpha(hotmap, frac_y=True)
self.image = PIL.Image.fromarray(hotmap, 'LA')
def json(self):
return {
'image': encode_image_base64(self.image),
}
class MaskService:
"""Mask prediction service. Supports TorchScript and state_dict formats."""
DEFAULT_TRANS = ['Mono', 'HWC2CHW']
DEFAULT_SLICING_WIDTH = 512
def __init__(self, model_path, device='cuda', trans=None, slicing_width=None):
self.device = device
self.model = _load_mask_model(model_path, device)
self.composer = Composer(trans or self.DEFAULT_TRANS)
self.slicing_width = slicing_width or self.DEFAULT_SLICING_WIDTH
def predict(self, streams, by_buffer=False, **kwargs):
"""
Predict staff mask from image streams.
streams: list of image byte buffers
by_buffer: if True, return raw bytes instead of base64
yields: mask results
"""
for stream in streams:
image = array_from_image_stream(stream)
if image is None:
yield {'error': 'Invalid image'}
continue
# Slice image into overlapping pieces
pieces = list(slice_feature(
image,
width=self.slicing_width,
overlapping=2 / MARGIN_DIVIDER,
padding=False
))
pieces = np.array(pieces, dtype=np.float32)
# Transform
staves, _ = self.composer(pieces, np.ones((1, 4, 4, 2)))
batch = torch.from_numpy(staves).to(self.device)
# Inference
with torch.no_grad():
output = self.model(batch) # (batch, channel, height, width)
# Splice output
hotmap = splice_output_tensor(output, soft=True) # (channel, height, width)
if hotmap.shape[2] > image.shape[1]:
hotmap = hotmap[:, :, :image.shape[1]]
mask = StaffMask(hotmap)
encoder = encode_image_bytes if by_buffer else encode_image_base64
yield {
'image': encoder(mask.image),
}