Spaces:
Running
Running
| """ | |
| DenseNet-CTC model builder for OCR/brackets recognition. | |
| Ports the DenseNet-CTC architecture from starry-ocr to TF 2.x Keras. | |
| Builds the model graph, loads weights-only .h5 files, and provides | |
| numpy-based CTC greedy decoding (no tf.Session needed). | |
| """ | |
| import os | |
| import logging | |
| import numpy as np | |
| os.environ.setdefault('TF_USE_LEGACY_KERAS', '1') | |
| import tensorflow as tf | |
| # Default architecture config (matches starry-ocr training) | |
| DEFAULT_DENSENET_CONFIG = { | |
| 'first_conv_filters': 64, | |
| 'first_conv_size': 5, | |
| 'first_conv_stride': 2, | |
| 'dense_block_layers': [8, 8, 8], | |
| 'dense_block_growth_rate': 8, | |
| 'trans_block_filters': 128, | |
| 'first_pool_size': 0, | |
| 'first_pool_stride': 2, | |
| 'last_conv_size': 0, | |
| 'last_conv_filters': 0, | |
| 'last_pool_size': 2, | |
| } | |
| DEFAULT_IMAGE_CONFIG = { | |
| 'height': 32, | |
| 'channel': 1, | |
| } | |
| def _conv_block(x, growth_rate): | |
| """Single dense block convolution: BN → ReLU → Conv2D(3×3).""" | |
| x = tf.keras.layers.BatchNormalization(axis=-1, epsilon=1.1e-5)(x) | |
| x = tf.keras.layers.Activation('relu')(x) | |
| x = tf.keras.layers.Conv2D(growth_rate, (3, 3), kernel_initializer='he_normal', padding='same')(x) | |
| return x | |
| def _dense_block(x, nb_layers, nb_filter, growth_rate): | |
| """Dense block: stack of conv_blocks with concatenation.""" | |
| for _ in range(nb_layers): | |
| cb = _conv_block(x, growth_rate) | |
| x = tf.keras.layers.Concatenate()([x, cb]) | |
| nb_filter += growth_rate | |
| return x, nb_filter | |
| def _transition_block(x, nb_filter, weight_decay=1e-4): | |
| """Transition block: BN → ReLU → 1×1 Conv → AvgPool(2×2).""" | |
| x = tf.keras.layers.BatchNormalization(axis=-1, epsilon=1.1e-5)(x) | |
| x = tf.keras.layers.Activation('relu')(x) | |
| x = tf.keras.layers.Conv2D( | |
| nb_filter, (1, 1), kernel_initializer='he_normal', padding='same', | |
| use_bias=False, kernel_regularizer=tf.keras.regularizers.l2(weight_decay) | |
| )(x) | |
| x = tf.keras.layers.AveragePooling2D((2, 2), strides=(2, 2))(x) | |
| return x, nb_filter | |
| def build_densenet_ctc(nclass, cfg=None, img_cfg=None): | |
| """ | |
| Build the DenseNet-CTC model matching the starry-ocr architecture. | |
| Returns a Keras Model: input (B, H, W?, 1) → softmax (B, T, nclass) | |
| """ | |
| cfg = cfg or DEFAULT_DENSENET_CONFIG | |
| img_cfg = img_cfg or DEFAULT_IMAGE_CONFIG | |
| height = img_cfg['height'] | |
| channels = img_cfg['channel'] | |
| weight_decay = 1e-4 | |
| inp = tf.keras.Input(shape=(height, None, channels), name='the_input') | |
| x = inp | |
| # Attention module | |
| a = tf.keras.layers.Permute((2, 1, 3), name='permute_first')(x) | |
| attention_ratio = 64 if height > 64 else height | |
| a = tf.keras.layers.Dense(attention_ratio, activation='softmax')(a) | |
| a = tf.keras.layers.Permute((2, 1, 3), name='attention_vec')(a) | |
| x = tf.keras.layers.Multiply(name='attention_mul')([x, a]) | |
| # Initial convolution | |
| nb_filter = cfg['first_conv_filters'] | |
| x = tf.keras.layers.Conv2D( | |
| nb_filter, cfg['first_conv_size'], strides=cfg['first_conv_stride'], | |
| padding='same', use_bias=False, | |
| kernel_regularizer=tf.keras.regularizers.l2(weight_decay) | |
| )(x) | |
| if cfg['first_pool_size']: | |
| x = tf.keras.layers.AveragePooling2D( | |
| cfg['first_pool_size'], strides=cfg['first_pool_stride'] | |
| )(x) | |
| # Dense blocks + transitions | |
| nb_layers = cfg['dense_block_layers'] | |
| growth_rate = cfg['dense_block_growth_rate'] | |
| for n_layer in nb_layers[:-1]: | |
| x, nb_filter = _dense_block(x, n_layer, nb_filter, growth_rate) | |
| trans_filters = cfg['trans_block_filters'] or nb_filter // 2 | |
| x, nb_filter = _transition_block(x, trans_filters) | |
| x, nb_filter = _dense_block(x, nb_layers[-1], nb_filter, growth_rate) | |
| if cfg['last_conv_size']: | |
| conv_filters = cfg['last_conv_filters'] or nb_filter | |
| x = tf.keras.layers.BatchNormalization(axis=-1, epsilon=1.1e-5)(x) | |
| x = tf.keras.layers.Conv2D( | |
| conv_filters, cfg['last_conv_size'], kernel_initializer='he_normal', | |
| padding='same', use_bias=False, | |
| kernel_regularizer=tf.keras.regularizers.l2(weight_decay) | |
| )(x) | |
| x = tf.keras.layers.AveragePooling2D(cfg['last_pool_size'], strides=2)(x) | |
| # Final BN + ReLU | |
| x = tf.keras.layers.BatchNormalization(axis=-1, epsilon=1.1e-5)(x) | |
| x = tf.keras.layers.Activation('relu')(x) | |
| # Reshape to sequence: (B, W, H*C) → time-distributed flatten | |
| x = tf.keras.layers.Permute((2, 1, 3), name='permute')(x) | |
| x = tf.keras.layers.TimeDistributed(tf.keras.layers.Flatten(), name='flatten')(x) | |
| # Softmax output | |
| y_pred = tf.keras.layers.Dense(nclass, name='out', activation='softmax')(x) | |
| model = tf.keras.Model(inputs=inp, outputs=y_pred, name='densenet_ctc') | |
| return model | |
| def load_densenet_ctc(model_path, nclass, cfg=None, img_cfg=None): | |
| """ | |
| Build model and load weights from .h5 file. | |
| The .h5 files from starry-ocr are weights-only (no model_config). | |
| The original model had additional input_length + CtcDecodeLayer, | |
| but those layers have no trainable weights, so by_name loading works. | |
| """ | |
| model = build_densenet_ctc(nclass, cfg, img_cfg) | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f'Model weights not found: {model_path}') | |
| model.load_weights(model_path, by_name=True, skip_mismatch=True) | |
| logging.info('DenseNet-CTC weights loaded: %s (%d classes)', model_path, nclass) | |
| return model | |
| def greedy_ctc_decode(pred, alphabet): | |
| """ | |
| Greedy CTC decoding on raw softmax output. | |
| pred: (B, T, nclass) numpy array | |
| alphabet: string of characters (len = nclass - 1, last class is blank) | |
| Returns decoded string. | |
| """ | |
| # pred shape: (1, T, nclass) — take first batch | |
| pred_indices = np.argmax(pred[0], axis=-1) # (T,) | |
| nclass = pred.shape[-1] | |
| chars = [] | |
| prev = -1 | |
| for idx in pred_indices: | |
| # Skip blank (last class) and repeated indices | |
| if idx != nclass - 1 and idx != prev: | |
| if idx < len(alphabet): | |
| chars.append(alphabet[idx]) | |
| prev = idx | |
| return ''.join(chars) | |