#!/usr/bin/env python3 """ Export PyTorch models to TorchScript format. This script should be run in the original deep-starry environment where the model definitions are available. Usage: cd /home/camus/work/deep-starry python /path/to/export_torchscript.py --mode layout --config configs/your-config.yaml --output layout.pt Modes: layout - LayoutPredictor (ScoreResidue model) mask - MaskPredictor (ScoreWidgetsMask model) semantic - SemanticPredictor (ScoreWidgets model) gauge - GaugePredictor (ScoreRegression model) """ import argparse import torch import numpy as np import os import sys # Add deep-starry to path DEEP_STARRY_PATH = '/home/camus/work/deep-starry' if DEEP_STARRY_PATH not in sys.path: sys.path.insert(0, DEEP_STARRY_PATH) def export_layout(config_path, output_path, device='cuda'): """Export layout model to TorchScript.""" from starry.utils.config import Configuration from starry.utils.model_factory import loadModel config = Configuration.createOrLoad(config_path, volatile=True) model = loadModel(config['model']) checkpoint_path = config.localPath('weights.chkpt') if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint['model']) model.to(device) model.eval() # Create example input: (batch, channel, height, width) # Layout model expects grayscale input example_input = torch.randn(1, 1, 600, 800).to(device) # Trace the model with torch.no_grad(): traced = torch.jit.trace(model, example_input) # Save traced.save(output_path) print(f'Layout model exported to: {output_path}') def export_mask(config_path, output_path, device='cuda'): """Export mask model to TorchScript.""" from starry.utils.config import Configuration from starry.utils.model_factory import loadModel config = Configuration.createOrLoad(config_path, volatile=True) model = loadModel(config['model']) checkpoint_path = config.localPath('weights.chkpt') if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint['model']) model.to(device) model.eval() # Mask model input: (batch, channel, height, width) # Usually 512x256 slices example_input = torch.randn(1, 1, 256, 512).to(device) with torch.no_grad(): traced = torch.jit.trace(model, example_input) traced.save(output_path) print(f'Mask model exported to: {output_path}') def export_semantic(config_path, output_path, device='cuda'): """Export semantic model to TorchScript.""" from starry.utils.config import Configuration from starry.utils.model_factory import loadModel config = Configuration.createOrLoad(config_path, volatile=True) model = loadModel(config['model']) checkpoint_path = config.localPath('weights.chkpt') if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint['model']) model.to(device) model.eval() # Semantic model input: (batch, channel, height, width) example_input = torch.randn(1, 1, 256, 512).to(device) with torch.no_grad(): traced = torch.jit.trace(model, example_input) traced.save(output_path) print(f'Semantic model exported to: {output_path}') def export_gauge(config_path, output_path, device='cuda'): """Export gauge model to TorchScript.""" from starry.utils.config import Configuration from starry.utils.model_factory import loadModel config = Configuration.createOrLoad(config_path, volatile=True) model = loadModel(config['model']) checkpoint_path = config.localPath('weights.chkpt') if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint['model']) model.to(device) model.eval() # Gauge model input example_input = torch.randn(1, 1, 256, 512).to(device) with torch.no_grad(): traced = torch.jit.trace(model, example_input) traced.save(output_path) print(f'Gauge model exported to: {output_path}') EXPORTERS = { 'layout': export_layout, 'mask': export_mask, 'semantic': export_semantic, 'gauge': export_gauge, } def main(): parser = argparse.ArgumentParser(description='Export PyTorch models to TorchScript') parser.add_argument('--mode', type=str, required=True, choices=list(EXPORTERS.keys()), help='Model type to export') parser.add_argument('--config', type=str, required=True, help='Path to model configuration directory') parser.add_argument('--output', type=str, required=True, help='Output TorchScript file path') parser.add_argument('--device', type=str, default='cuda', help='Device to use for export') args = parser.parse_args() exporter = EXPORTERS[args.mode] exporter(args.config, args.output, args.device) if __name__ == '__main__': main()