starry / backend /python-services /services /brackets_service.py
k-l-lambda's picture
feat: add Python ML services (CPU mode) with model download
2b7aae2
"""
Brackets prediction service.
Recognizes bracket sequences from staff bracket images.
Note: This service requires Keras/TensorFlow models (DenseNet-CTC).
"""
import numpy as np
import cv2
import logging
from predictors.densenet_ctc import load_densenet_ctc, greedy_ctc_decode
from common.image_utils import array_from_image_stream
class BracketCorrector:
"""
Corrects unpaired and nested parentheses in bracket sequences.
"""
pair_dict = {
'{': '}',
'<': '>',
'[': ']',
}
reverse_dict = {v: k for k, v in pair_dict.items()}
pair_dict.update(reverse_dict)
def __init__(self, vib='<', vvib='[', vvvib='{'):
"""
Define bracket priority.
vvvib > vvib > vib (curly > square > angle)
"""
self.vib = vib
self.vvib = vvib
self.vvvib = vvvib
self.right_symbol = ['{', '[', '<']
self.left_symbol = ['}', ']', '>']
def find_cp(self, string):
"""Find paired brackets at each priority level."""
str_len = len(string)
cp_his = []
vvvib_cp, vvib_cp, vib_cp = [], [], []
for cp_sym in [self.vvvib, self.vvib, self.vib]:
cur_cp = []
for index in range(str_len):
if index not in cp_his and string[index] == cp_sym:
for i in range(index + 1, str_len):
cur_sym = string[i]
if cur_sym == self.pair_dict.get(cp_sym):
for j in range(i - 1, -1, -1):
if j not in cp_his and string[j] == cp_sym:
if i > j:
cur_cp.append((j, i))
else:
cur_cp.append((i, j))
cp_his.append(i)
cp_his.append(j)
break
if cp_sym == self.vvvib:
vvvib_cp = cur_cp
elif cp_sym == self.vvib:
vvib_cp = cur_cp
elif cp_sym == self.vib:
vib_cp = cur_cp
return vvvib_cp, vvib_cp, vib_cp
def clean_up(self, string):
"""Remove nested conflicts based on priority."""
vvvib, vvib, vib = self.find_cp(string)
# Check curly vs square and angle brackets
for x in vvvib:
x_begin, x_end = x[0], x[1]
for y in list(vvib):
y_begin, y_end = y[0], y[1]
if (x_begin < y_begin < x_end < y_end) or \
(y_begin < x_begin < y_end < x_end):
vvib.remove(y)
for z in list(vib):
z_begin, z_end = z[0], z[1]
if (x_begin < z_begin < x_end < z_end) or \
(z_begin < x_begin < z_end < x_end):
vib.remove(z)
# Check square vs angle brackets
for x in vvib:
x_begin, x_end = x[0], x[1]
for y in list(vib):
y_begin, y_end = y[0], y[1]
if (x_begin < y_begin < x_end < y_end) or \
(y_begin < x_begin < y_end < x_end):
vib.remove(y)
# Collect all valid indices
all_cp = vvvib + vvib + vib
new_cp_list = []
for pair in all_cp:
new_cp_list.append(pair[0])
new_cp_list.append(pair[1])
return new_cp_list
def correct(self, string):
"""
Correct bracket sequence.
Returns only properly paired brackets and commas.
"""
all_cp = self.clean_up(string)
corrected = ''
for index, char in enumerate(string):
if char == ',':
corrected += char
elif index in all_cp:
corrected += char
return corrected
class BracketsService:
"""
Bracket recognition service using DenseNet-CTC.
Uses DenseNet-CTC architecture for bracket sequence recognition.
"""
def __init__(self, model_path, device='gpu', alphabet=None, **kwargs):
"""
Initialize brackets service.
model_path: path to bracket OCR weights (.h5)
alphabet: character set for the model
"""
self.alphabet = alphabet or '<>[]{},-.0123456789'
nclass = len(self.alphabet) + 1
self.model = load_densenet_ctc(model_path, nclass)
self.corrector = BracketCorrector()
def preprocess_image(self, image, target_height=32):
"""Preprocess bracket image for OCR model."""
# Convert to grayscale
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Rotate 90 degrees (brackets are vertical)
image = np.rot90(image)
h, w = image.shape[:2]
# Resize to target height
scale = target_height / h
new_w = int(w * scale)
image = cv2.resize(image, (new_w, target_height))
# Normalize
image = image.astype(np.float32) / 255.0 - 0.5
# Add batch and channel dimensions
image = np.expand_dims(image, axis=(0, -1)) # (1, H, W, 1)
return image
def predict(self, buffers, **kwargs):
"""
Recognize bracket sequence from images.
buffers: list of bracket image buffers
yields: corrected bracket strings
"""
for buffer in buffers:
image = array_from_image_stream(buffer)
if image is None:
yield None
continue
try:
# Preprocess
processed = self.preprocess_image(image)
# Predict
pred = self.model.predict(processed, verbose=0)
# Decode using greedy CTC
content = greedy_ctc_decode(pred, self.alphabet)
# Correct bracket pairing
content = self.corrector.correct(content)
yield content
except Exception as e:
logging.warning('Bracket prediction error: %s', str(e))
yield None