Spaces:
Running
Running
| """ | |
| Gauge prediction service. | |
| Predicts staff gauge (height and slope) map. | |
| Supports both TorchScript (.pt) and state_dict (.chkpt) model formats. | |
| """ | |
| import os | |
| import logging | |
| from collections import OrderedDict | |
| import numpy as np | |
| import torch | |
| import yaml | |
| import PIL.Image | |
| from predictors.torchscript_predictor import resolve_model_path | |
| from predictors.unet import UNet | |
| from common.image_utils import ( | |
| array_from_image_stream, slice_feature, splice_output_tensor, | |
| gauge_to_rgb, encode_image_base64, encode_image_bytes, | |
| MARGIN_DIVIDER | |
| ) | |
| from common.transform import Composer | |
| class _ScoreRegression(torch.nn.Module): | |
| """ScoreRegression architecture for loading .chkpt checkpoints.""" | |
| def __init__(self, in_channels=1, out_channels=2, unet_depth=6, unet_init_width=32): | |
| super().__init__() | |
| self.backbone = UNet(in_channels, out_channels, depth=unet_depth, init_width=unet_init_width) | |
| def forward(self, input): | |
| return self.backbone(input) | |
| def _load_gauge_model(model_path, device): | |
| """Load gauge model, handling both TorchScript and state_dict formats.""" | |
| resolved = resolve_model_path(model_path) | |
| # Try TorchScript first | |
| try: | |
| model = torch.jit.load(resolved, map_location=device) | |
| model.eval() | |
| logging.info('GaugeService: TorchScript model loaded: %s', resolved) | |
| return model | |
| except Exception as e: | |
| logging.info('GaugeService: not TorchScript (%s), trying state_dict...', str(e)[:60]) | |
| # Read model config from .state.yaml | |
| model_dir = os.path.dirname(resolved) | |
| state_file = os.path.join(model_dir, '.state.yaml') | |
| unet_depth = 6 | |
| unet_init_width = 32 | |
| out_channels = 2 | |
| if os.path.exists(state_file): | |
| with open(state_file, 'r') as f: | |
| state = yaml.safe_load(f) | |
| model_args = state.get('model', {}).get('args', {}) | |
| backbone = model_args.get('backbone', {}) | |
| unet_depth = backbone.get('unet_depth', 6) | |
| unet_init_width = backbone.get('unet_init_width', 32) | |
| out_channels = model_args.get('out_channels', 2) | |
| model = _ScoreRegression(out_channels=out_channels, unet_depth=unet_depth, unet_init_width=unet_init_width) | |
| checkpoint = torch.load(resolved, map_location=device, weights_only=False) | |
| # Handle different checkpoint formats | |
| state_dict = checkpoint | |
| if isinstance(checkpoint, dict): | |
| if 'model' in checkpoint: | |
| state_dict = checkpoint['model'] | |
| # Strip common prefixes from training wrapper (ScoreRegressionLoss.deducer.*) | |
| if isinstance(state_dict, dict): | |
| cleaned = OrderedDict() | |
| for key, value in state_dict.items(): | |
| new_key = key | |
| if new_key.startswith('deducer.'): | |
| new_key = new_key[len('deducer.'):] | |
| cleaned[new_key] = value | |
| # Remove non-model keys (e.g. channel_weights from Loss wrapper) | |
| cleaned = OrderedDict((k, v) for k, v in cleaned.items() | |
| if k.startswith('backbone.')) | |
| state_dict = cleaned | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| model.to(device) | |
| # Log key loading stats | |
| model_keys = set(model.state_dict().keys()) | |
| loaded_keys = set(state_dict.keys()) | |
| matched = model_keys & loaded_keys | |
| logging.info('GaugeService: state_dict loaded: %s (%d/%d keys matched, depth=%d, width=%d)', | |
| resolved, len(matched), len(model_keys), unet_depth, unet_init_width) | |
| return model | |
| class StaffGauge: | |
| """Staff gauge representation.""" | |
| def __init__(self, hotmap): | |
| """ | |
| hotmap: (2, H, W) - channels: [Y (height), K (slope)] | |
| """ | |
| hotmap = gauge_to_rgb(hotmap, frac_y=True) | |
| self.image = PIL.Image.fromarray(hotmap[:, :, ::-1], 'RGB') | |
| def json(self): | |
| return { | |
| 'image': encode_image_base64(self.image), | |
| } | |
| class GaugeService: | |
| """Gauge prediction service. Supports TorchScript and state_dict formats.""" | |
| DEFAULT_TRANS = ['Mono', 'HWC2CHW'] | |
| DEFAULT_SLICING_WIDTH = 512 | |
| def __init__(self, model_path, device='cuda', trans=None, slicing_width=None): | |
| self.device = device | |
| self.model = _load_gauge_model(model_path, device) | |
| self.composer = Composer(trans or self.DEFAULT_TRANS) | |
| self.slicing_width = slicing_width or self.DEFAULT_SLICING_WIDTH | |
| def predict(self, streams, by_buffer=False, **kwargs): | |
| """ | |
| Predict staff gauge from image streams. | |
| streams: list of image byte buffers | |
| by_buffer: if True, return raw bytes instead of base64 | |
| yields: gauge results | |
| """ | |
| for stream in streams: | |
| image = array_from_image_stream(stream) | |
| if image is None: | |
| yield {'error': 'Invalid image'} | |
| continue | |
| # Slice image | |
| pieces = list(slice_feature( | |
| image, | |
| width=self.slicing_width, | |
| overlapping=2 / MARGIN_DIVIDER, | |
| padding=False | |
| )) | |
| pieces = np.array(pieces, dtype=np.float32) | |
| # Transform | |
| staves, _ = self.composer(pieces, np.ones((1, 4, 4, 2))) | |
| batch = torch.from_numpy(staves).to(self.device) | |
| # Inference | |
| with torch.no_grad(): | |
| output = self.model(batch) # (batch, channel, height, width) | |
| # Splice output | |
| hotmap = splice_output_tensor(output, soft=True) # (channel, height, width) | |
| if hotmap.shape[2] > image.shape[1]: | |
| hotmap = hotmap[:, :, :image.shape[1]] | |
| gauge = StaffGauge(hotmap) | |
| encoder = encode_image_bytes if by_buffer else encode_image_base64 | |
| yield { | |
| 'image': encoder(gauge.image), | |
| } | |