""" Gauge prediction service. Predicts staff gauge (height and slope) map. Supports both TorchScript (.pt) and state_dict (.chkpt) model formats. """ import os import logging from collections import OrderedDict import numpy as np import torch import yaml import PIL.Image from predictors.torchscript_predictor import resolve_model_path from predictors.unet import UNet from common.image_utils import ( array_from_image_stream, slice_feature, splice_output_tensor, gauge_to_rgb, encode_image_base64, encode_image_bytes, MARGIN_DIVIDER ) from common.transform import Composer class _ScoreRegression(torch.nn.Module): """ScoreRegression architecture for loading .chkpt checkpoints.""" def __init__(self, in_channels=1, out_channels=2, unet_depth=6, unet_init_width=32): super().__init__() self.backbone = UNet(in_channels, out_channels, depth=unet_depth, init_width=unet_init_width) def forward(self, input): return self.backbone(input) def _load_gauge_model(model_path, device): """Load gauge model, handling both TorchScript and state_dict formats.""" resolved = resolve_model_path(model_path) # Try TorchScript first try: model = torch.jit.load(resolved, map_location=device) model.eval() logging.info('GaugeService: TorchScript model loaded: %s', resolved) return model except Exception as e: logging.info('GaugeService: not TorchScript (%s), trying state_dict...', str(e)[:60]) # Read model config from .state.yaml model_dir = os.path.dirname(resolved) state_file = os.path.join(model_dir, '.state.yaml') unet_depth = 6 unet_init_width = 32 out_channels = 2 if os.path.exists(state_file): with open(state_file, 'r') as f: state = yaml.safe_load(f) model_args = state.get('model', {}).get('args', {}) backbone = model_args.get('backbone', {}) unet_depth = backbone.get('unet_depth', 6) unet_init_width = backbone.get('unet_init_width', 32) out_channels = model_args.get('out_channels', 2) model = _ScoreRegression(out_channels=out_channels, unet_depth=unet_depth, unet_init_width=unet_init_width) checkpoint = torch.load(resolved, map_location=device, weights_only=False) # Handle different checkpoint formats state_dict = checkpoint if isinstance(checkpoint, dict): if 'model' in checkpoint: state_dict = checkpoint['model'] # Strip common prefixes from training wrapper (ScoreRegressionLoss.deducer.*) if isinstance(state_dict, dict): cleaned = OrderedDict() for key, value in state_dict.items(): new_key = key if new_key.startswith('deducer.'): new_key = new_key[len('deducer.'):] cleaned[new_key] = value # Remove non-model keys (e.g. channel_weights from Loss wrapper) cleaned = OrderedDict((k, v) for k, v in cleaned.items() if k.startswith('backbone.')) state_dict = cleaned model.load_state_dict(state_dict, strict=False) model.eval() model.to(device) # Log key loading stats model_keys = set(model.state_dict().keys()) loaded_keys = set(state_dict.keys()) matched = model_keys & loaded_keys logging.info('GaugeService: state_dict loaded: %s (%d/%d keys matched, depth=%d, width=%d)', resolved, len(matched), len(model_keys), unet_depth, unet_init_width) return model class StaffGauge: """Staff gauge representation.""" def __init__(self, hotmap): """ hotmap: (2, H, W) - channels: [Y (height), K (slope)] """ hotmap = gauge_to_rgb(hotmap, frac_y=True) self.image = PIL.Image.fromarray(hotmap[:, :, ::-1], 'RGB') def json(self): return { 'image': encode_image_base64(self.image), } class GaugeService: """Gauge prediction service. Supports TorchScript and state_dict formats.""" DEFAULT_TRANS = ['Mono', 'HWC2CHW'] DEFAULT_SLICING_WIDTH = 512 def __init__(self, model_path, device='cuda', trans=None, slicing_width=None): self.device = device self.model = _load_gauge_model(model_path, device) self.composer = Composer(trans or self.DEFAULT_TRANS) self.slicing_width = slicing_width or self.DEFAULT_SLICING_WIDTH def predict(self, streams, by_buffer=False, **kwargs): """ Predict staff gauge from image streams. streams: list of image byte buffers by_buffer: if True, return raw bytes instead of base64 yields: gauge results """ for stream in streams: image = array_from_image_stream(stream) if image is None: yield {'error': 'Invalid image'} continue # Slice image pieces = list(slice_feature( image, width=self.slicing_width, overlapping=2 / MARGIN_DIVIDER, padding=False )) pieces = np.array(pieces, dtype=np.float32) # Transform staves, _ = self.composer(pieces, np.ones((1, 4, 4, 2))) batch = torch.from_numpy(staves).to(self.device) # Inference with torch.no_grad(): output = self.model(batch) # (batch, channel, height, width) # Splice output hotmap = splice_output_tensor(output, soft=True) # (channel, height, width) if hotmap.shape[2] > image.shape[1]: hotmap = hotmap[:, :, :image.shape[1]] gauge = StaffGauge(hotmap) encoder = encode_image_bytes if by_buffer else encode_image_base64 yield { 'image': encoder(gauge.image), }