Spaces:
Running
Running
File size: 5,122 Bytes
2b7aae2 1958836 2b7aae2 1958836 2b7aae2 1958836 2b7aae2 1958836 2b7aae2 1958836 2b7aae2 1958836 2b7aae2 1958836 2b7aae2 1958836 2b7aae2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | """
Gauge prediction service.
Predicts staff gauge (height and slope) map.
Supports both TorchScript (.pt) and state_dict (.chkpt) model formats.
"""
import os
import logging
from collections import OrderedDict
import numpy as np
import torch
import yaml
import PIL.Image
from predictors.torchscript_predictor import resolve_model_path
from predictors.unet import UNet
from common.image_utils import (
array_from_image_stream, slice_feature, splice_output_tensor,
gauge_to_rgb, encode_image_base64, encode_image_bytes,
MARGIN_DIVIDER
)
from common.transform import Composer
class _ScoreRegression(torch.nn.Module):
"""ScoreRegression architecture for loading .chkpt checkpoints."""
def __init__(self, in_channels=1, out_channels=2, unet_depth=6, unet_init_width=32):
super().__init__()
self.backbone = UNet(in_channels, out_channels, depth=unet_depth, init_width=unet_init_width)
def forward(self, input):
return self.backbone(input)
def _load_gauge_model(model_path, device):
"""Load gauge model, handling both TorchScript and state_dict formats."""
resolved = resolve_model_path(model_path)
# Try TorchScript first
try:
model = torch.jit.load(resolved, map_location=device)
model.eval()
logging.info('GaugeService: TorchScript model loaded: %s', resolved)
return model
except Exception as e:
logging.info('GaugeService: not TorchScript (%s), trying state_dict...', str(e)[:60])
# Read model config from .state.yaml
model_dir = os.path.dirname(resolved)
state_file = os.path.join(model_dir, '.state.yaml')
unet_depth = 6
unet_init_width = 32
out_channels = 2
if os.path.exists(state_file):
with open(state_file, 'r') as f:
state = yaml.safe_load(f)
model_args = state.get('model', {}).get('args', {})
backbone = model_args.get('backbone', {})
unet_depth = backbone.get('unet_depth', 6)
unet_init_width = backbone.get('unet_init_width', 32)
out_channels = model_args.get('out_channels', 2)
model = _ScoreRegression(out_channels=out_channels, unet_depth=unet_depth, unet_init_width=unet_init_width)
checkpoint = torch.load(resolved, map_location=device, weights_only=False)
# Handle different checkpoint formats
state_dict = checkpoint
if isinstance(checkpoint, dict):
if 'model' in checkpoint:
state_dict = checkpoint['model']
# Strip common prefixes from training wrapper (ScoreRegressionLoss.deducer.*)
if isinstance(state_dict, dict):
cleaned = OrderedDict()
for key, value in state_dict.items():
new_key = key
if new_key.startswith('deducer.'):
new_key = new_key[len('deducer.'):]
cleaned[new_key] = value
# Remove non-model keys (e.g. channel_weights from Loss wrapper)
cleaned = OrderedDict((k, v) for k, v in cleaned.items()
if k.startswith('backbone.'))
state_dict = cleaned
model.load_state_dict(state_dict, strict=False)
model.eval()
model.to(device)
# Log key loading stats
model_keys = set(model.state_dict().keys())
loaded_keys = set(state_dict.keys())
matched = model_keys & loaded_keys
logging.info('GaugeService: state_dict loaded: %s (%d/%d keys matched, depth=%d, width=%d)',
resolved, len(matched), len(model_keys), unet_depth, unet_init_width)
return model
class StaffGauge:
"""Staff gauge representation."""
def __init__(self, hotmap):
"""
hotmap: (2, H, W) - channels: [Y (height), K (slope)]
"""
hotmap = gauge_to_rgb(hotmap, frac_y=True)
self.image = PIL.Image.fromarray(hotmap[:, :, ::-1], 'RGB')
def json(self):
return {
'image': encode_image_base64(self.image),
}
class GaugeService:
"""Gauge prediction service. Supports TorchScript and state_dict formats."""
DEFAULT_TRANS = ['Mono', 'HWC2CHW']
DEFAULT_SLICING_WIDTH = 512
def __init__(self, model_path, device='cuda', trans=None, slicing_width=None):
self.device = device
self.model = _load_gauge_model(model_path, device)
self.composer = Composer(trans or self.DEFAULT_TRANS)
self.slicing_width = slicing_width or self.DEFAULT_SLICING_WIDTH
def predict(self, streams, by_buffer=False, **kwargs):
"""
Predict staff gauge from image streams.
streams: list of image byte buffers
by_buffer: if True, return raw bytes instead of base64
yields: gauge results
"""
for stream in streams:
image = array_from_image_stream(stream)
if image is None:
yield {'error': 'Invalid image'}
continue
# Slice image
pieces = list(slice_feature(
image,
width=self.slicing_width,
overlapping=2 / MARGIN_DIVIDER,
padding=False
))
pieces = np.array(pieces, dtype=np.float32)
# Transform
staves, _ = self.composer(pieces, np.ones((1, 4, 4, 2)))
batch = torch.from_numpy(staves).to(self.device)
# Inference
with torch.no_grad():
output = self.model(batch) # (batch, channel, height, width)
# Splice output
hotmap = splice_output_tensor(output, soft=True) # (channel, height, width)
if hotmap.shape[2] > image.shape[1]:
hotmap = hotmap[:, :, :image.shape[1]]
gauge = StaffGauge(hotmap)
encoder = encode_image_bytes if by_buffer else encode_image_base64
yield {
'image': encoder(gauge.image),
}
|