k-l-lambda's picture
update: export from starry-refactor 2026-02-20 15:25
1958836
"""
Gauge prediction service.
Predicts staff gauge (height and slope) map.
Supports both TorchScript (.pt) and state_dict (.chkpt) model formats.
"""
import os
import logging
from collections import OrderedDict
import numpy as np
import torch
import yaml
import PIL.Image
from predictors.torchscript_predictor import resolve_model_path
from predictors.unet import UNet
from common.image_utils import (
array_from_image_stream, slice_feature, splice_output_tensor,
gauge_to_rgb, encode_image_base64, encode_image_bytes,
MARGIN_DIVIDER
)
from common.transform import Composer
class _ScoreRegression(torch.nn.Module):
"""ScoreRegression architecture for loading .chkpt checkpoints."""
def __init__(self, in_channels=1, out_channels=2, unet_depth=6, unet_init_width=32):
super().__init__()
self.backbone = UNet(in_channels, out_channels, depth=unet_depth, init_width=unet_init_width)
def forward(self, input):
return self.backbone(input)
def _load_gauge_model(model_path, device):
"""Load gauge model, handling both TorchScript and state_dict formats."""
resolved = resolve_model_path(model_path)
# Try TorchScript first
try:
model = torch.jit.load(resolved, map_location=device)
model.eval()
logging.info('GaugeService: TorchScript model loaded: %s', resolved)
return model
except Exception as e:
logging.info('GaugeService: not TorchScript (%s), trying state_dict...', str(e)[:60])
# Read model config from .state.yaml
model_dir = os.path.dirname(resolved)
state_file = os.path.join(model_dir, '.state.yaml')
unet_depth = 6
unet_init_width = 32
out_channels = 2
if os.path.exists(state_file):
with open(state_file, 'r') as f:
state = yaml.safe_load(f)
model_args = state.get('model', {}).get('args', {})
backbone = model_args.get('backbone', {})
unet_depth = backbone.get('unet_depth', 6)
unet_init_width = backbone.get('unet_init_width', 32)
out_channels = model_args.get('out_channels', 2)
model = _ScoreRegression(out_channels=out_channels, unet_depth=unet_depth, unet_init_width=unet_init_width)
checkpoint = torch.load(resolved, map_location=device, weights_only=False)
# Handle different checkpoint formats
state_dict = checkpoint
if isinstance(checkpoint, dict):
if 'model' in checkpoint:
state_dict = checkpoint['model']
# Strip common prefixes from training wrapper (ScoreRegressionLoss.deducer.*)
if isinstance(state_dict, dict):
cleaned = OrderedDict()
for key, value in state_dict.items():
new_key = key
if new_key.startswith('deducer.'):
new_key = new_key[len('deducer.'):]
cleaned[new_key] = value
# Remove non-model keys (e.g. channel_weights from Loss wrapper)
cleaned = OrderedDict((k, v) for k, v in cleaned.items()
if k.startswith('backbone.'))
state_dict = cleaned
model.load_state_dict(state_dict, strict=False)
model.eval()
model.to(device)
# Log key loading stats
model_keys = set(model.state_dict().keys())
loaded_keys = set(state_dict.keys())
matched = model_keys & loaded_keys
logging.info('GaugeService: state_dict loaded: %s (%d/%d keys matched, depth=%d, width=%d)',
resolved, len(matched), len(model_keys), unet_depth, unet_init_width)
return model
class StaffGauge:
"""Staff gauge representation."""
def __init__(self, hotmap):
"""
hotmap: (2, H, W) - channels: [Y (height), K (slope)]
"""
hotmap = gauge_to_rgb(hotmap, frac_y=True)
self.image = PIL.Image.fromarray(hotmap[:, :, ::-1], 'RGB')
def json(self):
return {
'image': encode_image_base64(self.image),
}
class GaugeService:
"""Gauge prediction service. Supports TorchScript and state_dict formats."""
DEFAULT_TRANS = ['Mono', 'HWC2CHW']
DEFAULT_SLICING_WIDTH = 512
def __init__(self, model_path, device='cuda', trans=None, slicing_width=None):
self.device = device
self.model = _load_gauge_model(model_path, device)
self.composer = Composer(trans or self.DEFAULT_TRANS)
self.slicing_width = slicing_width or self.DEFAULT_SLICING_WIDTH
def predict(self, streams, by_buffer=False, **kwargs):
"""
Predict staff gauge from image streams.
streams: list of image byte buffers
by_buffer: if True, return raw bytes instead of base64
yields: gauge results
"""
for stream in streams:
image = array_from_image_stream(stream)
if image is None:
yield {'error': 'Invalid image'}
continue
# Slice image
pieces = list(slice_feature(
image,
width=self.slicing_width,
overlapping=2 / MARGIN_DIVIDER,
padding=False
))
pieces = np.array(pieces, dtype=np.float32)
# Transform
staves, _ = self.composer(pieces, np.ones((1, 4, 4, 2)))
batch = torch.from_numpy(staves).to(self.device)
# Inference
with torch.no_grad():
output = self.model(batch) # (batch, channel, height, width)
# Splice output
hotmap = splice_output_tensor(output, soft=True) # (channel, height, width)
if hotmap.shape[2] > image.shape[1]:
hotmap = hotmap[:, :, :image.shape[1]]
gauge = StaffGauge(hotmap)
encoder = encode_image_bytes if by_buffer else encode_image_base64
yield {
'image': encoder(gauge.image),
}