""" Mask prediction service. Generates staff foreground/background mask. Supports both TorchScript (.pt) and state_dict (.chkpt) model formats. """ import os import logging import numpy as np import torch import yaml import PIL.Image from predictors.torchscript_predictor import resolve_model_path from predictors.unet import UNet from common.image_utils import ( array_from_image_stream, slice_feature, splice_output_tensor, mask_to_alpha, encode_image_base64, encode_image_bytes, MARGIN_DIVIDER ) from common.transform import Composer class _ScoreWidgetsMask(torch.nn.Module): """ScoreWidgetsMask architecture for loading .chkpt checkpoints.""" def __init__(self, in_channels=1, mask_channels=2, unet_depth=5, unet_init_width=32): super().__init__() self.mask = UNet(in_channels, mask_channels, depth=unet_depth, init_width=unet_init_width) def forward(self, x): return torch.sigmoid(self.mask(x)) def _load_mask_model(model_path, device): """Load mask model, handling both TorchScript and state_dict formats.""" resolved = resolve_model_path(model_path) # Try TorchScript first try: model = torch.jit.load(resolved, map_location=device) model.eval() logging.info('MaskService: TorchScript model loaded: %s', resolved) return model except Exception as e: logging.info('MaskService: not TorchScript (%s), trying state_dict...', str(e)[:60]) # Read model config from .state.yaml model_dir = os.path.dirname(resolved) state_file = os.path.join(model_dir, '.state.yaml') unet_depth = 5 unet_init_width = 32 if os.path.exists(state_file): with open(state_file, 'r') as f: state = yaml.safe_load(f) mask_config = state.get('model', {}).get('args', {}).get('mask', {}) unet_depth = mask_config.get('unet_depth', 5) unet_init_width = mask_config.get('unet_init_width', 32) model = _ScoreWidgetsMask(unet_depth=unet_depth, unet_init_width=unet_init_width) checkpoint = torch.load(resolved, map_location=device, weights_only=False) # Handle different checkpoint formats state_dict = checkpoint if isinstance(checkpoint, dict): if 'model' in checkpoint: state_dict = checkpoint['model'] # ScoreWidgetsMask saves as {'mask': {UNet weights}} if isinstance(state_dict, dict) and 'mask' in state_dict: model.mask.load_state_dict(state_dict['mask']) else: # Try loading directly (may have 'mask.' prefix from nn.Module default) model.load_state_dict(state_dict, strict=False) model.eval() model.to(device) logging.info('MaskService: state_dict loaded: %s (depth=%d, width=%d)', resolved, unet_depth, unet_init_width) return model class StaffMask: """Staff mask representation.""" def __init__(self, hotmap): """ hotmap: (2, H, W) - channels: [foreground, background] """ hotmap = mask_to_alpha(hotmap, frac_y=True) self.image = PIL.Image.fromarray(hotmap, 'LA') def json(self): return { 'image': encode_image_base64(self.image), } class MaskService: """Mask prediction service. Supports TorchScript and state_dict formats.""" DEFAULT_TRANS = ['Mono', 'HWC2CHW'] DEFAULT_SLICING_WIDTH = 512 def __init__(self, model_path, device='cuda', trans=None, slicing_width=None): self.device = device self.model = _load_mask_model(model_path, device) self.composer = Composer(trans or self.DEFAULT_TRANS) self.slicing_width = slicing_width or self.DEFAULT_SLICING_WIDTH def predict(self, streams, by_buffer=False, **kwargs): """ Predict staff mask from image streams. streams: list of image byte buffers by_buffer: if True, return raw bytes instead of base64 yields: mask results """ for stream in streams: image = array_from_image_stream(stream) if image is None: yield {'error': 'Invalid image'} continue # Slice image into overlapping pieces pieces = list(slice_feature( image, width=self.slicing_width, overlapping=2 / MARGIN_DIVIDER, padding=False )) pieces = np.array(pieces, dtype=np.float32) # Transform staves, _ = self.composer(pieces, np.ones((1, 4, 4, 2))) batch = torch.from_numpy(staves).to(self.device) # Inference with torch.no_grad(): output = self.model(batch) # (batch, channel, height, width) # Splice output hotmap = splice_output_tensor(output, soft=True) # (channel, height, width) if hotmap.shape[2] > image.shape[1]: hotmap = hotmap[:, :, :image.shape[1]] mask = StaffMask(hotmap) encoder = encode_image_bytes if by_buffer else encode_image_base64 yield { 'image': encoder(mask.image), }