Spaces:
Running
Running
| """ | |
| Mask prediction service. | |
| Generates staff foreground/background mask. | |
| Supports both TorchScript (.pt) and state_dict (.chkpt) model formats. | |
| """ | |
| import os | |
| import logging | |
| import numpy as np | |
| import torch | |
| import yaml | |
| import PIL.Image | |
| from predictors.torchscript_predictor import resolve_model_path | |
| from predictors.unet import UNet | |
| from common.image_utils import ( | |
| array_from_image_stream, slice_feature, splice_output_tensor, | |
| mask_to_alpha, encode_image_base64, encode_image_bytes, | |
| MARGIN_DIVIDER | |
| ) | |
| from common.transform import Composer | |
| class _ScoreWidgetsMask(torch.nn.Module): | |
| """ScoreWidgetsMask architecture for loading .chkpt checkpoints.""" | |
| def __init__(self, in_channels=1, mask_channels=2, unet_depth=5, unet_init_width=32): | |
| super().__init__() | |
| self.mask = UNet(in_channels, mask_channels, depth=unet_depth, init_width=unet_init_width) | |
| def forward(self, x): | |
| return torch.sigmoid(self.mask(x)) | |
| def _load_mask_model(model_path, device): | |
| """Load mask model, handling both TorchScript and state_dict formats.""" | |
| resolved = resolve_model_path(model_path) | |
| # Try TorchScript first | |
| try: | |
| model = torch.jit.load(resolved, map_location=device) | |
| model.eval() | |
| logging.info('MaskService: TorchScript model loaded: %s', resolved) | |
| return model | |
| except Exception as e: | |
| logging.info('MaskService: not TorchScript (%s), trying state_dict...', str(e)[:60]) | |
| # Read model config from .state.yaml | |
| model_dir = os.path.dirname(resolved) | |
| state_file = os.path.join(model_dir, '.state.yaml') | |
| unet_depth = 5 | |
| unet_init_width = 32 | |
| if os.path.exists(state_file): | |
| with open(state_file, 'r') as f: | |
| state = yaml.safe_load(f) | |
| mask_config = state.get('model', {}).get('args', {}).get('mask', {}) | |
| unet_depth = mask_config.get('unet_depth', 5) | |
| unet_init_width = mask_config.get('unet_init_width', 32) | |
| model = _ScoreWidgetsMask(unet_depth=unet_depth, unet_init_width=unet_init_width) | |
| checkpoint = torch.load(resolved, map_location=device, weights_only=False) | |
| # Handle different checkpoint formats | |
| state_dict = checkpoint | |
| if isinstance(checkpoint, dict): | |
| if 'model' in checkpoint: | |
| state_dict = checkpoint['model'] | |
| # ScoreWidgetsMask saves as {'mask': {UNet weights}} | |
| if isinstance(state_dict, dict) and 'mask' in state_dict: | |
| model.mask.load_state_dict(state_dict['mask']) | |
| else: | |
| # Try loading directly (may have 'mask.' prefix from nn.Module default) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| model.to(device) | |
| logging.info('MaskService: state_dict loaded: %s (depth=%d, width=%d)', | |
| resolved, unet_depth, unet_init_width) | |
| return model | |
| class StaffMask: | |
| """Staff mask representation.""" | |
| def __init__(self, hotmap): | |
| """ | |
| hotmap: (2, H, W) - channels: [foreground, background] | |
| """ | |
| hotmap = mask_to_alpha(hotmap, frac_y=True) | |
| self.image = PIL.Image.fromarray(hotmap, 'LA') | |
| def json(self): | |
| return { | |
| 'image': encode_image_base64(self.image), | |
| } | |
| class MaskService: | |
| """Mask prediction service. Supports TorchScript and state_dict formats.""" | |
| DEFAULT_TRANS = ['Mono', 'HWC2CHW'] | |
| DEFAULT_SLICING_WIDTH = 512 | |
| def __init__(self, model_path, device='cuda', trans=None, slicing_width=None): | |
| self.device = device | |
| self.model = _load_mask_model(model_path, device) | |
| self.composer = Composer(trans or self.DEFAULT_TRANS) | |
| self.slicing_width = slicing_width or self.DEFAULT_SLICING_WIDTH | |
| def predict(self, streams, by_buffer=False, **kwargs): | |
| """ | |
| Predict staff mask from image streams. | |
| streams: list of image byte buffers | |
| by_buffer: if True, return raw bytes instead of base64 | |
| yields: mask results | |
| """ | |
| for stream in streams: | |
| image = array_from_image_stream(stream) | |
| if image is None: | |
| yield {'error': 'Invalid image'} | |
| continue | |
| # Slice image into overlapping pieces | |
| pieces = list(slice_feature( | |
| image, | |
| width=self.slicing_width, | |
| overlapping=2 / MARGIN_DIVIDER, | |
| padding=False | |
| )) | |
| pieces = np.array(pieces, dtype=np.float32) | |
| # Transform | |
| staves, _ = self.composer(pieces, np.ones((1, 4, 4, 2))) | |
| batch = torch.from_numpy(staves).to(self.device) | |
| # Inference | |
| with torch.no_grad(): | |
| output = self.model(batch) # (batch, channel, height, width) | |
| # Splice output | |
| hotmap = splice_output_tensor(output, soft=True) # (channel, height, width) | |
| if hotmap.shape[2] > image.shape[1]: | |
| hotmap = hotmap[:, :, :image.shape[1]] | |
| mask = StaffMask(hotmap) | |
| encoder = encode_image_bytes if by_buffer else encode_image_base64 | |
| yield { | |
| 'image': encoder(mask.image), | |
| } | |