""" Location detection service. Detects text regions and classifies them into 13 categories. Uses DB_gc_loc architecture: ResNet-18 backbone + SegDetector decoder. Supports both TorchScript (.pt) models and state_dict checkpoints. """ import os import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import cv2 import logging from collections import OrderedDict from predictors.torchscript_predictor import resolve_model_path from common.image_utils import array_from_image_stream # RGB mean for normalization (from original implementation) RGB_MEAN = np.array([122.67891434, 116.66876762, 104.00698793]) # Text type categories TYPE_NAMES = [ 'Title', # 0 'Author', # 1 'TextualMark', # 2 'TempoNumeral', # 3 'MeasureNumber', # 4 'Times', # 5 'Chord', # 6 'PageMargin', # 7 'Instrument', # 8 'Other', # 9 'Lyric', # 10 'Alter1', # 11 'Alter2', # 12 ] # ===================== ResNet-18 Backbone ===================== def _conv3x3(in_planes, out_planes, stride=1): return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class _BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super().__init__() self.conv1 = _conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample def forward(self, x): residual = x out = self.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) if self.downsample is not None: residual = self.downsample(x) out += residual return self.relu(out) class _ResNet(nn.Module): def __init__(self, block, layers): super().__init__() self.inplanes = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) # Not used in forward but exist in original model (needed for weight loading) self.avgpool = nn.AvgPool2d(7, stride=1) self.fc = nn.Linear(512, 1000) self.smooth = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=1) def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [block(self.inplanes, planes, stride, downsample)] self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def forward(self, x): x = self.maxpool(self.relu(self.bn1(self.conv1(x)))) x2 = self.layer1(x) x3 = self.layer2(x2) x4 = self.layer3(x3) x5 = self.layer4(x4) return x2, x3, x4, x5 # ===================== SegDetector Decoder ===================== class _SegDetector(nn.Module): def __init__(self, n_cls=13, in_channels=(64, 128, 256, 512), inner_channels=256, k=50, bias=False, adaptive=True): super().__init__() self.k = k self.up5 = nn.Upsample(scale_factor=2, mode='nearest') self.up4 = nn.Upsample(scale_factor=2, mode='nearest') self.up3 = nn.Upsample(scale_factor=2, mode='nearest') self.in5 = nn.Conv2d(in_channels[-1], inner_channels, 1, bias=bias) self.in4 = nn.Conv2d(in_channels[-2], inner_channels, 1, bias=bias) self.in3 = nn.Conv2d(in_channels[-3], inner_channels, 1, bias=bias) self.in2 = nn.Conv2d(in_channels[-4], inner_channels, 1, bias=bias) self.out5 = nn.Sequential( nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias), nn.Upsample(scale_factor=8, mode='nearest')) self.out4 = nn.Sequential( nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias), nn.Upsample(scale_factor=4, mode='nearest')) self.out3 = nn.Sequential( nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias), nn.Upsample(scale_factor=2, mode='nearest')) self.out2 = nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias) self.binarize = nn.Sequential( nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias), nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2), nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), nn.Sigmoid()) self.mulclass = nn.Sequential( nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias), nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2), nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(inner_channels // 4, n_cls, 2, 2)) if adaptive: self.thresh = nn.Sequential( nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias), nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2), nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), nn.Sigmoid()) def forward(self, features): c2, c3, c4, c5 = features in5 = self.in5(c5) in4 = self.in4(c4) in3 = self.in3(c3) in2 = self.in2(c2) out4 = self.up5(in5) + in4 out3 = self.up4(out4) + in3 out2 = self.up3(out3) + in2 p5 = self.out5(in5) p4 = self.out4(out4) p3 = self.out3(out3) p2 = self.out2(out2) fuse = torch.cat((p5, p4, p3, p2), 1) binary = self.binarize(fuse) mcls = self.mulclass(fuse) return binary, mcls class _LocModel(nn.Module): """ResNet-18 + SegDetector combined model for text localization.""" def __init__(self): super().__init__() self.backbone = _ResNet(_BasicBlock, [2, 2, 2, 2]) self.decoder = _SegDetector(n_cls=13, adaptive=True, k=50) def forward(self, x): return self.decoder(self.backbone(x)) def _load_loc_model(model_path, device): """Load loc model, handling both TorchScript and state_dict formats. First tries TorchScript loading; on failure, builds architecture and loads state_dict. """ resolved = resolve_model_path(model_path) # Try TorchScript first try: model = torch.jit.load(resolved, map_location=device) model.eval() logging.info('LocService: TorchScript model loaded: %s', resolved) return model except Exception as e: logging.info('LocService: not TorchScript (%s), trying state_dict...', str(e)[:60]) # Fall back to state_dict loading with embedded architecture checkpoint = torch.load(resolved, map_location=device, weights_only=False) if isinstance(checkpoint, dict) and 'model' in checkpoint: state_dict = checkpoint['model'] elif isinstance(checkpoint, OrderedDict): state_dict = checkpoint else: state_dict = checkpoint # Strip 'model.module.' prefix (SegDetectorModel → DataParallel wrapping) new_state_dict = OrderedDict() for key, value in state_dict.items(): new_key = key if new_key.startswith('model.module.'): new_key = new_key[len('model.module.'):] elif new_key.startswith('module.'): new_key = new_key[len('module.'):] new_state_dict[new_key] = value model = _LocModel() model.load_state_dict(new_state_dict, strict=False) model.eval() model.to(device) # Log key loading stats model_keys = set(model.state_dict().keys()) loaded_keys = set(new_state_dict.keys()) matched = model_keys & loaded_keys logging.info('LocService: state_dict loaded: %s (%d/%d keys matched)', resolved, len(matched), len(model_keys)) return model class LocService: """ Location detection service. Uses DB_gc_loc architecture (ResNet-18 + SegDetector). Supports both TorchScript models and state_dict checkpoints. """ def __init__(self, model_path, device='cuda', image_short_side=736, box_thresh=0.01, class_num=13, **kwargs): self.device = device self.model = _load_loc_model(model_path, device) self.image_short_side = image_short_side self.box_thresh = box_thresh self.class_num = class_num def resize_image(self, img): """Resize image keeping aspect ratio, with short side = image_short_side.""" height, width = img.shape[:2] if height < width: new_height = self.image_short_side new_width = int(math.ceil(new_height / height * width / 32) * 32) else: new_width = self.image_short_side new_height = int(math.ceil(new_width / width * height / 32) * 32) return cv2.resize(img, (new_width, new_height)) def preprocess(self, image): """Preprocess image for model input.""" img = image.astype('float32') original_shape = img.shape[:2] img = self.resize_image(img) img -= RGB_MEAN img /= 255. img = torch.from_numpy(img).permute(2, 0, 1).float().unsqueeze(0) return img.to(self.device), original_shape def represent_boxes(self, pred, out_class, original_shape, resized_shape): """Post-process model output to extract bounding boxes.""" pred_np = pred.cpu().numpy()[0, 0] class_np = out_class.cpu().numpy()[0, 0] binary = (pred_np > self.box_thresh).astype(np.uint8) * 255 contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) boxes = [] h_scale = original_shape[0] / resized_shape[0] w_scale = original_shape[1] / resized_shape[1] for contour in contours: if len(contour) < 4: continue rect = cv2.minAreaRect(contour) box_points = cv2.boxPoints(rect) box_points = np.int0(box_points) mask = np.zeros(pred_np.shape, dtype=np.uint8) cv2.drawContours(mask, [contour], -1, 1, -1) class_region = class_np * mask if mask.sum() > 0: box_class = int(np.argmax(np.bincount(class_region[mask > 0].astype(int)))) else: box_class = 0 score_region = pred_np * mask score = score_region.sum() / max(mask.sum(), 1) scaled_points = box_points.astype(float) scaled_points[:, 0] *= w_scale scaled_points[:, 1] *= h_scale boxes.append({ 'x0': float(scaled_points[0, 0]), 'y0': float(scaled_points[0, 1]), 'x1': float(scaled_points[1, 0]), 'y1': float(scaled_points[1, 1]), 'x2': float(scaled_points[2, 0]), 'y2': float(scaled_points[2, 1]), 'x3': float(scaled_points[3, 0]), 'y3': float(scaled_points[3, 1]), 'score': float(score), 'class': box_class, }) return boxes def predict(self, buffers, **kwargs): """Detect text regions in images.""" for buffer in buffers: image = array_from_image_stream(buffer) if image is None: yield [] continue img_tensor, original_shape = self.preprocess(image) resized_shape = (img_tensor.shape[2], img_tensor.shape[3]) with torch.no_grad(): output = self.model(img_tensor) if isinstance(output, tuple) and len(output) == 2: pred, mcls = output else: pred = output mcls = torch.zeros_like(pred) b, c, h, w = mcls.shape out_class = mcls.permute(0, 2, 3, 1).reshape(-1, self.class_num) out_class = F.softmax(out_class, -1) out_class = out_class.max(1)[1].reshape(b, h, w).unsqueeze(1) boxes = self.represent_boxes(pred, out_class, original_shape, resized_shape) valid_boxes = [] for box in boxes: if not (box['x0'] == box['x1'] and box['x2'] == box['x1']): valid_boxes.append(box) yield valid_boxes