""" Brackets prediction service. Recognizes bracket sequences from staff bracket images. Note: This service requires Keras/TensorFlow models (DenseNet-CTC). """ import numpy as np import cv2 import logging from predictors.densenet_ctc import load_densenet_ctc, greedy_ctc_decode from common.image_utils import array_from_image_stream class BracketCorrector: """ Corrects unpaired and nested parentheses in bracket sequences. """ pair_dict = { '{': '}', '<': '>', '[': ']', } reverse_dict = {v: k for k, v in pair_dict.items()} pair_dict.update(reverse_dict) def __init__(self, vib='<', vvib='[', vvvib='{'): """ Define bracket priority. vvvib > vvib > vib (curly > square > angle) """ self.vib = vib self.vvib = vvib self.vvvib = vvvib self.right_symbol = ['{', '[', '<'] self.left_symbol = ['}', ']', '>'] def find_cp(self, string): """Find paired brackets at each priority level.""" str_len = len(string) cp_his = [] vvvib_cp, vvib_cp, vib_cp = [], [], [] for cp_sym in [self.vvvib, self.vvib, self.vib]: cur_cp = [] for index in range(str_len): if index not in cp_his and string[index] == cp_sym: for i in range(index + 1, str_len): cur_sym = string[i] if cur_sym == self.pair_dict.get(cp_sym): for j in range(i - 1, -1, -1): if j not in cp_his and string[j] == cp_sym: if i > j: cur_cp.append((j, i)) else: cur_cp.append((i, j)) cp_his.append(i) cp_his.append(j) break if cp_sym == self.vvvib: vvvib_cp = cur_cp elif cp_sym == self.vvib: vvib_cp = cur_cp elif cp_sym == self.vib: vib_cp = cur_cp return vvvib_cp, vvib_cp, vib_cp def clean_up(self, string): """Remove nested conflicts based on priority.""" vvvib, vvib, vib = self.find_cp(string) # Check curly vs square and angle brackets for x in vvvib: x_begin, x_end = x[0], x[1] for y in list(vvib): y_begin, y_end = y[0], y[1] if (x_begin < y_begin < x_end < y_end) or \ (y_begin < x_begin < y_end < x_end): vvib.remove(y) for z in list(vib): z_begin, z_end = z[0], z[1] if (x_begin < z_begin < x_end < z_end) or \ (z_begin < x_begin < z_end < x_end): vib.remove(z) # Check square vs angle brackets for x in vvib: x_begin, x_end = x[0], x[1] for y in list(vib): y_begin, y_end = y[0], y[1] if (x_begin < y_begin < x_end < y_end) or \ (y_begin < x_begin < y_end < x_end): vib.remove(y) # Collect all valid indices all_cp = vvvib + vvib + vib new_cp_list = [] for pair in all_cp: new_cp_list.append(pair[0]) new_cp_list.append(pair[1]) return new_cp_list def correct(self, string): """ Correct bracket sequence. Returns only properly paired brackets and commas. """ all_cp = self.clean_up(string) corrected = '' for index, char in enumerate(string): if char == ',': corrected += char elif index in all_cp: corrected += char return corrected class BracketsService: """ Bracket recognition service using DenseNet-CTC. Uses DenseNet-CTC architecture for bracket sequence recognition. """ def __init__(self, model_path, device='gpu', alphabet=None, **kwargs): """ Initialize brackets service. model_path: path to bracket OCR weights (.h5) alphabet: character set for the model """ self.alphabet = alphabet or '<>[]{},-.0123456789' nclass = len(self.alphabet) + 1 self.model = load_densenet_ctc(model_path, nclass) self.corrector = BracketCorrector() def preprocess_image(self, image, target_height=32): """Preprocess bracket image for OCR model.""" # Convert to grayscale if len(image.shape) == 3: image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) # Rotate 90 degrees (brackets are vertical) image = np.rot90(image) h, w = image.shape[:2] # Resize to target height scale = target_height / h new_w = int(w * scale) image = cv2.resize(image, (new_w, target_height)) # Normalize image = image.astype(np.float32) / 255.0 - 0.5 # Add batch and channel dimensions image = np.expand_dims(image, axis=(0, -1)) # (1, H, W, 1) return image def predict(self, buffers, **kwargs): """ Recognize bracket sequence from images. buffers: list of bracket image buffers yields: corrected bracket strings """ for buffer in buffers: image = array_from_image_stream(buffer) if image is None: yield None continue try: # Preprocess processed = self.preprocess_image(image) # Predict pred = self.model.predict(processed, verbose=0) # Decode using greedy CTC content = greedy_ctc_decode(pred, self.alphabet) # Correct bracket pairing content = self.corrector.correct(content) yield content except Exception as e: logging.warning('Bracket prediction error: %s', str(e)) yield None