starry / backend /python-services /scripts /export_torchscript.py
k-l-lambda's picture
feat: add Python ML services (CPU mode) with model download
2b7aae2
#!/usr/bin/env python3
"""
Export PyTorch models to TorchScript format.
This script should be run in the original deep-starry environment
where the model definitions are available.
Usage:
cd /home/camus/work/deep-starry
python /path/to/export_torchscript.py --mode layout --config configs/your-config.yaml --output layout.pt
Modes:
layout - LayoutPredictor (ScoreResidue model)
mask - MaskPredictor (ScoreWidgetsMask model)
semantic - SemanticPredictor (ScoreWidgets model)
gauge - GaugePredictor (ScoreRegression model)
"""
import argparse
import torch
import numpy as np
import os
import sys
# Add deep-starry to path
DEEP_STARRY_PATH = '/home/camus/work/deep-starry'
if DEEP_STARRY_PATH not in sys.path:
sys.path.insert(0, DEEP_STARRY_PATH)
def export_layout(config_path, output_path, device='cuda'):
"""Export layout model to TorchScript."""
from starry.utils.config import Configuration
from starry.utils.model_factory import loadModel
config = Configuration.createOrLoad(config_path, volatile=True)
model = loadModel(config['model'])
checkpoint_path = config.localPath('weights.chkpt')
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model'])
model.to(device)
model.eval()
# Create example input: (batch, channel, height, width)
# Layout model expects grayscale input
example_input = torch.randn(1, 1, 600, 800).to(device)
# Trace the model
with torch.no_grad():
traced = torch.jit.trace(model, example_input)
# Save
traced.save(output_path)
print(f'Layout model exported to: {output_path}')
def export_mask(config_path, output_path, device='cuda'):
"""Export mask model to TorchScript."""
from starry.utils.config import Configuration
from starry.utils.model_factory import loadModel
config = Configuration.createOrLoad(config_path, volatile=True)
model = loadModel(config['model'])
checkpoint_path = config.localPath('weights.chkpt')
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model'])
model.to(device)
model.eval()
# Mask model input: (batch, channel, height, width)
# Usually 512x256 slices
example_input = torch.randn(1, 1, 256, 512).to(device)
with torch.no_grad():
traced = torch.jit.trace(model, example_input)
traced.save(output_path)
print(f'Mask model exported to: {output_path}')
def export_semantic(config_path, output_path, device='cuda'):
"""Export semantic model to TorchScript."""
from starry.utils.config import Configuration
from starry.utils.model_factory import loadModel
config = Configuration.createOrLoad(config_path, volatile=True)
model = loadModel(config['model'])
checkpoint_path = config.localPath('weights.chkpt')
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model'])
model.to(device)
model.eval()
# Semantic model input: (batch, channel, height, width)
example_input = torch.randn(1, 1, 256, 512).to(device)
with torch.no_grad():
traced = torch.jit.trace(model, example_input)
traced.save(output_path)
print(f'Semantic model exported to: {output_path}')
def export_gauge(config_path, output_path, device='cuda'):
"""Export gauge model to TorchScript."""
from starry.utils.config import Configuration
from starry.utils.model_factory import loadModel
config = Configuration.createOrLoad(config_path, volatile=True)
model = loadModel(config['model'])
checkpoint_path = config.localPath('weights.chkpt')
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model'])
model.to(device)
model.eval()
# Gauge model input
example_input = torch.randn(1, 1, 256, 512).to(device)
with torch.no_grad():
traced = torch.jit.trace(model, example_input)
traced.save(output_path)
print(f'Gauge model exported to: {output_path}')
EXPORTERS = {
'layout': export_layout,
'mask': export_mask,
'semantic': export_semantic,
'gauge': export_gauge,
}
def main():
parser = argparse.ArgumentParser(description='Export PyTorch models to TorchScript')
parser.add_argument('--mode', type=str, required=True, choices=list(EXPORTERS.keys()),
help='Model type to export')
parser.add_argument('--config', type=str, required=True,
help='Path to model configuration directory')
parser.add_argument('--output', type=str, required=True,
help='Output TorchScript file path')
parser.add_argument('--device', type=str, default='cuda',
help='Device to use for export')
args = parser.parse_args()
exporter = EXPORTERS[args.mode]
exporter(args.config, args.output, args.device)
if __name__ == '__main__':
main()