File size: 4,890 Bytes
2b7aae2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""
Brackets prediction service.
Recognizes bracket sequences from staff bracket images.

Note: This service requires Keras/TensorFlow models (DenseNet-CTC).
"""

import numpy as np
import cv2
import logging

from predictors.densenet_ctc import load_densenet_ctc, greedy_ctc_decode
from common.image_utils import array_from_image_stream


class BracketCorrector:
	"""
	Corrects unpaired and nested parentheses in bracket sequences.
	"""

	pair_dict = {
		'{': '}',
		'<': '>',
		'[': ']',
	}
	reverse_dict = {v: k for k, v in pair_dict.items()}
	pair_dict.update(reverse_dict)

	def __init__(self, vib='<', vvib='[', vvvib='{'):
		"""
		Define bracket priority.
		vvvib > vvib > vib (curly > square > angle)
		"""
		self.vib = vib
		self.vvib = vvib
		self.vvvib = vvvib
		self.right_symbol = ['{', '[', '<']
		self.left_symbol = ['}', ']', '>']

	def find_cp(self, string):
		"""Find paired brackets at each priority level."""
		str_len = len(string)
		cp_his = []
		vvvib_cp, vvib_cp, vib_cp = [], [], []

		for cp_sym in [self.vvvib, self.vvib, self.vib]:
			cur_cp = []
			for index in range(str_len):
				if index not in cp_his and string[index] == cp_sym:
					for i in range(index + 1, str_len):
						cur_sym = string[i]
						if cur_sym == self.pair_dict.get(cp_sym):
							for j in range(i - 1, -1, -1):
								if j not in cp_his and string[j] == cp_sym:
									if i > j:
										cur_cp.append((j, i))
									else:
										cur_cp.append((i, j))
									cp_his.append(i)
									cp_his.append(j)
									break

			if cp_sym == self.vvvib:
				vvvib_cp = cur_cp
			elif cp_sym == self.vvib:
				vvib_cp = cur_cp
			elif cp_sym == self.vib:
				vib_cp = cur_cp

		return vvvib_cp, vvib_cp, vib_cp

	def clean_up(self, string):
		"""Remove nested conflicts based on priority."""
		vvvib, vvib, vib = self.find_cp(string)

		# Check curly vs square and angle brackets
		for x in vvvib:
			x_begin, x_end = x[0], x[1]

			for y in list(vvib):
				y_begin, y_end = y[0], y[1]
				if (x_begin < y_begin < x_end < y_end) or \
				   (y_begin < x_begin < y_end < x_end):
					vvib.remove(y)

			for z in list(vib):
				z_begin, z_end = z[0], z[1]
				if (x_begin < z_begin < x_end < z_end) or \
				   (z_begin < x_begin < z_end < x_end):
					vib.remove(z)

		# Check square vs angle brackets
		for x in vvib:
			x_begin, x_end = x[0], x[1]

			for y in list(vib):
				y_begin, y_end = y[0], y[1]
				if (x_begin < y_begin < x_end < y_end) or \
				   (y_begin < x_begin < y_end < x_end):
					vib.remove(y)

		# Collect all valid indices
		all_cp = vvvib + vvib + vib
		new_cp_list = []
		for pair in all_cp:
			new_cp_list.append(pair[0])
			new_cp_list.append(pair[1])

		return new_cp_list

	def correct(self, string):
		"""
		Correct bracket sequence.
		Returns only properly paired brackets and commas.
		"""
		all_cp = self.clean_up(string)
		corrected = ''

		for index, char in enumerate(string):
			if char == ',':
				corrected += char
			elif index in all_cp:
				corrected += char

		return corrected


class BracketsService:
	"""
	Bracket recognition service using DenseNet-CTC.

	Uses DenseNet-CTC architecture for bracket sequence recognition.
	"""

	def __init__(self, model_path, device='gpu', alphabet=None, **kwargs):
		"""
		Initialize brackets service.

		model_path: path to bracket OCR weights (.h5)
		alphabet: character set for the model
		"""
		self.alphabet = alphabet or '<>[]{},-.0123456789'
		nclass = len(self.alphabet) + 1
		self.model = load_densenet_ctc(model_path, nclass)
		self.corrector = BracketCorrector()

	def preprocess_image(self, image, target_height=32):
		"""Preprocess bracket image for OCR model."""
		# Convert to grayscale
		if len(image.shape) == 3:
			image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)

		# Rotate 90 degrees (brackets are vertical)
		image = np.rot90(image)

		h, w = image.shape[:2]

		# Resize to target height
		scale = target_height / h
		new_w = int(w * scale)
		image = cv2.resize(image, (new_w, target_height))

		# Normalize
		image = image.astype(np.float32) / 255.0 - 0.5

		# Add batch and channel dimensions
		image = np.expand_dims(image, axis=(0, -1))  # (1, H, W, 1)

		return image

	def predict(self, buffers, **kwargs):
		"""
		Recognize bracket sequence from images.

		buffers: list of bracket image buffers
		yields: corrected bracket strings
		"""
		for buffer in buffers:
			image = array_from_image_stream(buffer)
			if image is None:
				yield None
				continue

			try:
				# Preprocess
				processed = self.preprocess_image(image)

				# Predict
				pred = self.model.predict(processed, verbose=0)

				# Decode using greedy CTC
				content = greedy_ctc_decode(pred, self.alphabet)

				# Correct bracket pairing
				content = self.corrector.correct(content)

				yield content

			except Exception as e:
				logging.warning('Bracket prediction error: %s', str(e))
				yield None