Spaces:
Running
Running
File size: 5,553 Bytes
2b7aae2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | """
OCR prediction service.
Recognizes text content from detected regions using DenseNet-CTC.
"""
import math
import numpy as np
import cv2
import logging
from predictors.densenet_ctc import load_densenet_ctc, greedy_ctc_decode
from common.image_utils import array_from_image_stream
# Text type categories (same as loc_service)
TYPE_NAMES = [
'Title', # 0
'Author', # 1
'TextualMark', # 2
'TempoNumeral', # 3
'MeasureNumber', # 4
'Times', # 5
'Chord', # 6
'PageMargin', # 7
'Instrument', # 8
'Other', # 9
'Lyric', # 10
'Alter1', # 11
'Alter2', # 12
]
# Tempo character mapping (musical note symbols)
TEMPO_CHAR_DICT = {
'a': '\U0001d15d', # Whole Note
'b': '\U0001d15e', # Half Note
'c': '\U0001d15f', # Quarter Note
'd': '\U0001d160', # Eighth Note
'e': '\U0001d161', # Sixteenth Note
'f': '\U0001d162', # Thirty-Second Note
}
def translate_string_by_dict(s, d):
"""Translate characters in string using dictionary."""
return ''.join(d.get(c, c) for c in s)
class OcrService:
"""
OCR service using DenseNet-CTC architecture.
"""
def __init__(self, model_path, device='gpu', alphabet=None,
tempo_model_path=None, tempo_alphabet=None,
chord_model_path=None, **kwargs):
self.alphabet = alphabet or ''
self.tempo_alphabet = tempo_alphabet or ''
# nclass = alphabet length + 1 (blank token '卍')
nclass = len(self.alphabet) + 1
self.model = load_densenet_ctc(model_path, nclass)
# Load tempo model
self.tempo_model = None
if tempo_model_path:
tempo_nclass = len(self.tempo_alphabet) + 1
self.tempo_model = load_densenet_ctc(tempo_model_path, tempo_nclass)
# Chord model (SavedModel directory, different architecture)
self.chord_model = None
if chord_model_path and chord_model_path.endswith('/'):
try:
import tensorflow as tf
self.chord_model = tf.keras.models.load_model(chord_model_path, compile=False)
logging.info('Chord model loaded: %s', chord_model_path)
except Exception as e:
logging.warning('Failed to load chord model: %s', e)
def preprocess_image(self, image, target_height=32):
"""Preprocess image for DenseNet-CTC model."""
h, w = image.shape[:2]
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
scale = target_height / h
new_w = int(w * scale)
if new_w < 9:
return None
image = cv2.resize(image, (new_w, target_height))
image = image.astype(np.float32) / 255.0 - 0.5
image = np.expand_dims(image, axis=(0, -1)) # (1, H, W, 1)
return image
def perspective_transform(self, image, box):
"""Apply perspective transform to extract text region."""
pts1 = np.float32([
[box['x0'], box['y0']],
[box['x1'], box['y1']],
[box['x2'], box['y2']],
[box['x3'], box['y3']]
])
trans_width = round(math.sqrt(
(box['x0'] - box['x1']) ** 2 + (box['y0'] - box['y1']) ** 2
))
trans_height = round(math.sqrt(
(box['x0'] - box['x3']) ** 2 + (box['y0'] - box['y3']) ** 2
))
if trans_width < 1 or trans_height < 1:
return None
pts2 = np.float32([
[0, 0],
[trans_width, 0],
[trans_width, trans_height],
[0, trans_height]
])
M = cv2.getPerspectiveTransform(pts1, pts2)
dst = cv2.warpPerspective(image, M, (trans_width, trans_height))
return dst
def predict(self, buffers, location=None, **kwargs):
"""
Recognize text from image with location info.
buffers: list containing single image buffer
location: list of detected boxes from loc_service
yields: recognition results
"""
if not buffers:
yield {'error': 'No image provided'}
return
image = array_from_image_stream(buffers[0])
if image is None:
yield {'error': 'Invalid image'}
return
if not location:
yield {'imageSize': list(image.shape[:2]), 'areas': []}
return
areas = []
for box in location:
dst_pic = self.perspective_transform(image, box)
if dst_pic is None:
continue
cx = (box['x0'] + box['x1'] + box['x2'] + box['x3']) / 4
cy = (box['y0'] + box['y1'] + box['y2'] + box['y3']) / 4
width = (box['x1'] + box['x2'] - box['x0'] - box['x3']) / 2
height = (box['y2'] + box['y3'] - box['y0'] - box['y1']) / 2
theta = math.atan2(
box['y1'] - box['y0'] + box['y2'] - box['y3'],
box['x1'] - box['x0'] + box['x2'] - box['x3']
)
text_type = TYPE_NAMES[box.get('class', 0)]
text = ''
feature_dict = None
try:
if text_type == 'TempoNumeral' and self.tempo_model is not None:
processed = self.preprocess_image(dst_pic)
if processed is not None:
pred = self.tempo_model.predict(processed, verbose=0)
text = greedy_ctc_decode(pred, self.tempo_alphabet)
text = translate_string_by_dict(text, TEMPO_CHAR_DICT)
elif text_type == 'Chord' and self.chord_model is not None:
processed = self.preprocess_image(dst_pic)
if processed is not None:
pred = self.chord_model.predict(processed, verbose=0)
text = greedy_ctc_decode(pred, self.alphabet)
else:
processed = self.preprocess_image(dst_pic)
if processed is not None:
pred = self.model.predict(processed, verbose=0)
text = greedy_ctc_decode(pred, self.alphabet)
except Exception as e:
logging.warning('OCR prediction error: %s', str(e))
text = ''
areas.append({
'score': box.get('score', 0),
'text': text,
'feature_dict': feature_dict,
'cx': cx,
'cy': cy,
'width': width,
'height': height,
'theta': theta,
'type': text_type,
})
yield {
'imageSize': list(image.shape[:2]),
'areas': areas,
}
|