""" DenseNet-CTC model builder for OCR/brackets recognition. Ports the DenseNet-CTC architecture from starry-ocr to TF 2.x Keras. Builds the model graph, loads weights-only .h5 files, and provides numpy-based CTC greedy decoding (no tf.Session needed). """ import os import logging import numpy as np os.environ.setdefault('TF_USE_LEGACY_KERAS', '1') import tensorflow as tf # Default architecture config (matches starry-ocr training) DEFAULT_DENSENET_CONFIG = { 'first_conv_filters': 64, 'first_conv_size': 5, 'first_conv_stride': 2, 'dense_block_layers': [8, 8, 8], 'dense_block_growth_rate': 8, 'trans_block_filters': 128, 'first_pool_size': 0, 'first_pool_stride': 2, 'last_conv_size': 0, 'last_conv_filters': 0, 'last_pool_size': 2, } DEFAULT_IMAGE_CONFIG = { 'height': 32, 'channel': 1, } def _conv_block(x, growth_rate): """Single dense block convolution: BN → ReLU → Conv2D(3×3).""" x = tf.keras.layers.BatchNormalization(axis=-1, epsilon=1.1e-5)(x) x = tf.keras.layers.Activation('relu')(x) x = tf.keras.layers.Conv2D(growth_rate, (3, 3), kernel_initializer='he_normal', padding='same')(x) return x def _dense_block(x, nb_layers, nb_filter, growth_rate): """Dense block: stack of conv_blocks with concatenation.""" for _ in range(nb_layers): cb = _conv_block(x, growth_rate) x = tf.keras.layers.Concatenate()([x, cb]) nb_filter += growth_rate return x, nb_filter def _transition_block(x, nb_filter, weight_decay=1e-4): """Transition block: BN → ReLU → 1×1 Conv → AvgPool(2×2).""" x = tf.keras.layers.BatchNormalization(axis=-1, epsilon=1.1e-5)(x) x = tf.keras.layers.Activation('relu')(x) x = tf.keras.layers.Conv2D( nb_filter, (1, 1), kernel_initializer='he_normal', padding='same', use_bias=False, kernel_regularizer=tf.keras.regularizers.l2(weight_decay) )(x) x = tf.keras.layers.AveragePooling2D((2, 2), strides=(2, 2))(x) return x, nb_filter def build_densenet_ctc(nclass, cfg=None, img_cfg=None): """ Build the DenseNet-CTC model matching the starry-ocr architecture. Returns a Keras Model: input (B, H, W?, 1) → softmax (B, T, nclass) """ cfg = cfg or DEFAULT_DENSENET_CONFIG img_cfg = img_cfg or DEFAULT_IMAGE_CONFIG height = img_cfg['height'] channels = img_cfg['channel'] weight_decay = 1e-4 inp = tf.keras.Input(shape=(height, None, channels), name='the_input') x = inp # Attention module a = tf.keras.layers.Permute((2, 1, 3), name='permute_first')(x) attention_ratio = 64 if height > 64 else height a = tf.keras.layers.Dense(attention_ratio, activation='softmax')(a) a = tf.keras.layers.Permute((2, 1, 3), name='attention_vec')(a) x = tf.keras.layers.Multiply(name='attention_mul')([x, a]) # Initial convolution nb_filter = cfg['first_conv_filters'] x = tf.keras.layers.Conv2D( nb_filter, cfg['first_conv_size'], strides=cfg['first_conv_stride'], padding='same', use_bias=False, kernel_regularizer=tf.keras.regularizers.l2(weight_decay) )(x) if cfg['first_pool_size']: x = tf.keras.layers.AveragePooling2D( cfg['first_pool_size'], strides=cfg['first_pool_stride'] )(x) # Dense blocks + transitions nb_layers = cfg['dense_block_layers'] growth_rate = cfg['dense_block_growth_rate'] for n_layer in nb_layers[:-1]: x, nb_filter = _dense_block(x, n_layer, nb_filter, growth_rate) trans_filters = cfg['trans_block_filters'] or nb_filter // 2 x, nb_filter = _transition_block(x, trans_filters) x, nb_filter = _dense_block(x, nb_layers[-1], nb_filter, growth_rate) if cfg['last_conv_size']: conv_filters = cfg['last_conv_filters'] or nb_filter x = tf.keras.layers.BatchNormalization(axis=-1, epsilon=1.1e-5)(x) x = tf.keras.layers.Conv2D( conv_filters, cfg['last_conv_size'], kernel_initializer='he_normal', padding='same', use_bias=False, kernel_regularizer=tf.keras.regularizers.l2(weight_decay) )(x) x = tf.keras.layers.AveragePooling2D(cfg['last_pool_size'], strides=2)(x) # Final BN + ReLU x = tf.keras.layers.BatchNormalization(axis=-1, epsilon=1.1e-5)(x) x = tf.keras.layers.Activation('relu')(x) # Reshape to sequence: (B, W, H*C) → time-distributed flatten x = tf.keras.layers.Permute((2, 1, 3), name='permute')(x) x = tf.keras.layers.TimeDistributed(tf.keras.layers.Flatten(), name='flatten')(x) # Softmax output y_pred = tf.keras.layers.Dense(nclass, name='out', activation='softmax')(x) model = tf.keras.Model(inputs=inp, outputs=y_pred, name='densenet_ctc') return model def load_densenet_ctc(model_path, nclass, cfg=None, img_cfg=None): """ Build model and load weights from .h5 file. The .h5 files from starry-ocr are weights-only (no model_config). The original model had additional input_length + CtcDecodeLayer, but those layers have no trainable weights, so by_name loading works. """ model = build_densenet_ctc(nclass, cfg, img_cfg) if not os.path.exists(model_path): raise FileNotFoundError(f'Model weights not found: {model_path}') model.load_weights(model_path, by_name=True, skip_mismatch=True) logging.info('DenseNet-CTC weights loaded: %s (%d classes)', model_path, nclass) return model def greedy_ctc_decode(pred, alphabet): """ Greedy CTC decoding on raw softmax output. pred: (B, T, nclass) numpy array alphabet: string of characters (len = nclass - 1, last class is blank) Returns decoded string. """ # pred shape: (1, T, nclass) — take first batch pred_indices = np.argmax(pred[0], axis=-1) # (T,) nclass = pred.shape[-1] chars = [] prev = -1 for idx in pred_indices: # Skip blank (last class) and repeated indices if idx != nclass - 1 and idx != prev: if idx < len(alphabet): chars.append(alphabet[idx]) prev = idx return ''.join(chars)