Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Export PyTorch models to TorchScript format. | |
| This script should be run in the original deep-starry environment | |
| where the model definitions are available. | |
| Usage: | |
| cd /home/camus/work/deep-starry | |
| python /path/to/export_torchscript.py --mode layout --config configs/your-config.yaml --output layout.pt | |
| Modes: | |
| layout - LayoutPredictor (ScoreResidue model) | |
| mask - MaskPredictor (ScoreWidgetsMask model) | |
| semantic - SemanticPredictor (ScoreWidgets model) | |
| gauge - GaugePredictor (ScoreRegression model) | |
| """ | |
| import argparse | |
| import torch | |
| import numpy as np | |
| import os | |
| import sys | |
| # Add deep-starry to path | |
| DEEP_STARRY_PATH = '/home/camus/work/deep-starry' | |
| if DEEP_STARRY_PATH not in sys.path: | |
| sys.path.insert(0, DEEP_STARRY_PATH) | |
| def export_layout(config_path, output_path, device='cuda'): | |
| """Export layout model to TorchScript.""" | |
| from starry.utils.config import Configuration | |
| from starry.utils.model_factory import loadModel | |
| config = Configuration.createOrLoad(config_path, volatile=True) | |
| model = loadModel(config['model']) | |
| checkpoint_path = config.localPath('weights.chkpt') | |
| if os.path.exists(checkpoint_path): | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| model.load_state_dict(checkpoint['model']) | |
| model.to(device) | |
| model.eval() | |
| # Create example input: (batch, channel, height, width) | |
| # Layout model expects grayscale input | |
| example_input = torch.randn(1, 1, 600, 800).to(device) | |
| # Trace the model | |
| with torch.no_grad(): | |
| traced = torch.jit.trace(model, example_input) | |
| # Save | |
| traced.save(output_path) | |
| print(f'Layout model exported to: {output_path}') | |
| def export_mask(config_path, output_path, device='cuda'): | |
| """Export mask model to TorchScript.""" | |
| from starry.utils.config import Configuration | |
| from starry.utils.model_factory import loadModel | |
| config = Configuration.createOrLoad(config_path, volatile=True) | |
| model = loadModel(config['model']) | |
| checkpoint_path = config.localPath('weights.chkpt') | |
| if os.path.exists(checkpoint_path): | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| model.load_state_dict(checkpoint['model']) | |
| model.to(device) | |
| model.eval() | |
| # Mask model input: (batch, channel, height, width) | |
| # Usually 512x256 slices | |
| example_input = torch.randn(1, 1, 256, 512).to(device) | |
| with torch.no_grad(): | |
| traced = torch.jit.trace(model, example_input) | |
| traced.save(output_path) | |
| print(f'Mask model exported to: {output_path}') | |
| def export_semantic(config_path, output_path, device='cuda'): | |
| """Export semantic model to TorchScript.""" | |
| from starry.utils.config import Configuration | |
| from starry.utils.model_factory import loadModel | |
| config = Configuration.createOrLoad(config_path, volatile=True) | |
| model = loadModel(config['model']) | |
| checkpoint_path = config.localPath('weights.chkpt') | |
| if os.path.exists(checkpoint_path): | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| model.load_state_dict(checkpoint['model']) | |
| model.to(device) | |
| model.eval() | |
| # Semantic model input: (batch, channel, height, width) | |
| example_input = torch.randn(1, 1, 256, 512).to(device) | |
| with torch.no_grad(): | |
| traced = torch.jit.trace(model, example_input) | |
| traced.save(output_path) | |
| print(f'Semantic model exported to: {output_path}') | |
| def export_gauge(config_path, output_path, device='cuda'): | |
| """Export gauge model to TorchScript.""" | |
| from starry.utils.config import Configuration | |
| from starry.utils.model_factory import loadModel | |
| config = Configuration.createOrLoad(config_path, volatile=True) | |
| model = loadModel(config['model']) | |
| checkpoint_path = config.localPath('weights.chkpt') | |
| if os.path.exists(checkpoint_path): | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| model.load_state_dict(checkpoint['model']) | |
| model.to(device) | |
| model.eval() | |
| # Gauge model input | |
| example_input = torch.randn(1, 1, 256, 512).to(device) | |
| with torch.no_grad(): | |
| traced = torch.jit.trace(model, example_input) | |
| traced.save(output_path) | |
| print(f'Gauge model exported to: {output_path}') | |
| EXPORTERS = { | |
| 'layout': export_layout, | |
| 'mask': export_mask, | |
| 'semantic': export_semantic, | |
| 'gauge': export_gauge, | |
| } | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Export PyTorch models to TorchScript') | |
| parser.add_argument('--mode', type=str, required=True, choices=list(EXPORTERS.keys()), | |
| help='Model type to export') | |
| parser.add_argument('--config', type=str, required=True, | |
| help='Path to model configuration directory') | |
| parser.add_argument('--output', type=str, required=True, | |
| help='Output TorchScript file path') | |
| parser.add_argument('--device', type=str, default='cuda', | |
| help='Device to use for export') | |
| args = parser.parse_args() | |
| exporter = EXPORTERS[args.mode] | |
| exporter(args.config, args.output, args.device) | |
| if __name__ == '__main__': | |
| main() | |