Spaces:
Running
Running
File size: 4,890 Bytes
2b7aae2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 | """
Brackets prediction service.
Recognizes bracket sequences from staff bracket images.
Note: This service requires Keras/TensorFlow models (DenseNet-CTC).
"""
import numpy as np
import cv2
import logging
from predictors.densenet_ctc import load_densenet_ctc, greedy_ctc_decode
from common.image_utils import array_from_image_stream
class BracketCorrector:
"""
Corrects unpaired and nested parentheses in bracket sequences.
"""
pair_dict = {
'{': '}',
'<': '>',
'[': ']',
}
reverse_dict = {v: k for k, v in pair_dict.items()}
pair_dict.update(reverse_dict)
def __init__(self, vib='<', vvib='[', vvvib='{'):
"""
Define bracket priority.
vvvib > vvib > vib (curly > square > angle)
"""
self.vib = vib
self.vvib = vvib
self.vvvib = vvvib
self.right_symbol = ['{', '[', '<']
self.left_symbol = ['}', ']', '>']
def find_cp(self, string):
"""Find paired brackets at each priority level."""
str_len = len(string)
cp_his = []
vvvib_cp, vvib_cp, vib_cp = [], [], []
for cp_sym in [self.vvvib, self.vvib, self.vib]:
cur_cp = []
for index in range(str_len):
if index not in cp_his and string[index] == cp_sym:
for i in range(index + 1, str_len):
cur_sym = string[i]
if cur_sym == self.pair_dict.get(cp_sym):
for j in range(i - 1, -1, -1):
if j not in cp_his and string[j] == cp_sym:
if i > j:
cur_cp.append((j, i))
else:
cur_cp.append((i, j))
cp_his.append(i)
cp_his.append(j)
break
if cp_sym == self.vvvib:
vvvib_cp = cur_cp
elif cp_sym == self.vvib:
vvib_cp = cur_cp
elif cp_sym == self.vib:
vib_cp = cur_cp
return vvvib_cp, vvib_cp, vib_cp
def clean_up(self, string):
"""Remove nested conflicts based on priority."""
vvvib, vvib, vib = self.find_cp(string)
# Check curly vs square and angle brackets
for x in vvvib:
x_begin, x_end = x[0], x[1]
for y in list(vvib):
y_begin, y_end = y[0], y[1]
if (x_begin < y_begin < x_end < y_end) or \
(y_begin < x_begin < y_end < x_end):
vvib.remove(y)
for z in list(vib):
z_begin, z_end = z[0], z[1]
if (x_begin < z_begin < x_end < z_end) or \
(z_begin < x_begin < z_end < x_end):
vib.remove(z)
# Check square vs angle brackets
for x in vvib:
x_begin, x_end = x[0], x[1]
for y in list(vib):
y_begin, y_end = y[0], y[1]
if (x_begin < y_begin < x_end < y_end) or \
(y_begin < x_begin < y_end < x_end):
vib.remove(y)
# Collect all valid indices
all_cp = vvvib + vvib + vib
new_cp_list = []
for pair in all_cp:
new_cp_list.append(pair[0])
new_cp_list.append(pair[1])
return new_cp_list
def correct(self, string):
"""
Correct bracket sequence.
Returns only properly paired brackets and commas.
"""
all_cp = self.clean_up(string)
corrected = ''
for index, char in enumerate(string):
if char == ',':
corrected += char
elif index in all_cp:
corrected += char
return corrected
class BracketsService:
"""
Bracket recognition service using DenseNet-CTC.
Uses DenseNet-CTC architecture for bracket sequence recognition.
"""
def __init__(self, model_path, device='gpu', alphabet=None, **kwargs):
"""
Initialize brackets service.
model_path: path to bracket OCR weights (.h5)
alphabet: character set for the model
"""
self.alphabet = alphabet or '<>[]{},-.0123456789'
nclass = len(self.alphabet) + 1
self.model = load_densenet_ctc(model_path, nclass)
self.corrector = BracketCorrector()
def preprocess_image(self, image, target_height=32):
"""Preprocess bracket image for OCR model."""
# Convert to grayscale
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Rotate 90 degrees (brackets are vertical)
image = np.rot90(image)
h, w = image.shape[:2]
# Resize to target height
scale = target_height / h
new_w = int(w * scale)
image = cv2.resize(image, (new_w, target_height))
# Normalize
image = image.astype(np.float32) / 255.0 - 0.5
# Add batch and channel dimensions
image = np.expand_dims(image, axis=(0, -1)) # (1, H, W, 1)
return image
def predict(self, buffers, **kwargs):
"""
Recognize bracket sequence from images.
buffers: list of bracket image buffers
yields: corrected bracket strings
"""
for buffer in buffers:
image = array_from_image_stream(buffer)
if image is None:
yield None
continue
try:
# Preprocess
processed = self.preprocess_image(image)
# Predict
pred = self.model.predict(processed, verbose=0)
# Decode using greedy CTC
content = greedy_ctc_decode(pred, self.alphabet)
# Correct bracket pairing
content = self.corrector.correct(content)
yield content
except Exception as e:
logging.warning('Bracket prediction error: %s', str(e))
yield None
|