Spaces:
Running
Running
File size: 5,697 Bytes
2b7aae2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | """
DenseNet-CTC model builder for OCR/brackets recognition.
Ports the DenseNet-CTC architecture from starry-ocr to TF 2.x Keras.
Builds the model graph, loads weights-only .h5 files, and provides
numpy-based CTC greedy decoding (no tf.Session needed).
"""
import os
import logging
import numpy as np
os.environ.setdefault('TF_USE_LEGACY_KERAS', '1')
import tensorflow as tf
# Default architecture config (matches starry-ocr training)
DEFAULT_DENSENET_CONFIG = {
'first_conv_filters': 64,
'first_conv_size': 5,
'first_conv_stride': 2,
'dense_block_layers': [8, 8, 8],
'dense_block_growth_rate': 8,
'trans_block_filters': 128,
'first_pool_size': 0,
'first_pool_stride': 2,
'last_conv_size': 0,
'last_conv_filters': 0,
'last_pool_size': 2,
}
DEFAULT_IMAGE_CONFIG = {
'height': 32,
'channel': 1,
}
def _conv_block(x, growth_rate):
"""Single dense block convolution: BN → ReLU → Conv2D(3×3)."""
x = tf.keras.layers.BatchNormalization(axis=-1, epsilon=1.1e-5)(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Conv2D(growth_rate, (3, 3), kernel_initializer='he_normal', padding='same')(x)
return x
def _dense_block(x, nb_layers, nb_filter, growth_rate):
"""Dense block: stack of conv_blocks with concatenation."""
for _ in range(nb_layers):
cb = _conv_block(x, growth_rate)
x = tf.keras.layers.Concatenate()([x, cb])
nb_filter += growth_rate
return x, nb_filter
def _transition_block(x, nb_filter, weight_decay=1e-4):
"""Transition block: BN → ReLU → 1×1 Conv → AvgPool(2×2)."""
x = tf.keras.layers.BatchNormalization(axis=-1, epsilon=1.1e-5)(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Conv2D(
nb_filter, (1, 1), kernel_initializer='he_normal', padding='same',
use_bias=False, kernel_regularizer=tf.keras.regularizers.l2(weight_decay)
)(x)
x = tf.keras.layers.AveragePooling2D((2, 2), strides=(2, 2))(x)
return x, nb_filter
def build_densenet_ctc(nclass, cfg=None, img_cfg=None):
"""
Build the DenseNet-CTC model matching the starry-ocr architecture.
Returns a Keras Model: input (B, H, W?, 1) → softmax (B, T, nclass)
"""
cfg = cfg or DEFAULT_DENSENET_CONFIG
img_cfg = img_cfg or DEFAULT_IMAGE_CONFIG
height = img_cfg['height']
channels = img_cfg['channel']
weight_decay = 1e-4
inp = tf.keras.Input(shape=(height, None, channels), name='the_input')
x = inp
# Attention module
a = tf.keras.layers.Permute((2, 1, 3), name='permute_first')(x)
attention_ratio = 64 if height > 64 else height
a = tf.keras.layers.Dense(attention_ratio, activation='softmax')(a)
a = tf.keras.layers.Permute((2, 1, 3), name='attention_vec')(a)
x = tf.keras.layers.Multiply(name='attention_mul')([x, a])
# Initial convolution
nb_filter = cfg['first_conv_filters']
x = tf.keras.layers.Conv2D(
nb_filter, cfg['first_conv_size'], strides=cfg['first_conv_stride'],
padding='same', use_bias=False,
kernel_regularizer=tf.keras.regularizers.l2(weight_decay)
)(x)
if cfg['first_pool_size']:
x = tf.keras.layers.AveragePooling2D(
cfg['first_pool_size'], strides=cfg['first_pool_stride']
)(x)
# Dense blocks + transitions
nb_layers = cfg['dense_block_layers']
growth_rate = cfg['dense_block_growth_rate']
for n_layer in nb_layers[:-1]:
x, nb_filter = _dense_block(x, n_layer, nb_filter, growth_rate)
trans_filters = cfg['trans_block_filters'] or nb_filter // 2
x, nb_filter = _transition_block(x, trans_filters)
x, nb_filter = _dense_block(x, nb_layers[-1], nb_filter, growth_rate)
if cfg['last_conv_size']:
conv_filters = cfg['last_conv_filters'] or nb_filter
x = tf.keras.layers.BatchNormalization(axis=-1, epsilon=1.1e-5)(x)
x = tf.keras.layers.Conv2D(
conv_filters, cfg['last_conv_size'], kernel_initializer='he_normal',
padding='same', use_bias=False,
kernel_regularizer=tf.keras.regularizers.l2(weight_decay)
)(x)
x = tf.keras.layers.AveragePooling2D(cfg['last_pool_size'], strides=2)(x)
# Final BN + ReLU
x = tf.keras.layers.BatchNormalization(axis=-1, epsilon=1.1e-5)(x)
x = tf.keras.layers.Activation('relu')(x)
# Reshape to sequence: (B, W, H*C) → time-distributed flatten
x = tf.keras.layers.Permute((2, 1, 3), name='permute')(x)
x = tf.keras.layers.TimeDistributed(tf.keras.layers.Flatten(), name='flatten')(x)
# Softmax output
y_pred = tf.keras.layers.Dense(nclass, name='out', activation='softmax')(x)
model = tf.keras.Model(inputs=inp, outputs=y_pred, name='densenet_ctc')
return model
def load_densenet_ctc(model_path, nclass, cfg=None, img_cfg=None):
"""
Build model and load weights from .h5 file.
The .h5 files from starry-ocr are weights-only (no model_config).
The original model had additional input_length + CtcDecodeLayer,
but those layers have no trainable weights, so by_name loading works.
"""
model = build_densenet_ctc(nclass, cfg, img_cfg)
if not os.path.exists(model_path):
raise FileNotFoundError(f'Model weights not found: {model_path}')
model.load_weights(model_path, by_name=True, skip_mismatch=True)
logging.info('DenseNet-CTC weights loaded: %s (%d classes)', model_path, nclass)
return model
def greedy_ctc_decode(pred, alphabet):
"""
Greedy CTC decoding on raw softmax output.
pred: (B, T, nclass) numpy array
alphabet: string of characters (len = nclass - 1, last class is blank)
Returns decoded string.
"""
# pred shape: (1, T, nclass) — take first batch
pred_indices = np.argmax(pred[0], axis=-1) # (T,)
nclass = pred.shape[-1]
chars = []
prev = -1
for idx in pred_indices:
# Skip blank (last class) and repeated indices
if idx != nclass - 1 and idx != prev:
if idx < len(alphabet):
chars.append(alphabet[idx])
prev = idx
return ''.join(chars)
|