k-l-lambda's picture
feat: add Python ML services (CPU mode) with model download
2b7aae2
"""
Location detection service.
Detects text regions and classifies them into 13 categories.
Uses DB_gc_loc architecture: ResNet-18 backbone + SegDetector decoder.
Supports both TorchScript (.pt) models and state_dict checkpoints.
"""
import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import logging
from collections import OrderedDict
from predictors.torchscript_predictor import resolve_model_path
from common.image_utils import array_from_image_stream
# RGB mean for normalization (from original implementation)
RGB_MEAN = np.array([122.67891434, 116.66876762, 104.00698793])
# Text type categories
TYPE_NAMES = [
'Title', # 0
'Author', # 1
'TextualMark', # 2
'TempoNumeral', # 3
'MeasureNumber', # 4
'Times', # 5
'Chord', # 6
'PageMargin', # 7
'Instrument', # 8
'Other', # 9
'Lyric', # 10
'Alter1', # 11
'Alter2', # 12
]
# ===================== ResNet-18 Backbone =====================
def _conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class _BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()
self.conv1 = _conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
if self.downsample is not None:
residual = self.downsample(x)
out += residual
return self.relu(out)
class _ResNet(nn.Module):
def __init__(self, block, layers):
super().__init__()
self.inplanes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
# Not used in forward but exist in original model (needed for weight loading)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(512, 1000)
self.smooth = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=1)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = [block(self.inplanes, planes, stride, downsample)]
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
x2 = self.layer1(x)
x3 = self.layer2(x2)
x4 = self.layer3(x3)
x5 = self.layer4(x4)
return x2, x3, x4, x5
# ===================== SegDetector Decoder =====================
class _SegDetector(nn.Module):
def __init__(self, n_cls=13, in_channels=(64, 128, 256, 512),
inner_channels=256, k=50, bias=False, adaptive=True):
super().__init__()
self.k = k
self.up5 = nn.Upsample(scale_factor=2, mode='nearest')
self.up4 = nn.Upsample(scale_factor=2, mode='nearest')
self.up3 = nn.Upsample(scale_factor=2, mode='nearest')
self.in5 = nn.Conv2d(in_channels[-1], inner_channels, 1, bias=bias)
self.in4 = nn.Conv2d(in_channels[-2], inner_channels, 1, bias=bias)
self.in3 = nn.Conv2d(in_channels[-3], inner_channels, 1, bias=bias)
self.in2 = nn.Conv2d(in_channels[-4], inner_channels, 1, bias=bias)
self.out5 = nn.Sequential(
nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias),
nn.Upsample(scale_factor=8, mode='nearest'))
self.out4 = nn.Sequential(
nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias),
nn.Upsample(scale_factor=4, mode='nearest'))
self.out3 = nn.Sequential(
nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias),
nn.Upsample(scale_factor=2, mode='nearest'))
self.out2 = nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias)
self.binarize = nn.Sequential(
nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias),
nn.BatchNorm2d(inner_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2),
nn.BatchNorm2d(inner_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2),
nn.Sigmoid())
self.mulclass = nn.Sequential(
nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias),
nn.BatchNorm2d(inner_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2),
nn.BatchNorm2d(inner_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(inner_channels // 4, n_cls, 2, 2))
if adaptive:
self.thresh = nn.Sequential(
nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias),
nn.BatchNorm2d(inner_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2),
nn.BatchNorm2d(inner_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2),
nn.Sigmoid())
def forward(self, features):
c2, c3, c4, c5 = features
in5 = self.in5(c5)
in4 = self.in4(c4)
in3 = self.in3(c3)
in2 = self.in2(c2)
out4 = self.up5(in5) + in4
out3 = self.up4(out4) + in3
out2 = self.up3(out3) + in2
p5 = self.out5(in5)
p4 = self.out4(out4)
p3 = self.out3(out3)
p2 = self.out2(out2)
fuse = torch.cat((p5, p4, p3, p2), 1)
binary = self.binarize(fuse)
mcls = self.mulclass(fuse)
return binary, mcls
class _LocModel(nn.Module):
"""ResNet-18 + SegDetector combined model for text localization."""
def __init__(self):
super().__init__()
self.backbone = _ResNet(_BasicBlock, [2, 2, 2, 2])
self.decoder = _SegDetector(n_cls=13, adaptive=True, k=50)
def forward(self, x):
return self.decoder(self.backbone(x))
def _load_loc_model(model_path, device):
"""Load loc model, handling both TorchScript and state_dict formats.
First tries TorchScript loading; on failure, builds architecture and loads state_dict.
"""
resolved = resolve_model_path(model_path)
# Try TorchScript first
try:
model = torch.jit.load(resolved, map_location=device)
model.eval()
logging.info('LocService: TorchScript model loaded: %s', resolved)
return model
except Exception as e:
logging.info('LocService: not TorchScript (%s), trying state_dict...', str(e)[:60])
# Fall back to state_dict loading with embedded architecture
checkpoint = torch.load(resolved, map_location=device, weights_only=False)
if isinstance(checkpoint, dict) and 'model' in checkpoint:
state_dict = checkpoint['model']
elif isinstance(checkpoint, OrderedDict):
state_dict = checkpoint
else:
state_dict = checkpoint
# Strip 'model.module.' prefix (SegDetectorModel → DataParallel wrapping)
new_state_dict = OrderedDict()
for key, value in state_dict.items():
new_key = key
if new_key.startswith('model.module.'):
new_key = new_key[len('model.module.'):]
elif new_key.startswith('module.'):
new_key = new_key[len('module.'):]
new_state_dict[new_key] = value
model = _LocModel()
model.load_state_dict(new_state_dict, strict=False)
model.eval()
model.to(device)
# Log key loading stats
model_keys = set(model.state_dict().keys())
loaded_keys = set(new_state_dict.keys())
matched = model_keys & loaded_keys
logging.info('LocService: state_dict loaded: %s (%d/%d keys matched)',
resolved, len(matched), len(model_keys))
return model
class LocService:
"""
Location detection service.
Uses DB_gc_loc architecture (ResNet-18 + SegDetector).
Supports both TorchScript models and state_dict checkpoints.
"""
def __init__(self, model_path, device='cuda', image_short_side=736,
box_thresh=0.01, class_num=13, **kwargs):
self.device = device
self.model = _load_loc_model(model_path, device)
self.image_short_side = image_short_side
self.box_thresh = box_thresh
self.class_num = class_num
def resize_image(self, img):
"""Resize image keeping aspect ratio, with short side = image_short_side."""
height, width = img.shape[:2]
if height < width:
new_height = self.image_short_side
new_width = int(math.ceil(new_height / height * width / 32) * 32)
else:
new_width = self.image_short_side
new_height = int(math.ceil(new_width / width * height / 32) * 32)
return cv2.resize(img, (new_width, new_height))
def preprocess(self, image):
"""Preprocess image for model input."""
img = image.astype('float32')
original_shape = img.shape[:2]
img = self.resize_image(img)
img -= RGB_MEAN
img /= 255.
img = torch.from_numpy(img).permute(2, 0, 1).float().unsqueeze(0)
return img.to(self.device), original_shape
def represent_boxes(self, pred, out_class, original_shape, resized_shape):
"""Post-process model output to extract bounding boxes."""
pred_np = pred.cpu().numpy()[0, 0]
class_np = out_class.cpu().numpy()[0, 0]
binary = (pred_np > self.box_thresh).astype(np.uint8) * 255
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
boxes = []
h_scale = original_shape[0] / resized_shape[0]
w_scale = original_shape[1] / resized_shape[1]
for contour in contours:
if len(contour) < 4:
continue
rect = cv2.minAreaRect(contour)
box_points = cv2.boxPoints(rect)
box_points = np.int0(box_points)
mask = np.zeros(pred_np.shape, dtype=np.uint8)
cv2.drawContours(mask, [contour], -1, 1, -1)
class_region = class_np * mask
if mask.sum() > 0:
box_class = int(np.argmax(np.bincount(class_region[mask > 0].astype(int))))
else:
box_class = 0
score_region = pred_np * mask
score = score_region.sum() / max(mask.sum(), 1)
scaled_points = box_points.astype(float)
scaled_points[:, 0] *= w_scale
scaled_points[:, 1] *= h_scale
boxes.append({
'x0': float(scaled_points[0, 0]),
'y0': float(scaled_points[0, 1]),
'x1': float(scaled_points[1, 0]),
'y1': float(scaled_points[1, 1]),
'x2': float(scaled_points[2, 0]),
'y2': float(scaled_points[2, 1]),
'x3': float(scaled_points[3, 0]),
'y3': float(scaled_points[3, 1]),
'score': float(score),
'class': box_class,
})
return boxes
def predict(self, buffers, **kwargs):
"""Detect text regions in images."""
for buffer in buffers:
image = array_from_image_stream(buffer)
if image is None:
yield []
continue
img_tensor, original_shape = self.preprocess(image)
resized_shape = (img_tensor.shape[2], img_tensor.shape[3])
with torch.no_grad():
output = self.model(img_tensor)
if isinstance(output, tuple) and len(output) == 2:
pred, mcls = output
else:
pred = output
mcls = torch.zeros_like(pred)
b, c, h, w = mcls.shape
out_class = mcls.permute(0, 2, 3, 1).reshape(-1, self.class_num)
out_class = F.softmax(out_class, -1)
out_class = out_class.max(1)[1].reshape(b, h, w).unsqueeze(1)
boxes = self.represent_boxes(pred, out_class, original_shape, resized_shape)
valid_boxes = []
for box in boxes:
if not (box['x0'] == box['x1'] and box['x2'] == box['x1']):
valid_boxes.append(box)
yield valid_boxes