k-l-lambda's picture
feat: add Python ML services (CPU mode) with model download
2b7aae2
#!/usr/bin/env python3
"""
Unified entry point for STARRY ML prediction services.
Usage:
python main.py -m layout -w models/layout.pt -p 12022 -dv cuda
python main.py -m semantic -w models/semantic.pt -p 12025 -dv cuda --config config.yaml
Available modes:
layout - Page layout detection (port 12022)
mask - Staff mask generation (port 12024)
semantic - Symbol semantic detection (port 12025)
gauge - Staff gauge prediction (port 12023)
loc - Text location detection (port 12026)
ocr - Text recognition (port 12027)
brackets - Bracket recognition (port 12028)
"""
import argparse
import importlib
import logging
import yaml
import os
# Service class mapping
SERVICE_MAP = {
'layout': 'services.layout_service.LayoutService',
'mask': 'services.mask_service.MaskService',
'semantic': 'services.semantic_service.SemanticService',
'gauge': 'services.gauge_service.GaugeService',
'loc': 'services.loc_service.LocService',
'ocr': 'services.ocr_service.OcrService',
'brackets': 'services.brackets_service.BracketsService',
}
# Default ports
DEFAULT_PORTS = {
'layout': 12022,
'gauge': 12023,
'mask': 12024,
'semantic': 12025,
'loc': 12026,
'ocr': 12027,
'brackets': 12028,
}
def import_class(class_path):
"""Dynamically import a class from module path."""
module_path, class_name = class_path.rsplit('.', 1)
module = importlib.import_module(module_path)
return getattr(module, class_name)
def load_config(config_path):
"""Load configuration from YAML file."""
if config_path and os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
return {}
def resolve_ocr_config(yaml_path):
"""Parse OCR/brackets config YAML and resolve model/alphabet paths.
The config YAML may contain:
generalOCR_weight_path, generalOCR_alphabet_path,
temponumOCR_weight_path, temponumOCR_alphabet_path,
bracket_weight_path, bracket_alphabet_path,
chord_config_weight_path
Paths are relative to the YAML file's directory.
"""
base_dir = os.path.dirname(os.path.abspath(yaml_path))
with open(yaml_path, 'r', encoding='utf-8') as f:
cfg = yaml.safe_load(f) or {}
def abs_path(rel):
if rel and not os.path.isabs(rel):
return os.path.join(base_dir, rel)
return rel
def read_alphabet(path):
if path and os.path.exists(path):
with open(path, 'r', encoding='utf-8') as f:
return f.readline().strip()
return None
result = {}
# General OCR model
gen_weight = abs_path(cfg.get('generalOCR_weight_path'))
if not gen_weight:
# Auto-detect h5 model file alongside alphabet
alpha_path = abs_path(cfg.get('generalOCR_alphabet_path'))
if alpha_path:
alpha_dir = os.path.dirname(alpha_path)
h5_files = [f for f in os.listdir(alpha_dir) if f.endswith('.h5')]
if h5_files:
gen_weight = os.path.join(alpha_dir, h5_files[0])
if gen_weight:
result['model_path'] = gen_weight
gen_alpha = abs_path(cfg.get('generalOCR_alphabet_path'))
alpha = read_alphabet(gen_alpha)
if alpha:
result['alphabet'] = alpha
# Tempo numeral model
tempo_weight = abs_path(cfg.get('temponumOCR_weight_path'))
if tempo_weight:
result['tempo_model_path'] = tempo_weight
tempo_alpha_path = abs_path(cfg.get('temponumOCR_alphabet_path'))
tempo_alpha = read_alphabet(tempo_alpha_path)
if tempo_alpha:
result['tempo_alphabet'] = tempo_alpha
# Chord model
chord_weight = abs_path(cfg.get('chord_config_weight_path'))
if chord_weight:
result['chord_model_path'] = chord_weight
return result
def resolve_brackets_config(yaml_path):
"""Parse brackets config YAML and resolve model/alphabet paths."""
base_dir = os.path.dirname(os.path.abspath(yaml_path))
with open(yaml_path, 'r', encoding='utf-8') as f:
cfg = yaml.safe_load(f) or {}
def abs_path(rel):
if rel and not os.path.isabs(rel):
return os.path.join(base_dir, rel)
return rel
result = {}
bracket_weight = abs_path(cfg.get('bracket_weight_path'))
if bracket_weight:
result['model_path'] = bracket_weight
bracket_alpha_path = abs_path(cfg.get('bracket_alphabet_path'))
if bracket_alpha_path and os.path.exists(bracket_alpha_path):
with open(bracket_alpha_path, 'r', encoding='utf-8') as f:
result['alphabet'] = f.readline().strip()
return result
def setup_logging(mode, level='INFO'):
"""Configure logging."""
logging.basicConfig(
level=getattr(logging, level.upper()),
format=f'[%(asctime)s] [{mode}] %(levelname)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
def main():
parser = argparse.ArgumentParser(
description='STARRY ML prediction service',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__
)
parser.add_argument(
'-m', '--mode',
type=str,
required=True,
choices=list(SERVICE_MAP.keys()),
help='Service mode to run'
)
parser.add_argument(
'-w', '--weights',
type=str,
required=True,
help='Path to model weights file (TorchScript .pt or SavedModel directory)'
)
parser.add_argument(
'-p', '--port',
type=int,
default=None,
help='ZeroMQ server port (default: mode-specific)'
)
parser.add_argument(
'-dv', '--device',
type=str,
default='cuda',
help='Device to use: cuda or cpu (default: cuda)'
)
parser.add_argument(
'--config',
type=str,
default=None,
help='Path to service configuration YAML file'
)
parser.add_argument(
'--log-level',
type=str,
default='INFO',
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
help='Logging level (default: INFO)'
)
# Service-specific arguments
parser.add_argument(
'--slicing-width',
type=int,
default=512,
help='Slicing width for mask/semantic/gauge services'
)
parser.add_argument(
'--labels',
type=str,
nargs='+',
default=None,
help='Semantic labels (for semantic service)'
)
parser.add_argument(
'--image-short-side',
type=int,
default=736,
help='Image short side for loc service'
)
parser.add_argument(
'--alphabet',
type=str,
default=None,
help='Character alphabet file for OCR/brackets services'
)
args = parser.parse_args()
# Setup logging
setup_logging(args.mode, args.log_level)
# Load config if provided
config = load_config(args.config)
# Determine port
port = args.port or DEFAULT_PORTS.get(args.mode, 12020)
# Get service class
if args.mode not in SERVICE_MAP:
logging.error('Unknown service mode: %s', args.mode)
return 1
ServiceClass = import_class(SERVICE_MAP[args.mode])
# Build service kwargs
service_kwargs = {
'model_path': args.weights,
'device': args.device,
}
# Handle OCR/brackets YAML config passed via -w
if args.mode == 'ocr' and args.weights.endswith('.yaml'):
logging.info('Resolving OCR config from: %s', args.weights)
ocr_cfg = resolve_ocr_config(args.weights)
service_kwargs.update(ocr_cfg)
elif args.mode == 'brackets' and args.weights.endswith('.yaml'):
logging.info('Resolving brackets config from: %s', args.weights)
br_cfg = resolve_brackets_config(args.weights)
service_kwargs.update(br_cfg)
# Add service-specific kwargs
if args.mode in ['mask', 'semantic', 'gauge']:
service_kwargs['slicing_width'] = args.slicing_width
if args.mode == 'semantic':
if args.labels:
service_kwargs['labels'] = args.labels
elif 'labels' in config:
service_kwargs['labels'] = config['labels']
if args.mode == 'loc':
service_kwargs['image_short_side'] = args.image_short_side
if args.mode in ['ocr', 'brackets'] and not args.weights.endswith('.yaml'):
if args.alphabet:
with open(args.alphabet, 'r', encoding='utf-8') as f:
service_kwargs['alphabet'] = f.readline().strip()
elif 'alphabet' in config:
service_kwargs['alphabet'] = config['alphabet']
# Merge config
for key, value in config.items():
if key not in service_kwargs:
service_kwargs[key] = value
# Create service instance
logging.info('Initializing %s service...', args.mode)
logging.info('Model path: %s', args.weights)
logging.info('Device: %s', args.device)
try:
service = ServiceClass(**service_kwargs)
except Exception as e:
logging.error('Failed to initialize service: %s', str(e))
raise
# Start ZeroMQ server
from common.zero_server import ZeroServer
logging.info('Starting ZeroMQ server on port %d...', port)
server = ZeroServer(service)
server.bind(port)
if __name__ == '__main__':
main()