""" OCR prediction service. Recognizes text content from detected regions using DenseNet-CTC. """ import math import numpy as np import cv2 import logging from predictors.densenet_ctc import load_densenet_ctc, greedy_ctc_decode from common.image_utils import array_from_image_stream # Text type categories (same as loc_service) TYPE_NAMES = [ 'Title', # 0 'Author', # 1 'TextualMark', # 2 'TempoNumeral', # 3 'MeasureNumber', # 4 'Times', # 5 'Chord', # 6 'PageMargin', # 7 'Instrument', # 8 'Other', # 9 'Lyric', # 10 'Alter1', # 11 'Alter2', # 12 ] # Tempo character mapping (musical note symbols) TEMPO_CHAR_DICT = { 'a': '\U0001d15d', # Whole Note 'b': '\U0001d15e', # Half Note 'c': '\U0001d15f', # Quarter Note 'd': '\U0001d160', # Eighth Note 'e': '\U0001d161', # Sixteenth Note 'f': '\U0001d162', # Thirty-Second Note } def translate_string_by_dict(s, d): """Translate characters in string using dictionary.""" return ''.join(d.get(c, c) for c in s) class OcrService: """ OCR service using DenseNet-CTC architecture. """ def __init__(self, model_path, device='gpu', alphabet=None, tempo_model_path=None, tempo_alphabet=None, chord_model_path=None, **kwargs): self.alphabet = alphabet or '' self.tempo_alphabet = tempo_alphabet or '' # nclass = alphabet length + 1 (blank token '卍') nclass = len(self.alphabet) + 1 self.model = load_densenet_ctc(model_path, nclass) # Load tempo model self.tempo_model = None if tempo_model_path: tempo_nclass = len(self.tempo_alphabet) + 1 self.tempo_model = load_densenet_ctc(tempo_model_path, tempo_nclass) # Chord model (SavedModel directory, different architecture) self.chord_model = None if chord_model_path and chord_model_path.endswith('/'): try: import tensorflow as tf self.chord_model = tf.keras.models.load_model(chord_model_path, compile=False) logging.info('Chord model loaded: %s', chord_model_path) except Exception as e: logging.warning('Failed to load chord model: %s', e) def preprocess_image(self, image, target_height=32): """Preprocess image for DenseNet-CTC model.""" h, w = image.shape[:2] if len(image.shape) == 3: image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) scale = target_height / h new_w = int(w * scale) if new_w < 9: return None image = cv2.resize(image, (new_w, target_height)) image = image.astype(np.float32) / 255.0 - 0.5 image = np.expand_dims(image, axis=(0, -1)) # (1, H, W, 1) return image def perspective_transform(self, image, box): """Apply perspective transform to extract text region.""" pts1 = np.float32([ [box['x0'], box['y0']], [box['x1'], box['y1']], [box['x2'], box['y2']], [box['x3'], box['y3']] ]) trans_width = round(math.sqrt( (box['x0'] - box['x1']) ** 2 + (box['y0'] - box['y1']) ** 2 )) trans_height = round(math.sqrt( (box['x0'] - box['x3']) ** 2 + (box['y0'] - box['y3']) ** 2 )) if trans_width < 1 or trans_height < 1: return None pts2 = np.float32([ [0, 0], [trans_width, 0], [trans_width, trans_height], [0, trans_height] ]) M = cv2.getPerspectiveTransform(pts1, pts2) dst = cv2.warpPerspective(image, M, (trans_width, trans_height)) return dst def predict(self, buffers, location=None, **kwargs): """ Recognize text from image with location info. buffers: list containing single image buffer location: list of detected boxes from loc_service yields: recognition results """ if not buffers: yield {'error': 'No image provided'} return image = array_from_image_stream(buffers[0]) if image is None: yield {'error': 'Invalid image'} return if not location: yield {'imageSize': list(image.shape[:2]), 'areas': []} return areas = [] for box in location: dst_pic = self.perspective_transform(image, box) if dst_pic is None: continue cx = (box['x0'] + box['x1'] + box['x2'] + box['x3']) / 4 cy = (box['y0'] + box['y1'] + box['y2'] + box['y3']) / 4 width = (box['x1'] + box['x2'] - box['x0'] - box['x3']) / 2 height = (box['y2'] + box['y3'] - box['y0'] - box['y1']) / 2 theta = math.atan2( box['y1'] - box['y0'] + box['y2'] - box['y3'], box['x1'] - box['x0'] + box['x2'] - box['x3'] ) text_type = TYPE_NAMES[box.get('class', 0)] text = '' feature_dict = None try: if text_type == 'TempoNumeral' and self.tempo_model is not None: processed = self.preprocess_image(dst_pic) if processed is not None: pred = self.tempo_model.predict(processed, verbose=0) text = greedy_ctc_decode(pred, self.tempo_alphabet) text = translate_string_by_dict(text, TEMPO_CHAR_DICT) elif text_type == 'Chord' and self.chord_model is not None: processed = self.preprocess_image(dst_pic) if processed is not None: pred = self.chord_model.predict(processed, verbose=0) text = greedy_ctc_decode(pred, self.alphabet) else: processed = self.preprocess_image(dst_pic) if processed is not None: pred = self.model.predict(processed, verbose=0) text = greedy_ctc_decode(pred, self.alphabet) except Exception as e: logging.warning('OCR prediction error: %s', str(e)) text = '' areas.append({ 'score': box.get('score', 0), 'text': text, 'feature_dict': feature_dict, 'cx': cx, 'cy': cy, 'width': width, 'height': height, 'theta': theta, 'type': text_type, }) yield { 'imageSize': list(image.shape[:2]), 'areas': areas, }