Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Export TensorFlow/Keras models to SavedModel format. | |
| This script should be run in the original starry-ocr environment | |
| where the model definitions and weights are available. | |
| Usage: | |
| cd /home/camus/work/starry-ocr | |
| python /path/to/export_tensorflow.py --mode ocr --config config.yaml --output ocr_savedmodel | |
| Modes: | |
| ocr - General OCR model (DenseNet-CTC) | |
| tempo - Tempo numeral OCR model | |
| brackets - Bracket recognition model | |
| chord - Chord recognition model (Seq2Seq) | |
| """ | |
| import argparse | |
| import os | |
| import sys | |
| # Add starry-ocr to path | |
| STARRY_OCR_PATH = '/home/camus/work/starry-ocr' | |
| if STARRY_OCR_PATH not in sys.path: | |
| sys.path.insert(0, STARRY_OCR_PATH) | |
| def export_ocr(config_path, output_path): | |
| """Export general OCR model to SavedModel.""" | |
| import yaml | |
| import tensorflow as tf | |
| # Limit GPU memory | |
| gpus = tf.config.experimental.list_physical_devices('GPU') | |
| if gpus: | |
| for gpu in gpus: | |
| tf.config.experimental.set_memory_growth(gpu, True) | |
| from OCR_Test.densenet.model import Densenet | |
| with open(config_path, 'r', encoding='utf-8') as f: | |
| config = yaml.safe_load(f) | |
| config_dir = os.path.dirname(config_path) | |
| # Load alphabet | |
| alphabet_path = config['generalOCR_alphabet_path'] | |
| if not os.path.isabs(alphabet_path): | |
| alphabet_path = os.path.join(config_dir, alphabet_path) | |
| alphabet = open(alphabet_path, 'r', encoding='utf-8').readline().strip() | |
| # Load weights | |
| weights_path = config['generalOCR_weight_path'] | |
| if not os.path.isabs(weights_path): | |
| weights_path = os.path.join(config_dir, weights_path) | |
| # Model config | |
| densenetconfig = { | |
| 'first_conv_filters': 64, | |
| 'first_conv_size': 5, | |
| 'first_conv_stride': 2, | |
| 'dense_block_layers': [8, 8, 8], | |
| 'dense_block_growth_rate': 8, | |
| 'trans_block_filters': 128, | |
| 'first_pool_size': 0, | |
| 'first_pool_stride': 2, | |
| 'last_conv_size': 0, | |
| 'last_conv_filters': 0, | |
| 'last_pool_size': 2, | |
| 'need_feature_vector': False, # Disable for export | |
| } | |
| imageconfig = { | |
| 'hight': 32, | |
| 'width': 400, | |
| 'channel': 1, | |
| } | |
| # Create model | |
| model = Densenet( | |
| alphabet=alphabet, | |
| modelPath=weights_path, | |
| imageconfig=imageconfig, | |
| densenetconfig=densenetconfig | |
| ) | |
| # Get the underlying Keras model | |
| keras_model = model.model | |
| # Save as SavedModel | |
| keras_model.save(output_path, save_format='tf') | |
| print(f'OCR model exported to: {output_path}') | |
| # Also save alphabet | |
| with open(os.path.join(output_path, 'alphabet.txt'), 'w', encoding='utf-8') as f: | |
| f.write(alphabet) | |
| print(f'Alphabet saved to: {output_path}/alphabet.txt') | |
| def export_tempo(config_path, output_path): | |
| """Export tempo numeral OCR model to SavedModel.""" | |
| import yaml | |
| import tensorflow as tf | |
| gpus = tf.config.experimental.list_physical_devices('GPU') | |
| if gpus: | |
| for gpu in gpus: | |
| tf.config.experimental.set_memory_growth(gpu, True) | |
| from OCR_Test.densenet.model import Densenet | |
| with open(config_path, 'r', encoding='utf-8') as f: | |
| config = yaml.safe_load(f) | |
| config_dir = os.path.dirname(config_path) | |
| alphabet_path = config['temponumOCR_alphabet_path'] | |
| if not os.path.isabs(alphabet_path): | |
| alphabet_path = os.path.join(config_dir, alphabet_path) | |
| alphabet = open(alphabet_path, 'r', encoding='utf-8').readline().strip() | |
| weights_path = config['temponumOCR_weight_path'] | |
| if not os.path.isabs(weights_path): | |
| weights_path = os.path.join(config_dir, weights_path) | |
| densenetconfig = { | |
| 'first_conv_filters': 64, | |
| 'first_conv_size': 5, | |
| 'first_conv_stride': 2, | |
| 'dense_block_layers': [8, 8, 8], | |
| 'dense_block_growth_rate': 8, | |
| 'trans_block_filters': 128, | |
| 'first_pool_size': 0, | |
| 'first_pool_stride': 2, | |
| 'last_conv_size': 0, | |
| 'last_conv_filters': 0, | |
| 'last_pool_size': 2, | |
| 'need_feature_vector': False, | |
| } | |
| imageconfig = { | |
| 'hight': 32, | |
| 'width': 400, | |
| 'channel': 1, | |
| } | |
| model = Densenet( | |
| alphabet=alphabet, | |
| modelPath=weights_path, | |
| imageconfig=imageconfig, | |
| densenetconfig=densenetconfig | |
| ) | |
| keras_model = model.model | |
| keras_model.save(output_path, save_format='tf') | |
| print(f'Tempo OCR model exported to: {output_path}') | |
| with open(os.path.join(output_path, 'alphabet.txt'), 'w', encoding='utf-8') as f: | |
| f.write(alphabet) | |
| def export_brackets(config_path, output_path): | |
| """Export brackets OCR model to SavedModel.""" | |
| import yaml | |
| import tensorflow as tf | |
| gpus = tf.config.experimental.list_physical_devices('GPU') | |
| if gpus: | |
| for gpu in gpus: | |
| tf.config.experimental.set_memory_growth(gpu, True) | |
| from OCR_Test.densenet.model import Densenet | |
| with open(config_path, 'r', encoding='utf-8') as f: | |
| config = yaml.safe_load(f) | |
| config_dir = os.path.dirname(config_path) | |
| alphabet_path = config['bracket_alphabet_path'] | |
| if not os.path.isabs(alphabet_path): | |
| alphabet_path = os.path.join(config_dir, alphabet_path) | |
| alphabet = open(alphabet_path, 'r', encoding='utf-8').readline().strip() | |
| weights_path = config['bracket_weight_path'] | |
| if not os.path.isabs(weights_path): | |
| weights_path = os.path.join(config_dir, weights_path) | |
| densenetconfig = { | |
| 'first_conv_filters': 64, | |
| 'first_conv_size': 5, | |
| 'first_conv_stride': 2, | |
| 'dense_block_layers': [8, 8, 8], | |
| 'dense_block_growth_rate': 8, | |
| 'trans_block_filters': 128, | |
| 'first_pool_size': 0, | |
| 'first_pool_stride': 2, | |
| 'last_conv_size': 0, | |
| 'last_conv_filters': 0, | |
| 'last_pool_size': 2, | |
| } | |
| imageconfig = { | |
| 'hight': 32, | |
| 'width': 400, | |
| 'channel': 1, | |
| } | |
| model = Densenet( | |
| alphabet=alphabet, | |
| modelPath=weights_path, | |
| imageconfig=imageconfig, | |
| densenetconfig=densenetconfig | |
| ) | |
| keras_model = model.model | |
| keras_model.save(output_path, save_format='tf') | |
| print(f'Brackets model exported to: {output_path}') | |
| with open(os.path.join(output_path, 'alphabet.txt'), 'w', encoding='utf-8') as f: | |
| f.write(alphabet) | |
| EXPORTERS = { | |
| 'ocr': export_ocr, | |
| 'tempo': export_tempo, | |
| 'brackets': export_brackets, | |
| } | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Export TensorFlow models to SavedModel') | |
| parser.add_argument('--mode', type=str, required=True, choices=list(EXPORTERS.keys()), | |
| help='Model type to export') | |
| parser.add_argument('--config', type=str, required=True, | |
| help='Path to configuration YAML file') | |
| parser.add_argument('--output', type=str, required=True, | |
| help='Output SavedModel directory path') | |
| args = parser.parse_args() | |
| exporter = EXPORTERS[args.mode] | |
| exporter(args.config, args.output) | |
| if __name__ == '__main__': | |
| main() | |