#!/usr/bin/env python3 """ Unified entry point for STARRY ML prediction services. Usage: python main.py -m layout -w models/layout.pt -p 12022 -dv cuda python main.py -m semantic -w models/semantic.pt -p 12025 -dv cuda --config config.yaml Available modes: layout - Page layout detection (port 12022) mask - Staff mask generation (port 12024) semantic - Symbol semantic detection (port 12025) gauge - Staff gauge prediction (port 12023) loc - Text location detection (port 12026) ocr - Text recognition (port 12027) brackets - Bracket recognition (port 12028) """ import argparse import importlib import logging import yaml import os # Service class mapping SERVICE_MAP = { 'layout': 'services.layout_service.LayoutService', 'mask': 'services.mask_service.MaskService', 'semantic': 'services.semantic_service.SemanticService', 'gauge': 'services.gauge_service.GaugeService', 'loc': 'services.loc_service.LocService', 'ocr': 'services.ocr_service.OcrService', 'brackets': 'services.brackets_service.BracketsService', } # Default ports DEFAULT_PORTS = { 'layout': 12022, 'gauge': 12023, 'mask': 12024, 'semantic': 12025, 'loc': 12026, 'ocr': 12027, 'brackets': 12028, } def import_class(class_path): """Dynamically import a class from module path.""" module_path, class_name = class_path.rsplit('.', 1) module = importlib.import_module(module_path) return getattr(module, class_name) def load_config(config_path): """Load configuration from YAML file.""" if config_path and os.path.exists(config_path): with open(config_path, 'r', encoding='utf-8') as f: return yaml.safe_load(f) return {} def resolve_ocr_config(yaml_path): """Parse OCR/brackets config YAML and resolve model/alphabet paths. The config YAML may contain: generalOCR_weight_path, generalOCR_alphabet_path, temponumOCR_weight_path, temponumOCR_alphabet_path, bracket_weight_path, bracket_alphabet_path, chord_config_weight_path Paths are relative to the YAML file's directory. """ base_dir = os.path.dirname(os.path.abspath(yaml_path)) with open(yaml_path, 'r', encoding='utf-8') as f: cfg = yaml.safe_load(f) or {} def abs_path(rel): if rel and not os.path.isabs(rel): return os.path.join(base_dir, rel) return rel def read_alphabet(path): if path and os.path.exists(path): with open(path, 'r', encoding='utf-8') as f: return f.readline().strip() return None result = {} # General OCR model gen_weight = abs_path(cfg.get('generalOCR_weight_path')) if not gen_weight: # Auto-detect h5 model file alongside alphabet alpha_path = abs_path(cfg.get('generalOCR_alphabet_path')) if alpha_path: alpha_dir = os.path.dirname(alpha_path) h5_files = [f for f in os.listdir(alpha_dir) if f.endswith('.h5')] if h5_files: gen_weight = os.path.join(alpha_dir, h5_files[0]) if gen_weight: result['model_path'] = gen_weight gen_alpha = abs_path(cfg.get('generalOCR_alphabet_path')) alpha = read_alphabet(gen_alpha) if alpha: result['alphabet'] = alpha # Tempo numeral model tempo_weight = abs_path(cfg.get('temponumOCR_weight_path')) if tempo_weight: result['tempo_model_path'] = tempo_weight tempo_alpha_path = abs_path(cfg.get('temponumOCR_alphabet_path')) tempo_alpha = read_alphabet(tempo_alpha_path) if tempo_alpha: result['tempo_alphabet'] = tempo_alpha # Chord model chord_weight = abs_path(cfg.get('chord_config_weight_path')) if chord_weight: result['chord_model_path'] = chord_weight return result def resolve_brackets_config(yaml_path): """Parse brackets config YAML and resolve model/alphabet paths.""" base_dir = os.path.dirname(os.path.abspath(yaml_path)) with open(yaml_path, 'r', encoding='utf-8') as f: cfg = yaml.safe_load(f) or {} def abs_path(rel): if rel and not os.path.isabs(rel): return os.path.join(base_dir, rel) return rel result = {} bracket_weight = abs_path(cfg.get('bracket_weight_path')) if bracket_weight: result['model_path'] = bracket_weight bracket_alpha_path = abs_path(cfg.get('bracket_alphabet_path')) if bracket_alpha_path and os.path.exists(bracket_alpha_path): with open(bracket_alpha_path, 'r', encoding='utf-8') as f: result['alphabet'] = f.readline().strip() return result def setup_logging(mode, level='INFO'): """Configure logging.""" logging.basicConfig( level=getattr(logging, level.upper()), format=f'[%(asctime)s] [{mode}] %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) def main(): parser = argparse.ArgumentParser( description='STARRY ML prediction service', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=__doc__ ) parser.add_argument( '-m', '--mode', type=str, required=True, choices=list(SERVICE_MAP.keys()), help='Service mode to run' ) parser.add_argument( '-w', '--weights', type=str, required=True, help='Path to model weights file (TorchScript .pt or SavedModel directory)' ) parser.add_argument( '-p', '--port', type=int, default=None, help='ZeroMQ server port (default: mode-specific)' ) parser.add_argument( '-dv', '--device', type=str, default='cuda', help='Device to use: cuda or cpu (default: cuda)' ) parser.add_argument( '--config', type=str, default=None, help='Path to service configuration YAML file' ) parser.add_argument( '--log-level', type=str, default='INFO', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], help='Logging level (default: INFO)' ) # Service-specific arguments parser.add_argument( '--slicing-width', type=int, default=512, help='Slicing width for mask/semantic/gauge services' ) parser.add_argument( '--labels', type=str, nargs='+', default=None, help='Semantic labels (for semantic service)' ) parser.add_argument( '--image-short-side', type=int, default=736, help='Image short side for loc service' ) parser.add_argument( '--alphabet', type=str, default=None, help='Character alphabet file for OCR/brackets services' ) args = parser.parse_args() # Setup logging setup_logging(args.mode, args.log_level) # Load config if provided config = load_config(args.config) # Determine port port = args.port or DEFAULT_PORTS.get(args.mode, 12020) # Get service class if args.mode not in SERVICE_MAP: logging.error('Unknown service mode: %s', args.mode) return 1 ServiceClass = import_class(SERVICE_MAP[args.mode]) # Build service kwargs service_kwargs = { 'model_path': args.weights, 'device': args.device, } # Handle OCR/brackets YAML config passed via -w if args.mode == 'ocr' and args.weights.endswith('.yaml'): logging.info('Resolving OCR config from: %s', args.weights) ocr_cfg = resolve_ocr_config(args.weights) service_kwargs.update(ocr_cfg) elif args.mode == 'brackets' and args.weights.endswith('.yaml'): logging.info('Resolving brackets config from: %s', args.weights) br_cfg = resolve_brackets_config(args.weights) service_kwargs.update(br_cfg) # Add service-specific kwargs if args.mode in ['mask', 'semantic', 'gauge']: service_kwargs['slicing_width'] = args.slicing_width if args.mode == 'semantic': if args.labels: service_kwargs['labels'] = args.labels elif 'labels' in config: service_kwargs['labels'] = config['labels'] if args.mode == 'loc': service_kwargs['image_short_side'] = args.image_short_side if args.mode in ['ocr', 'brackets'] and not args.weights.endswith('.yaml'): if args.alphabet: with open(args.alphabet, 'r', encoding='utf-8') as f: service_kwargs['alphabet'] = f.readline().strip() elif 'alphabet' in config: service_kwargs['alphabet'] = config['alphabet'] # Merge config for key, value in config.items(): if key not in service_kwargs: service_kwargs[key] = value # Create service instance logging.info('Initializing %s service...', args.mode) logging.info('Model path: %s', args.weights) logging.info('Device: %s', args.device) try: service = ServiceClass(**service_kwargs) except Exception as e: logging.error('Failed to initialize service: %s', str(e)) raise # Start ZeroMQ server from common.zero_server import ZeroServer logging.info('Starting ZeroMQ server on port %d...', port) server = ZeroServer(service) server.bind(port) if __name__ == '__main__': main()