t03i's picture
Add ONNX model and scripts
f771c82 verified
#!/usr/bin/env python3
"""
ProtT5 Encoder ONNX Conversion Script
Converts ProtT5 encoder-only models to ONNX format for optimized inference.
Usage:
python convert.py --model_name Rostlab/prot_t5_xl_half_uniref50-enc --output_dir ./prot_t5_onnx
Requirements:
pip install torch transformers onnx
"""
import argparse
from pathlib import Path
from typing import Dict
import torch
from transformers import T5EncoderModel, T5Tokenizer
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ProtT5EncoderConverter:
"""
Convert ProtT5 encoder-only models to ONNX format
"""
def __init__(
self,
model_name: str,
max_sequence_length: int = 1024,
fp16: bool = True
):
"""
Initialize ProtT5 encoder converter
Args:
model_name: Hugging Face model identifier
max_sequence_length: Maximum protein sequence length
fp16: Use half precision (float16)
"""
self.model_name = model_name
self.max_sequence_length = max_sequence_length
self.fp16 = fp16
logger.info(f"Initializing converter for {model_name}")
# Load tokenizer and model
self.tokenizer = T5Tokenizer.from_pretrained(model_name, do_lower_case=False)
self.model = T5EncoderModel.from_pretrained(model_name)
self.model.eval()
# Convert to half precision if requested
if self.fp16:
self.model = self.model.half()
def prepare_dummy_inputs(self, batch_size: int = 1, sequence_length: int = 10) -> Dict[str, torch.Tensor]:
"""
Prepare dummy inputs for ONNX export tracing
Note: These are minimal inputs required by torch.onnx.export() to trace
the model's execution graph. They don't need to be realistic data.
Args:
batch_size: Number of sequences in batch
sequence_length: Length of each sequence
Returns:
Dictionary of input tensors with correct shapes/types
"""
# Create dummy inputs with appropriate shape
input_ids = torch.randint(
0, self.tokenizer.vocab_size,
(batch_size, sequence_length),
dtype=torch.long
)
attention_mask = torch.ones(
(batch_size, sequence_length),
dtype=torch.long
)
return {
"input_ids": input_ids,
"attention_mask": attention_mask
}
def export_encoder_onnx(self, output_path: str) -> str:
"""
Export encoder model to ONNX
Args:
output_path: Path to save ONNX model
Returns:
Path to exported ONNX model
"""
logger.info(f"Exporting encoder to ONNX: {output_path}")
# Prepare dummy inputs for export
dummy_inputs = self.prepare_dummy_inputs()
# Export to ONNX
torch.onnx.export(
self.model,
(dummy_inputs['input_ids'], dummy_inputs['attention_mask']),
output_path,
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=['input_ids', 'attention_mask'],
output_names=['last_hidden_state'],
dynamic_axes={
'input_ids': {0: 'batch_size', 1: 'sequence_length'},
'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}
}
)
logger.info(f"ONNX export completed: {output_path}")
return output_path
def save_tokenizer(self, output_dir: str):
"""
Save tokenizer to output directory
Args:
output_dir: Directory to save tokenizer files
"""
logger.info(f"Saving tokenizer to {output_dir}")
self.tokenizer.save_pretrained(output_dir)
def convert(self, output_dir: str) -> Dict[str, str]:
"""
Convert model and save all components
Args:
output_dir: Directory to save converted model
Returns:
Dictionary with paths to saved files
"""
# Create output directory
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Converting {self.model_name} to ONNX in {output_dir}")
# Export ONNX model
onnx_path = output_path / "model.onnx"
self.export_encoder_onnx(str(onnx_path))
# Save tokenizer
self.save_tokenizer(output_dir)
return {
"onnx_model": str(onnx_path),
"tokenizer_dir": output_dir
}
def main():
"""Main conversion function"""
parser = argparse.ArgumentParser(description="Convert ProtT5 encoder to ONNX")
parser.add_argument(
"--model_name",
default="Rostlab/prot_t5_xl_half_uniref50-enc",
help="Hugging Face model name"
)
parser.add_argument(
"--output_dir",
default="./prot_t5_onnx",
help="Output directory for converted model"
)
parser.add_argument(
"--max_sequence_length",
type=int,
default=1024,
help="Maximum sequence length"
)
parser.add_argument(
"--fp16",
action="store_true",
default=True,
help="Use half precision (default: True)"
)
parser.add_argument(
"--no_fp16",
action="store_true",
help="Disable half precision"
)
args = parser.parse_args()
# Handle fp16 flag
fp16 = args.fp16 and not args.no_fp16
# Initialize converter
converter = ProtT5EncoderConverter(
model_name=args.model_name,
max_sequence_length=args.max_sequence_length,
fp16=fp16
)
# Convert model
result = converter.convert(args.output_dir)
# Print results
print("\n" + "="*60)
print("PROTТ5 ONNX CONVERSION COMPLETED")
print("="*60)
print(f"Model: {args.model_name}")
print(f"ONNX Model: {result['onnx_model']}")
print(f"Tokenizer: {result['tokenizer_dir']}")
print(f"Half Precision: {fp16}")
print("="*60)
if __name__ == "__main__":
main()