Spaces:
Running
Running
| """ | |
| Location detection service. | |
| Detects text regions and classifies them into 13 categories. | |
| Uses DB_gc_loc architecture: ResNet-18 backbone + SegDetector decoder. | |
| Supports both TorchScript (.pt) models and state_dict checkpoints. | |
| """ | |
| import os | |
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import cv2 | |
| import logging | |
| from collections import OrderedDict | |
| from predictors.torchscript_predictor import resolve_model_path | |
| from common.image_utils import array_from_image_stream | |
| # RGB mean for normalization (from original implementation) | |
| RGB_MEAN = np.array([122.67891434, 116.66876762, 104.00698793]) | |
| # Text type categories | |
| TYPE_NAMES = [ | |
| 'Title', # 0 | |
| 'Author', # 1 | |
| 'TextualMark', # 2 | |
| 'TempoNumeral', # 3 | |
| 'MeasureNumber', # 4 | |
| 'Times', # 5 | |
| 'Chord', # 6 | |
| 'PageMargin', # 7 | |
| 'Instrument', # 8 | |
| 'Other', # 9 | |
| 'Lyric', # 10 | |
| 'Alter1', # 11 | |
| 'Alter2', # 12 | |
| ] | |
| # ===================== ResNet-18 Backbone ===================== | |
| def _conv3x3(in_planes, out_planes, stride=1): | |
| return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | |
| padding=1, bias=False) | |
| class _BasicBlock(nn.Module): | |
| expansion = 1 | |
| def __init__(self, inplanes, planes, stride=1, downsample=None): | |
| super().__init__() | |
| self.conv1 = _conv3x3(inplanes, planes, stride) | |
| self.bn1 = nn.BatchNorm2d(planes) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) | |
| self.bn2 = nn.BatchNorm2d(planes) | |
| self.downsample = downsample | |
| def forward(self, x): | |
| residual = x | |
| out = self.relu(self.bn1(self.conv1(x))) | |
| out = self.bn2(self.conv2(out)) | |
| if self.downsample is not None: | |
| residual = self.downsample(x) | |
| out += residual | |
| return self.relu(out) | |
| class _ResNet(nn.Module): | |
| def __init__(self, block, layers): | |
| super().__init__() | |
| self.inplanes = 64 | |
| self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) | |
| self.bn1 = nn.BatchNorm2d(64) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |
| self.layer1 = self._make_layer(block, 64, layers[0]) | |
| self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | |
| self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | |
| self.layer4 = self._make_layer(block, 512, layers[3], stride=2) | |
| # Not used in forward but exist in original model (needed for weight loading) | |
| self.avgpool = nn.AvgPool2d(7, stride=1) | |
| self.fc = nn.Linear(512, 1000) | |
| self.smooth = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=1) | |
| def _make_layer(self, block, planes, blocks, stride=1): | |
| downsample = None | |
| if stride != 1 or self.inplanes != planes * block.expansion: | |
| downsample = nn.Sequential( | |
| nn.Conv2d(self.inplanes, planes * block.expansion, | |
| kernel_size=1, stride=stride, bias=False), | |
| nn.BatchNorm2d(planes * block.expansion), | |
| ) | |
| layers = [block(self.inplanes, planes, stride, downsample)] | |
| self.inplanes = planes * block.expansion | |
| for _ in range(1, blocks): | |
| layers.append(block(self.inplanes, planes)) | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| x = self.maxpool(self.relu(self.bn1(self.conv1(x)))) | |
| x2 = self.layer1(x) | |
| x3 = self.layer2(x2) | |
| x4 = self.layer3(x3) | |
| x5 = self.layer4(x4) | |
| return x2, x3, x4, x5 | |
| # ===================== SegDetector Decoder ===================== | |
| class _SegDetector(nn.Module): | |
| def __init__(self, n_cls=13, in_channels=(64, 128, 256, 512), | |
| inner_channels=256, k=50, bias=False, adaptive=True): | |
| super().__init__() | |
| self.k = k | |
| self.up5 = nn.Upsample(scale_factor=2, mode='nearest') | |
| self.up4 = nn.Upsample(scale_factor=2, mode='nearest') | |
| self.up3 = nn.Upsample(scale_factor=2, mode='nearest') | |
| self.in5 = nn.Conv2d(in_channels[-1], inner_channels, 1, bias=bias) | |
| self.in4 = nn.Conv2d(in_channels[-2], inner_channels, 1, bias=bias) | |
| self.in3 = nn.Conv2d(in_channels[-3], inner_channels, 1, bias=bias) | |
| self.in2 = nn.Conv2d(in_channels[-4], inner_channels, 1, bias=bias) | |
| self.out5 = nn.Sequential( | |
| nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias), | |
| nn.Upsample(scale_factor=8, mode='nearest')) | |
| self.out4 = nn.Sequential( | |
| nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias), | |
| nn.Upsample(scale_factor=4, mode='nearest')) | |
| self.out3 = nn.Sequential( | |
| nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias), | |
| nn.Upsample(scale_factor=2, mode='nearest')) | |
| self.out2 = nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias) | |
| self.binarize = nn.Sequential( | |
| nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias), | |
| nn.BatchNorm2d(inner_channels // 4), | |
| nn.ReLU(inplace=True), | |
| nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2), | |
| nn.BatchNorm2d(inner_channels // 4), | |
| nn.ReLU(inplace=True), | |
| nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), | |
| nn.Sigmoid()) | |
| self.mulclass = nn.Sequential( | |
| nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias), | |
| nn.BatchNorm2d(inner_channels // 4), | |
| nn.ReLU(inplace=True), | |
| nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2), | |
| nn.BatchNorm2d(inner_channels // 4), | |
| nn.ReLU(inplace=True), | |
| nn.ConvTranspose2d(inner_channels // 4, n_cls, 2, 2)) | |
| if adaptive: | |
| self.thresh = nn.Sequential( | |
| nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias), | |
| nn.BatchNorm2d(inner_channels // 4), | |
| nn.ReLU(inplace=True), | |
| nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2), | |
| nn.BatchNorm2d(inner_channels // 4), | |
| nn.ReLU(inplace=True), | |
| nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), | |
| nn.Sigmoid()) | |
| def forward(self, features): | |
| c2, c3, c4, c5 = features | |
| in5 = self.in5(c5) | |
| in4 = self.in4(c4) | |
| in3 = self.in3(c3) | |
| in2 = self.in2(c2) | |
| out4 = self.up5(in5) + in4 | |
| out3 = self.up4(out4) + in3 | |
| out2 = self.up3(out3) + in2 | |
| p5 = self.out5(in5) | |
| p4 = self.out4(out4) | |
| p3 = self.out3(out3) | |
| p2 = self.out2(out2) | |
| fuse = torch.cat((p5, p4, p3, p2), 1) | |
| binary = self.binarize(fuse) | |
| mcls = self.mulclass(fuse) | |
| return binary, mcls | |
| class _LocModel(nn.Module): | |
| """ResNet-18 + SegDetector combined model for text localization.""" | |
| def __init__(self): | |
| super().__init__() | |
| self.backbone = _ResNet(_BasicBlock, [2, 2, 2, 2]) | |
| self.decoder = _SegDetector(n_cls=13, adaptive=True, k=50) | |
| def forward(self, x): | |
| return self.decoder(self.backbone(x)) | |
| def _load_loc_model(model_path, device): | |
| """Load loc model, handling both TorchScript and state_dict formats. | |
| First tries TorchScript loading; on failure, builds architecture and loads state_dict. | |
| """ | |
| resolved = resolve_model_path(model_path) | |
| # Try TorchScript first | |
| try: | |
| model = torch.jit.load(resolved, map_location=device) | |
| model.eval() | |
| logging.info('LocService: TorchScript model loaded: %s', resolved) | |
| return model | |
| except Exception as e: | |
| logging.info('LocService: not TorchScript (%s), trying state_dict...', str(e)[:60]) | |
| # Fall back to state_dict loading with embedded architecture | |
| checkpoint = torch.load(resolved, map_location=device, weights_only=False) | |
| if isinstance(checkpoint, dict) and 'model' in checkpoint: | |
| state_dict = checkpoint['model'] | |
| elif isinstance(checkpoint, OrderedDict): | |
| state_dict = checkpoint | |
| else: | |
| state_dict = checkpoint | |
| # Strip 'model.module.' prefix (SegDetectorModel → DataParallel wrapping) | |
| new_state_dict = OrderedDict() | |
| for key, value in state_dict.items(): | |
| new_key = key | |
| if new_key.startswith('model.module.'): | |
| new_key = new_key[len('model.module.'):] | |
| elif new_key.startswith('module.'): | |
| new_key = new_key[len('module.'):] | |
| new_state_dict[new_key] = value | |
| model = _LocModel() | |
| model.load_state_dict(new_state_dict, strict=False) | |
| model.eval() | |
| model.to(device) | |
| # Log key loading stats | |
| model_keys = set(model.state_dict().keys()) | |
| loaded_keys = set(new_state_dict.keys()) | |
| matched = model_keys & loaded_keys | |
| logging.info('LocService: state_dict loaded: %s (%d/%d keys matched)', | |
| resolved, len(matched), len(model_keys)) | |
| return model | |
| class LocService: | |
| """ | |
| Location detection service. | |
| Uses DB_gc_loc architecture (ResNet-18 + SegDetector). | |
| Supports both TorchScript models and state_dict checkpoints. | |
| """ | |
| def __init__(self, model_path, device='cuda', image_short_side=736, | |
| box_thresh=0.01, class_num=13, **kwargs): | |
| self.device = device | |
| self.model = _load_loc_model(model_path, device) | |
| self.image_short_side = image_short_side | |
| self.box_thresh = box_thresh | |
| self.class_num = class_num | |
| def resize_image(self, img): | |
| """Resize image keeping aspect ratio, with short side = image_short_side.""" | |
| height, width = img.shape[:2] | |
| if height < width: | |
| new_height = self.image_short_side | |
| new_width = int(math.ceil(new_height / height * width / 32) * 32) | |
| else: | |
| new_width = self.image_short_side | |
| new_height = int(math.ceil(new_width / width * height / 32) * 32) | |
| return cv2.resize(img, (new_width, new_height)) | |
| def preprocess(self, image): | |
| """Preprocess image for model input.""" | |
| img = image.astype('float32') | |
| original_shape = img.shape[:2] | |
| img = self.resize_image(img) | |
| img -= RGB_MEAN | |
| img /= 255. | |
| img = torch.from_numpy(img).permute(2, 0, 1).float().unsqueeze(0) | |
| return img.to(self.device), original_shape | |
| def represent_boxes(self, pred, out_class, original_shape, resized_shape): | |
| """Post-process model output to extract bounding boxes.""" | |
| pred_np = pred.cpu().numpy()[0, 0] | |
| class_np = out_class.cpu().numpy()[0, 0] | |
| binary = (pred_np > self.box_thresh).astype(np.uint8) * 255 | |
| contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| boxes = [] | |
| h_scale = original_shape[0] / resized_shape[0] | |
| w_scale = original_shape[1] / resized_shape[1] | |
| for contour in contours: | |
| if len(contour) < 4: | |
| continue | |
| rect = cv2.minAreaRect(contour) | |
| box_points = cv2.boxPoints(rect) | |
| box_points = np.int0(box_points) | |
| mask = np.zeros(pred_np.shape, dtype=np.uint8) | |
| cv2.drawContours(mask, [contour], -1, 1, -1) | |
| class_region = class_np * mask | |
| if mask.sum() > 0: | |
| box_class = int(np.argmax(np.bincount(class_region[mask > 0].astype(int)))) | |
| else: | |
| box_class = 0 | |
| score_region = pred_np * mask | |
| score = score_region.sum() / max(mask.sum(), 1) | |
| scaled_points = box_points.astype(float) | |
| scaled_points[:, 0] *= w_scale | |
| scaled_points[:, 1] *= h_scale | |
| boxes.append({ | |
| 'x0': float(scaled_points[0, 0]), | |
| 'y0': float(scaled_points[0, 1]), | |
| 'x1': float(scaled_points[1, 0]), | |
| 'y1': float(scaled_points[1, 1]), | |
| 'x2': float(scaled_points[2, 0]), | |
| 'y2': float(scaled_points[2, 1]), | |
| 'x3': float(scaled_points[3, 0]), | |
| 'y3': float(scaled_points[3, 1]), | |
| 'score': float(score), | |
| 'class': box_class, | |
| }) | |
| return boxes | |
| def predict(self, buffers, **kwargs): | |
| """Detect text regions in images.""" | |
| for buffer in buffers: | |
| image = array_from_image_stream(buffer) | |
| if image is None: | |
| yield [] | |
| continue | |
| img_tensor, original_shape = self.preprocess(image) | |
| resized_shape = (img_tensor.shape[2], img_tensor.shape[3]) | |
| with torch.no_grad(): | |
| output = self.model(img_tensor) | |
| if isinstance(output, tuple) and len(output) == 2: | |
| pred, mcls = output | |
| else: | |
| pred = output | |
| mcls = torch.zeros_like(pred) | |
| b, c, h, w = mcls.shape | |
| out_class = mcls.permute(0, 2, 3, 1).reshape(-1, self.class_num) | |
| out_class = F.softmax(out_class, -1) | |
| out_class = out_class.max(1)[1].reshape(b, h, w).unsqueeze(1) | |
| boxes = self.represent_boxes(pred, out_class, original_shape, resized_shape) | |
| valid_boxes = [] | |
| for box in boxes: | |
| if not (box['x0'] == box['x1'] and box['x2'] == box['x1']): | |
| valid_boxes.append(box) | |
| yield valid_boxes | |