Spaces:
Running
Running
| """ | |
| OCR prediction service. | |
| Recognizes text content from detected regions using DenseNet-CTC. | |
| """ | |
| import math | |
| import numpy as np | |
| import cv2 | |
| import logging | |
| from predictors.densenet_ctc import load_densenet_ctc, greedy_ctc_decode | |
| from common.image_utils import array_from_image_stream | |
| # Text type categories (same as loc_service) | |
| TYPE_NAMES = [ | |
| 'Title', # 0 | |
| 'Author', # 1 | |
| 'TextualMark', # 2 | |
| 'TempoNumeral', # 3 | |
| 'MeasureNumber', # 4 | |
| 'Times', # 5 | |
| 'Chord', # 6 | |
| 'PageMargin', # 7 | |
| 'Instrument', # 8 | |
| 'Other', # 9 | |
| 'Lyric', # 10 | |
| 'Alter1', # 11 | |
| 'Alter2', # 12 | |
| ] | |
| # Tempo character mapping (musical note symbols) | |
| TEMPO_CHAR_DICT = { | |
| 'a': '\U0001d15d', # Whole Note | |
| 'b': '\U0001d15e', # Half Note | |
| 'c': '\U0001d15f', # Quarter Note | |
| 'd': '\U0001d160', # Eighth Note | |
| 'e': '\U0001d161', # Sixteenth Note | |
| 'f': '\U0001d162', # Thirty-Second Note | |
| } | |
| def translate_string_by_dict(s, d): | |
| """Translate characters in string using dictionary.""" | |
| return ''.join(d.get(c, c) for c in s) | |
| class OcrService: | |
| """ | |
| OCR service using DenseNet-CTC architecture. | |
| """ | |
| def __init__(self, model_path, device='gpu', alphabet=None, | |
| tempo_model_path=None, tempo_alphabet=None, | |
| chord_model_path=None, **kwargs): | |
| self.alphabet = alphabet or '' | |
| self.tempo_alphabet = tempo_alphabet or '' | |
| # nclass = alphabet length + 1 (blank token '卍') | |
| nclass = len(self.alphabet) + 1 | |
| self.model = load_densenet_ctc(model_path, nclass) | |
| # Load tempo model | |
| self.tempo_model = None | |
| if tempo_model_path: | |
| tempo_nclass = len(self.tempo_alphabet) + 1 | |
| self.tempo_model = load_densenet_ctc(tempo_model_path, tempo_nclass) | |
| # Chord model (SavedModel directory, different architecture) | |
| self.chord_model = None | |
| if chord_model_path and chord_model_path.endswith('/'): | |
| try: | |
| import tensorflow as tf | |
| self.chord_model = tf.keras.models.load_model(chord_model_path, compile=False) | |
| logging.info('Chord model loaded: %s', chord_model_path) | |
| except Exception as e: | |
| logging.warning('Failed to load chord model: %s', e) | |
| def preprocess_image(self, image, target_height=32): | |
| """Preprocess image for DenseNet-CTC model.""" | |
| h, w = image.shape[:2] | |
| if len(image.shape) == 3: | |
| image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
| scale = target_height / h | |
| new_w = int(w * scale) | |
| if new_w < 9: | |
| return None | |
| image = cv2.resize(image, (new_w, target_height)) | |
| image = image.astype(np.float32) / 255.0 - 0.5 | |
| image = np.expand_dims(image, axis=(0, -1)) # (1, H, W, 1) | |
| return image | |
| def perspective_transform(self, image, box): | |
| """Apply perspective transform to extract text region.""" | |
| pts1 = np.float32([ | |
| [box['x0'], box['y0']], | |
| [box['x1'], box['y1']], | |
| [box['x2'], box['y2']], | |
| [box['x3'], box['y3']] | |
| ]) | |
| trans_width = round(math.sqrt( | |
| (box['x0'] - box['x1']) ** 2 + (box['y0'] - box['y1']) ** 2 | |
| )) | |
| trans_height = round(math.sqrt( | |
| (box['x0'] - box['x3']) ** 2 + (box['y0'] - box['y3']) ** 2 | |
| )) | |
| if trans_width < 1 or trans_height < 1: | |
| return None | |
| pts2 = np.float32([ | |
| [0, 0], | |
| [trans_width, 0], | |
| [trans_width, trans_height], | |
| [0, trans_height] | |
| ]) | |
| M = cv2.getPerspectiveTransform(pts1, pts2) | |
| dst = cv2.warpPerspective(image, M, (trans_width, trans_height)) | |
| return dst | |
| def predict(self, buffers, location=None, **kwargs): | |
| """ | |
| Recognize text from image with location info. | |
| buffers: list containing single image buffer | |
| location: list of detected boxes from loc_service | |
| yields: recognition results | |
| """ | |
| if not buffers: | |
| yield {'error': 'No image provided'} | |
| return | |
| image = array_from_image_stream(buffers[0]) | |
| if image is None: | |
| yield {'error': 'Invalid image'} | |
| return | |
| if not location: | |
| yield {'imageSize': list(image.shape[:2]), 'areas': []} | |
| return | |
| areas = [] | |
| for box in location: | |
| dst_pic = self.perspective_transform(image, box) | |
| if dst_pic is None: | |
| continue | |
| cx = (box['x0'] + box['x1'] + box['x2'] + box['x3']) / 4 | |
| cy = (box['y0'] + box['y1'] + box['y2'] + box['y3']) / 4 | |
| width = (box['x1'] + box['x2'] - box['x0'] - box['x3']) / 2 | |
| height = (box['y2'] + box['y3'] - box['y0'] - box['y1']) / 2 | |
| theta = math.atan2( | |
| box['y1'] - box['y0'] + box['y2'] - box['y3'], | |
| box['x1'] - box['x0'] + box['x2'] - box['x3'] | |
| ) | |
| text_type = TYPE_NAMES[box.get('class', 0)] | |
| text = '' | |
| feature_dict = None | |
| try: | |
| if text_type == 'TempoNumeral' and self.tempo_model is not None: | |
| processed = self.preprocess_image(dst_pic) | |
| if processed is not None: | |
| pred = self.tempo_model.predict(processed, verbose=0) | |
| text = greedy_ctc_decode(pred, self.tempo_alphabet) | |
| text = translate_string_by_dict(text, TEMPO_CHAR_DICT) | |
| elif text_type == 'Chord' and self.chord_model is not None: | |
| processed = self.preprocess_image(dst_pic) | |
| if processed is not None: | |
| pred = self.chord_model.predict(processed, verbose=0) | |
| text = greedy_ctc_decode(pred, self.alphabet) | |
| else: | |
| processed = self.preprocess_image(dst_pic) | |
| if processed is not None: | |
| pred = self.model.predict(processed, verbose=0) | |
| text = greedy_ctc_decode(pred, self.alphabet) | |
| except Exception as e: | |
| logging.warning('OCR prediction error: %s', str(e)) | |
| text = '' | |
| areas.append({ | |
| 'score': box.get('score', 0), | |
| 'text': text, | |
| 'feature_dict': feature_dict, | |
| 'cx': cx, | |
| 'cy': cy, | |
| 'width': width, | |
| 'height': height, | |
| 'theta': theta, | |
| 'type': text_type, | |
| }) | |
| yield { | |
| 'imageSize': list(image.shape[:2]), | |
| 'areas': areas, | |
| } | |