Spaces:
Running
Running
File size: 4,842 Bytes
2b7aae2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | #!/usr/bin/env python3
"""
Export PyTorch models to TorchScript format.
This script should be run in the original deep-starry environment
where the model definitions are available.
Usage:
cd /home/camus/work/deep-starry
python /path/to/export_torchscript.py --mode layout --config configs/your-config.yaml --output layout.pt
Modes:
layout - LayoutPredictor (ScoreResidue model)
mask - MaskPredictor (ScoreWidgetsMask model)
semantic - SemanticPredictor (ScoreWidgets model)
gauge - GaugePredictor (ScoreRegression model)
"""
import argparse
import torch
import numpy as np
import os
import sys
# Add deep-starry to path
DEEP_STARRY_PATH = '/home/camus/work/deep-starry'
if DEEP_STARRY_PATH not in sys.path:
sys.path.insert(0, DEEP_STARRY_PATH)
def export_layout(config_path, output_path, device='cuda'):
"""Export layout model to TorchScript."""
from starry.utils.config import Configuration
from starry.utils.model_factory import loadModel
config = Configuration.createOrLoad(config_path, volatile=True)
model = loadModel(config['model'])
checkpoint_path = config.localPath('weights.chkpt')
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model'])
model.to(device)
model.eval()
# Create example input: (batch, channel, height, width)
# Layout model expects grayscale input
example_input = torch.randn(1, 1, 600, 800).to(device)
# Trace the model
with torch.no_grad():
traced = torch.jit.trace(model, example_input)
# Save
traced.save(output_path)
print(f'Layout model exported to: {output_path}')
def export_mask(config_path, output_path, device='cuda'):
"""Export mask model to TorchScript."""
from starry.utils.config import Configuration
from starry.utils.model_factory import loadModel
config = Configuration.createOrLoad(config_path, volatile=True)
model = loadModel(config['model'])
checkpoint_path = config.localPath('weights.chkpt')
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model'])
model.to(device)
model.eval()
# Mask model input: (batch, channel, height, width)
# Usually 512x256 slices
example_input = torch.randn(1, 1, 256, 512).to(device)
with torch.no_grad():
traced = torch.jit.trace(model, example_input)
traced.save(output_path)
print(f'Mask model exported to: {output_path}')
def export_semantic(config_path, output_path, device='cuda'):
"""Export semantic model to TorchScript."""
from starry.utils.config import Configuration
from starry.utils.model_factory import loadModel
config = Configuration.createOrLoad(config_path, volatile=True)
model = loadModel(config['model'])
checkpoint_path = config.localPath('weights.chkpt')
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model'])
model.to(device)
model.eval()
# Semantic model input: (batch, channel, height, width)
example_input = torch.randn(1, 1, 256, 512).to(device)
with torch.no_grad():
traced = torch.jit.trace(model, example_input)
traced.save(output_path)
print(f'Semantic model exported to: {output_path}')
def export_gauge(config_path, output_path, device='cuda'):
"""Export gauge model to TorchScript."""
from starry.utils.config import Configuration
from starry.utils.model_factory import loadModel
config = Configuration.createOrLoad(config_path, volatile=True)
model = loadModel(config['model'])
checkpoint_path = config.localPath('weights.chkpt')
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model'])
model.to(device)
model.eval()
# Gauge model input
example_input = torch.randn(1, 1, 256, 512).to(device)
with torch.no_grad():
traced = torch.jit.trace(model, example_input)
traced.save(output_path)
print(f'Gauge model exported to: {output_path}')
EXPORTERS = {
'layout': export_layout,
'mask': export_mask,
'semantic': export_semantic,
'gauge': export_gauge,
}
def main():
parser = argparse.ArgumentParser(description='Export PyTorch models to TorchScript')
parser.add_argument('--mode', type=str, required=True, choices=list(EXPORTERS.keys()),
help='Model type to export')
parser.add_argument('--config', type=str, required=True,
help='Path to model configuration directory')
parser.add_argument('--output', type=str, required=True,
help='Output TorchScript file path')
parser.add_argument('--device', type=str, default='cuda',
help='Device to use for export')
args = parser.parse_args()
exporter = EXPORTERS[args.mode]
exporter(args.config, args.output, args.device)
if __name__ == '__main__':
main()
|