k-l-lambda's picture
feat: add Python ML services (CPU mode) with model download
2b7aae2
"""
OCR prediction service.
Recognizes text content from detected regions using DenseNet-CTC.
"""
import math
import numpy as np
import cv2
import logging
from predictors.densenet_ctc import load_densenet_ctc, greedy_ctc_decode
from common.image_utils import array_from_image_stream
# Text type categories (same as loc_service)
TYPE_NAMES = [
'Title', # 0
'Author', # 1
'TextualMark', # 2
'TempoNumeral', # 3
'MeasureNumber', # 4
'Times', # 5
'Chord', # 6
'PageMargin', # 7
'Instrument', # 8
'Other', # 9
'Lyric', # 10
'Alter1', # 11
'Alter2', # 12
]
# Tempo character mapping (musical note symbols)
TEMPO_CHAR_DICT = {
'a': '\U0001d15d', # Whole Note
'b': '\U0001d15e', # Half Note
'c': '\U0001d15f', # Quarter Note
'd': '\U0001d160', # Eighth Note
'e': '\U0001d161', # Sixteenth Note
'f': '\U0001d162', # Thirty-Second Note
}
def translate_string_by_dict(s, d):
"""Translate characters in string using dictionary."""
return ''.join(d.get(c, c) for c in s)
class OcrService:
"""
OCR service using DenseNet-CTC architecture.
"""
def __init__(self, model_path, device='gpu', alphabet=None,
tempo_model_path=None, tempo_alphabet=None,
chord_model_path=None, **kwargs):
self.alphabet = alphabet or ''
self.tempo_alphabet = tempo_alphabet or ''
# nclass = alphabet length + 1 (blank token '卍')
nclass = len(self.alphabet) + 1
self.model = load_densenet_ctc(model_path, nclass)
# Load tempo model
self.tempo_model = None
if tempo_model_path:
tempo_nclass = len(self.tempo_alphabet) + 1
self.tempo_model = load_densenet_ctc(tempo_model_path, tempo_nclass)
# Chord model (SavedModel directory, different architecture)
self.chord_model = None
if chord_model_path and chord_model_path.endswith('/'):
try:
import tensorflow as tf
self.chord_model = tf.keras.models.load_model(chord_model_path, compile=False)
logging.info('Chord model loaded: %s', chord_model_path)
except Exception as e:
logging.warning('Failed to load chord model: %s', e)
def preprocess_image(self, image, target_height=32):
"""Preprocess image for DenseNet-CTC model."""
h, w = image.shape[:2]
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
scale = target_height / h
new_w = int(w * scale)
if new_w < 9:
return None
image = cv2.resize(image, (new_w, target_height))
image = image.astype(np.float32) / 255.0 - 0.5
image = np.expand_dims(image, axis=(0, -1)) # (1, H, W, 1)
return image
def perspective_transform(self, image, box):
"""Apply perspective transform to extract text region."""
pts1 = np.float32([
[box['x0'], box['y0']],
[box['x1'], box['y1']],
[box['x2'], box['y2']],
[box['x3'], box['y3']]
])
trans_width = round(math.sqrt(
(box['x0'] - box['x1']) ** 2 + (box['y0'] - box['y1']) ** 2
))
trans_height = round(math.sqrt(
(box['x0'] - box['x3']) ** 2 + (box['y0'] - box['y3']) ** 2
))
if trans_width < 1 or trans_height < 1:
return None
pts2 = np.float32([
[0, 0],
[trans_width, 0],
[trans_width, trans_height],
[0, trans_height]
])
M = cv2.getPerspectiveTransform(pts1, pts2)
dst = cv2.warpPerspective(image, M, (trans_width, trans_height))
return dst
def predict(self, buffers, location=None, **kwargs):
"""
Recognize text from image with location info.
buffers: list containing single image buffer
location: list of detected boxes from loc_service
yields: recognition results
"""
if not buffers:
yield {'error': 'No image provided'}
return
image = array_from_image_stream(buffers[0])
if image is None:
yield {'error': 'Invalid image'}
return
if not location:
yield {'imageSize': list(image.shape[:2]), 'areas': []}
return
areas = []
for box in location:
dst_pic = self.perspective_transform(image, box)
if dst_pic is None:
continue
cx = (box['x0'] + box['x1'] + box['x2'] + box['x3']) / 4
cy = (box['y0'] + box['y1'] + box['y2'] + box['y3']) / 4
width = (box['x1'] + box['x2'] - box['x0'] - box['x3']) / 2
height = (box['y2'] + box['y3'] - box['y0'] - box['y1']) / 2
theta = math.atan2(
box['y1'] - box['y0'] + box['y2'] - box['y3'],
box['x1'] - box['x0'] + box['x2'] - box['x3']
)
text_type = TYPE_NAMES[box.get('class', 0)]
text = ''
feature_dict = None
try:
if text_type == 'TempoNumeral' and self.tempo_model is not None:
processed = self.preprocess_image(dst_pic)
if processed is not None:
pred = self.tempo_model.predict(processed, verbose=0)
text = greedy_ctc_decode(pred, self.tempo_alphabet)
text = translate_string_by_dict(text, TEMPO_CHAR_DICT)
elif text_type == 'Chord' and self.chord_model is not None:
processed = self.preprocess_image(dst_pic)
if processed is not None:
pred = self.chord_model.predict(processed, verbose=0)
text = greedy_ctc_decode(pred, self.alphabet)
else:
processed = self.preprocess_image(dst_pic)
if processed is not None:
pred = self.model.predict(processed, verbose=0)
text = greedy_ctc_decode(pred, self.alphabet)
except Exception as e:
logging.warning('OCR prediction error: %s', str(e))
text = ''
areas.append({
'score': box.get('score', 0),
'text': text,
'feature_dict': feature_dict,
'cx': cx,
'cy': cy,
'width': width,
'height': height,
'theta': theta,
'type': text_type,
})
yield {
'imageSize': list(image.shape[:2]),
'areas': areas,
}