starry / backend /python-services /scripts /export_tensorflow.py
k-l-lambda's picture
feat: add Python ML services (CPU mode) with model download
2b7aae2
#!/usr/bin/env python3
"""
Export TensorFlow/Keras models to SavedModel format.
This script should be run in the original starry-ocr environment
where the model definitions and weights are available.
Usage:
cd /home/camus/work/starry-ocr
python /path/to/export_tensorflow.py --mode ocr --config config.yaml --output ocr_savedmodel
Modes:
ocr - General OCR model (DenseNet-CTC)
tempo - Tempo numeral OCR model
brackets - Bracket recognition model
chord - Chord recognition model (Seq2Seq)
"""
import argparse
import os
import sys
# Add starry-ocr to path
STARRY_OCR_PATH = '/home/camus/work/starry-ocr'
if STARRY_OCR_PATH not in sys.path:
sys.path.insert(0, STARRY_OCR_PATH)
def export_ocr(config_path, output_path):
"""Export general OCR model to SavedModel."""
import yaml
import tensorflow as tf
# Limit GPU memory
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
from OCR_Test.densenet.model import Densenet
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
config_dir = os.path.dirname(config_path)
# Load alphabet
alphabet_path = config['generalOCR_alphabet_path']
if not os.path.isabs(alphabet_path):
alphabet_path = os.path.join(config_dir, alphabet_path)
alphabet = open(alphabet_path, 'r', encoding='utf-8').readline().strip()
# Load weights
weights_path = config['generalOCR_weight_path']
if not os.path.isabs(weights_path):
weights_path = os.path.join(config_dir, weights_path)
# Model config
densenetconfig = {
'first_conv_filters': 64,
'first_conv_size': 5,
'first_conv_stride': 2,
'dense_block_layers': [8, 8, 8],
'dense_block_growth_rate': 8,
'trans_block_filters': 128,
'first_pool_size': 0,
'first_pool_stride': 2,
'last_conv_size': 0,
'last_conv_filters': 0,
'last_pool_size': 2,
'need_feature_vector': False, # Disable for export
}
imageconfig = {
'hight': 32,
'width': 400,
'channel': 1,
}
# Create model
model = Densenet(
alphabet=alphabet,
modelPath=weights_path,
imageconfig=imageconfig,
densenetconfig=densenetconfig
)
# Get the underlying Keras model
keras_model = model.model
# Save as SavedModel
keras_model.save(output_path, save_format='tf')
print(f'OCR model exported to: {output_path}')
# Also save alphabet
with open(os.path.join(output_path, 'alphabet.txt'), 'w', encoding='utf-8') as f:
f.write(alphabet)
print(f'Alphabet saved to: {output_path}/alphabet.txt')
def export_tempo(config_path, output_path):
"""Export tempo numeral OCR model to SavedModel."""
import yaml
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
from OCR_Test.densenet.model import Densenet
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
config_dir = os.path.dirname(config_path)
alphabet_path = config['temponumOCR_alphabet_path']
if not os.path.isabs(alphabet_path):
alphabet_path = os.path.join(config_dir, alphabet_path)
alphabet = open(alphabet_path, 'r', encoding='utf-8').readline().strip()
weights_path = config['temponumOCR_weight_path']
if not os.path.isabs(weights_path):
weights_path = os.path.join(config_dir, weights_path)
densenetconfig = {
'first_conv_filters': 64,
'first_conv_size': 5,
'first_conv_stride': 2,
'dense_block_layers': [8, 8, 8],
'dense_block_growth_rate': 8,
'trans_block_filters': 128,
'first_pool_size': 0,
'first_pool_stride': 2,
'last_conv_size': 0,
'last_conv_filters': 0,
'last_pool_size': 2,
'need_feature_vector': False,
}
imageconfig = {
'hight': 32,
'width': 400,
'channel': 1,
}
model = Densenet(
alphabet=alphabet,
modelPath=weights_path,
imageconfig=imageconfig,
densenetconfig=densenetconfig
)
keras_model = model.model
keras_model.save(output_path, save_format='tf')
print(f'Tempo OCR model exported to: {output_path}')
with open(os.path.join(output_path, 'alphabet.txt'), 'w', encoding='utf-8') as f:
f.write(alphabet)
def export_brackets(config_path, output_path):
"""Export brackets OCR model to SavedModel."""
import yaml
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
from OCR_Test.densenet.model import Densenet
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
config_dir = os.path.dirname(config_path)
alphabet_path = config['bracket_alphabet_path']
if not os.path.isabs(alphabet_path):
alphabet_path = os.path.join(config_dir, alphabet_path)
alphabet = open(alphabet_path, 'r', encoding='utf-8').readline().strip()
weights_path = config['bracket_weight_path']
if not os.path.isabs(weights_path):
weights_path = os.path.join(config_dir, weights_path)
densenetconfig = {
'first_conv_filters': 64,
'first_conv_size': 5,
'first_conv_stride': 2,
'dense_block_layers': [8, 8, 8],
'dense_block_growth_rate': 8,
'trans_block_filters': 128,
'first_pool_size': 0,
'first_pool_stride': 2,
'last_conv_size': 0,
'last_conv_filters': 0,
'last_pool_size': 2,
}
imageconfig = {
'hight': 32,
'width': 400,
'channel': 1,
}
model = Densenet(
alphabet=alphabet,
modelPath=weights_path,
imageconfig=imageconfig,
densenetconfig=densenetconfig
)
keras_model = model.model
keras_model.save(output_path, save_format='tf')
print(f'Brackets model exported to: {output_path}')
with open(os.path.join(output_path, 'alphabet.txt'), 'w', encoding='utf-8') as f:
f.write(alphabet)
EXPORTERS = {
'ocr': export_ocr,
'tempo': export_tempo,
'brackets': export_brackets,
}
def main():
parser = argparse.ArgumentParser(description='Export TensorFlow models to SavedModel')
parser.add_argument('--mode', type=str, required=True, choices=list(EXPORTERS.keys()),
help='Model type to export')
parser.add_argument('--config', type=str, required=True,
help='Path to configuration YAML file')
parser.add_argument('--output', type=str, required=True,
help='Output SavedModel directory path')
args = parser.parse_args()
exporter = EXPORTERS[args.mode]
exporter(args.config, args.output)
if __name__ == '__main__':
main()