Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Unified entry point for STARRY ML prediction services. | |
| Usage: | |
| python main.py -m layout -w models/layout.pt -p 12022 -dv cuda | |
| python main.py -m semantic -w models/semantic.pt -p 12025 -dv cuda --config config.yaml | |
| Available modes: | |
| layout - Page layout detection (port 12022) | |
| mask - Staff mask generation (port 12024) | |
| semantic - Symbol semantic detection (port 12025) | |
| gauge - Staff gauge prediction (port 12023) | |
| loc - Text location detection (port 12026) | |
| ocr - Text recognition (port 12027) | |
| brackets - Bracket recognition (port 12028) | |
| """ | |
| import argparse | |
| import importlib | |
| import logging | |
| import yaml | |
| import os | |
| # Service class mapping | |
| SERVICE_MAP = { | |
| 'layout': 'services.layout_service.LayoutService', | |
| 'mask': 'services.mask_service.MaskService', | |
| 'semantic': 'services.semantic_service.SemanticService', | |
| 'gauge': 'services.gauge_service.GaugeService', | |
| 'loc': 'services.loc_service.LocService', | |
| 'ocr': 'services.ocr_service.OcrService', | |
| 'brackets': 'services.brackets_service.BracketsService', | |
| } | |
| # Default ports | |
| DEFAULT_PORTS = { | |
| 'layout': 12022, | |
| 'gauge': 12023, | |
| 'mask': 12024, | |
| 'semantic': 12025, | |
| 'loc': 12026, | |
| 'ocr': 12027, | |
| 'brackets': 12028, | |
| } | |
| def import_class(class_path): | |
| """Dynamically import a class from module path.""" | |
| module_path, class_name = class_path.rsplit('.', 1) | |
| module = importlib.import_module(module_path) | |
| return getattr(module, class_name) | |
| def load_config(config_path): | |
| """Load configuration from YAML file.""" | |
| if config_path and os.path.exists(config_path): | |
| with open(config_path, 'r', encoding='utf-8') as f: | |
| return yaml.safe_load(f) | |
| return {} | |
| def resolve_ocr_config(yaml_path): | |
| """Parse OCR/brackets config YAML and resolve model/alphabet paths. | |
| The config YAML may contain: | |
| generalOCR_weight_path, generalOCR_alphabet_path, | |
| temponumOCR_weight_path, temponumOCR_alphabet_path, | |
| bracket_weight_path, bracket_alphabet_path, | |
| chord_config_weight_path | |
| Paths are relative to the YAML file's directory. | |
| """ | |
| base_dir = os.path.dirname(os.path.abspath(yaml_path)) | |
| with open(yaml_path, 'r', encoding='utf-8') as f: | |
| cfg = yaml.safe_load(f) or {} | |
| def abs_path(rel): | |
| if rel and not os.path.isabs(rel): | |
| return os.path.join(base_dir, rel) | |
| return rel | |
| def read_alphabet(path): | |
| if path and os.path.exists(path): | |
| with open(path, 'r', encoding='utf-8') as f: | |
| return f.readline().strip() | |
| return None | |
| result = {} | |
| # General OCR model | |
| gen_weight = abs_path(cfg.get('generalOCR_weight_path')) | |
| if not gen_weight: | |
| # Auto-detect h5 model file alongside alphabet | |
| alpha_path = abs_path(cfg.get('generalOCR_alphabet_path')) | |
| if alpha_path: | |
| alpha_dir = os.path.dirname(alpha_path) | |
| h5_files = [f for f in os.listdir(alpha_dir) if f.endswith('.h5')] | |
| if h5_files: | |
| gen_weight = os.path.join(alpha_dir, h5_files[0]) | |
| if gen_weight: | |
| result['model_path'] = gen_weight | |
| gen_alpha = abs_path(cfg.get('generalOCR_alphabet_path')) | |
| alpha = read_alphabet(gen_alpha) | |
| if alpha: | |
| result['alphabet'] = alpha | |
| # Tempo numeral model | |
| tempo_weight = abs_path(cfg.get('temponumOCR_weight_path')) | |
| if tempo_weight: | |
| result['tempo_model_path'] = tempo_weight | |
| tempo_alpha_path = abs_path(cfg.get('temponumOCR_alphabet_path')) | |
| tempo_alpha = read_alphabet(tempo_alpha_path) | |
| if tempo_alpha: | |
| result['tempo_alphabet'] = tempo_alpha | |
| # Chord model | |
| chord_weight = abs_path(cfg.get('chord_config_weight_path')) | |
| if chord_weight: | |
| result['chord_model_path'] = chord_weight | |
| return result | |
| def resolve_brackets_config(yaml_path): | |
| """Parse brackets config YAML and resolve model/alphabet paths.""" | |
| base_dir = os.path.dirname(os.path.abspath(yaml_path)) | |
| with open(yaml_path, 'r', encoding='utf-8') as f: | |
| cfg = yaml.safe_load(f) or {} | |
| def abs_path(rel): | |
| if rel and not os.path.isabs(rel): | |
| return os.path.join(base_dir, rel) | |
| return rel | |
| result = {} | |
| bracket_weight = abs_path(cfg.get('bracket_weight_path')) | |
| if bracket_weight: | |
| result['model_path'] = bracket_weight | |
| bracket_alpha_path = abs_path(cfg.get('bracket_alphabet_path')) | |
| if bracket_alpha_path and os.path.exists(bracket_alpha_path): | |
| with open(bracket_alpha_path, 'r', encoding='utf-8') as f: | |
| result['alphabet'] = f.readline().strip() | |
| return result | |
| def setup_logging(mode, level='INFO'): | |
| """Configure logging.""" | |
| logging.basicConfig( | |
| level=getattr(logging, level.upper()), | |
| format=f'[%(asctime)s] [{mode}] %(levelname)s: %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S' | |
| ) | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description='STARRY ML prediction service', | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=__doc__ | |
| ) | |
| parser.add_argument( | |
| '-m', '--mode', | |
| type=str, | |
| required=True, | |
| choices=list(SERVICE_MAP.keys()), | |
| help='Service mode to run' | |
| ) | |
| parser.add_argument( | |
| '-w', '--weights', | |
| type=str, | |
| required=True, | |
| help='Path to model weights file (TorchScript .pt or SavedModel directory)' | |
| ) | |
| parser.add_argument( | |
| '-p', '--port', | |
| type=int, | |
| default=None, | |
| help='ZeroMQ server port (default: mode-specific)' | |
| ) | |
| parser.add_argument( | |
| '-dv', '--device', | |
| type=str, | |
| default='cuda', | |
| help='Device to use: cuda or cpu (default: cuda)' | |
| ) | |
| parser.add_argument( | |
| '--config', | |
| type=str, | |
| default=None, | |
| help='Path to service configuration YAML file' | |
| ) | |
| parser.add_argument( | |
| '--log-level', | |
| type=str, | |
| default='INFO', | |
| choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], | |
| help='Logging level (default: INFO)' | |
| ) | |
| # Service-specific arguments | |
| parser.add_argument( | |
| '--slicing-width', | |
| type=int, | |
| default=512, | |
| help='Slicing width for mask/semantic/gauge services' | |
| ) | |
| parser.add_argument( | |
| '--labels', | |
| type=str, | |
| nargs='+', | |
| default=None, | |
| help='Semantic labels (for semantic service)' | |
| ) | |
| parser.add_argument( | |
| '--image-short-side', | |
| type=int, | |
| default=736, | |
| help='Image short side for loc service' | |
| ) | |
| parser.add_argument( | |
| '--alphabet', | |
| type=str, | |
| default=None, | |
| help='Character alphabet file for OCR/brackets services' | |
| ) | |
| args = parser.parse_args() | |
| # Setup logging | |
| setup_logging(args.mode, args.log_level) | |
| # Load config if provided | |
| config = load_config(args.config) | |
| # Determine port | |
| port = args.port or DEFAULT_PORTS.get(args.mode, 12020) | |
| # Get service class | |
| if args.mode not in SERVICE_MAP: | |
| logging.error('Unknown service mode: %s', args.mode) | |
| return 1 | |
| ServiceClass = import_class(SERVICE_MAP[args.mode]) | |
| # Build service kwargs | |
| service_kwargs = { | |
| 'model_path': args.weights, | |
| 'device': args.device, | |
| } | |
| # Handle OCR/brackets YAML config passed via -w | |
| if args.mode == 'ocr' and args.weights.endswith('.yaml'): | |
| logging.info('Resolving OCR config from: %s', args.weights) | |
| ocr_cfg = resolve_ocr_config(args.weights) | |
| service_kwargs.update(ocr_cfg) | |
| elif args.mode == 'brackets' and args.weights.endswith('.yaml'): | |
| logging.info('Resolving brackets config from: %s', args.weights) | |
| br_cfg = resolve_brackets_config(args.weights) | |
| service_kwargs.update(br_cfg) | |
| # Add service-specific kwargs | |
| if args.mode in ['mask', 'semantic', 'gauge']: | |
| service_kwargs['slicing_width'] = args.slicing_width | |
| if args.mode == 'semantic': | |
| if args.labels: | |
| service_kwargs['labels'] = args.labels | |
| elif 'labels' in config: | |
| service_kwargs['labels'] = config['labels'] | |
| if args.mode == 'loc': | |
| service_kwargs['image_short_side'] = args.image_short_side | |
| if args.mode in ['ocr', 'brackets'] and not args.weights.endswith('.yaml'): | |
| if args.alphabet: | |
| with open(args.alphabet, 'r', encoding='utf-8') as f: | |
| service_kwargs['alphabet'] = f.readline().strip() | |
| elif 'alphabet' in config: | |
| service_kwargs['alphabet'] = config['alphabet'] | |
| # Merge config | |
| for key, value in config.items(): | |
| if key not in service_kwargs: | |
| service_kwargs[key] = value | |
| # Create service instance | |
| logging.info('Initializing %s service...', args.mode) | |
| logging.info('Model path: %s', args.weights) | |
| logging.info('Device: %s', args.device) | |
| try: | |
| service = ServiceClass(**service_kwargs) | |
| except Exception as e: | |
| logging.error('Failed to initialize service: %s', str(e)) | |
| raise | |
| # Start ZeroMQ server | |
| from common.zero_server import ZeroServer | |
| logging.info('Starting ZeroMQ server on port %d...', port) | |
| server = ZeroServer(service) | |
| server.bind(port) | |
| if __name__ == '__main__': | |
| main() | |