Spaces:
Running
Running
| """ | |
| Brackets prediction service. | |
| Recognizes bracket sequences from staff bracket images. | |
| Note: This service requires Keras/TensorFlow models (DenseNet-CTC). | |
| """ | |
| import numpy as np | |
| import cv2 | |
| import logging | |
| from predictors.densenet_ctc import load_densenet_ctc, greedy_ctc_decode | |
| from common.image_utils import array_from_image_stream | |
| class BracketCorrector: | |
| """ | |
| Corrects unpaired and nested parentheses in bracket sequences. | |
| """ | |
| pair_dict = { | |
| '{': '}', | |
| '<': '>', | |
| '[': ']', | |
| } | |
| reverse_dict = {v: k for k, v in pair_dict.items()} | |
| pair_dict.update(reverse_dict) | |
| def __init__(self, vib='<', vvib='[', vvvib='{'): | |
| """ | |
| Define bracket priority. | |
| vvvib > vvib > vib (curly > square > angle) | |
| """ | |
| self.vib = vib | |
| self.vvib = vvib | |
| self.vvvib = vvvib | |
| self.right_symbol = ['{', '[', '<'] | |
| self.left_symbol = ['}', ']', '>'] | |
| def find_cp(self, string): | |
| """Find paired brackets at each priority level.""" | |
| str_len = len(string) | |
| cp_his = [] | |
| vvvib_cp, vvib_cp, vib_cp = [], [], [] | |
| for cp_sym in [self.vvvib, self.vvib, self.vib]: | |
| cur_cp = [] | |
| for index in range(str_len): | |
| if index not in cp_his and string[index] == cp_sym: | |
| for i in range(index + 1, str_len): | |
| cur_sym = string[i] | |
| if cur_sym == self.pair_dict.get(cp_sym): | |
| for j in range(i - 1, -1, -1): | |
| if j not in cp_his and string[j] == cp_sym: | |
| if i > j: | |
| cur_cp.append((j, i)) | |
| else: | |
| cur_cp.append((i, j)) | |
| cp_his.append(i) | |
| cp_his.append(j) | |
| break | |
| if cp_sym == self.vvvib: | |
| vvvib_cp = cur_cp | |
| elif cp_sym == self.vvib: | |
| vvib_cp = cur_cp | |
| elif cp_sym == self.vib: | |
| vib_cp = cur_cp | |
| return vvvib_cp, vvib_cp, vib_cp | |
| def clean_up(self, string): | |
| """Remove nested conflicts based on priority.""" | |
| vvvib, vvib, vib = self.find_cp(string) | |
| # Check curly vs square and angle brackets | |
| for x in vvvib: | |
| x_begin, x_end = x[0], x[1] | |
| for y in list(vvib): | |
| y_begin, y_end = y[0], y[1] | |
| if (x_begin < y_begin < x_end < y_end) or \ | |
| (y_begin < x_begin < y_end < x_end): | |
| vvib.remove(y) | |
| for z in list(vib): | |
| z_begin, z_end = z[0], z[1] | |
| if (x_begin < z_begin < x_end < z_end) or \ | |
| (z_begin < x_begin < z_end < x_end): | |
| vib.remove(z) | |
| # Check square vs angle brackets | |
| for x in vvib: | |
| x_begin, x_end = x[0], x[1] | |
| for y in list(vib): | |
| y_begin, y_end = y[0], y[1] | |
| if (x_begin < y_begin < x_end < y_end) or \ | |
| (y_begin < x_begin < y_end < x_end): | |
| vib.remove(y) | |
| # Collect all valid indices | |
| all_cp = vvvib + vvib + vib | |
| new_cp_list = [] | |
| for pair in all_cp: | |
| new_cp_list.append(pair[0]) | |
| new_cp_list.append(pair[1]) | |
| return new_cp_list | |
| def correct(self, string): | |
| """ | |
| Correct bracket sequence. | |
| Returns only properly paired brackets and commas. | |
| """ | |
| all_cp = self.clean_up(string) | |
| corrected = '' | |
| for index, char in enumerate(string): | |
| if char == ',': | |
| corrected += char | |
| elif index in all_cp: | |
| corrected += char | |
| return corrected | |
| class BracketsService: | |
| """ | |
| Bracket recognition service using DenseNet-CTC. | |
| Uses DenseNet-CTC architecture for bracket sequence recognition. | |
| """ | |
| def __init__(self, model_path, device='gpu', alphabet=None, **kwargs): | |
| """ | |
| Initialize brackets service. | |
| model_path: path to bracket OCR weights (.h5) | |
| alphabet: character set for the model | |
| """ | |
| self.alphabet = alphabet or '<>[]{},-.0123456789' | |
| nclass = len(self.alphabet) + 1 | |
| self.model = load_densenet_ctc(model_path, nclass) | |
| self.corrector = BracketCorrector() | |
| def preprocess_image(self, image, target_height=32): | |
| """Preprocess bracket image for OCR model.""" | |
| # Convert to grayscale | |
| if len(image.shape) == 3: | |
| image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
| # Rotate 90 degrees (brackets are vertical) | |
| image = np.rot90(image) | |
| h, w = image.shape[:2] | |
| # Resize to target height | |
| scale = target_height / h | |
| new_w = int(w * scale) | |
| image = cv2.resize(image, (new_w, target_height)) | |
| # Normalize | |
| image = image.astype(np.float32) / 255.0 - 0.5 | |
| # Add batch and channel dimensions | |
| image = np.expand_dims(image, axis=(0, -1)) # (1, H, W, 1) | |
| return image | |
| def predict(self, buffers, **kwargs): | |
| """ | |
| Recognize bracket sequence from images. | |
| buffers: list of bracket image buffers | |
| yields: corrected bracket strings | |
| """ | |
| for buffer in buffers: | |
| image = array_from_image_stream(buffer) | |
| if image is None: | |
| yield None | |
| continue | |
| try: | |
| # Preprocess | |
| processed = self.preprocess_image(image) | |
| # Predict | |
| pred = self.model.predict(processed, verbose=0) | |
| # Decode using greedy CTC | |
| content = greedy_ctc_decode(pred, self.alphabet) | |
| # Correct bracket pairing | |
| content = self.corrector.correct(content) | |
| yield content | |
| except Exception as e: | |
| logging.warning('Bracket prediction error: %s', str(e)) | |
| yield None | |