|
|
|
|
|
""" |
|
|
ProtT5 Encoder ONNX Conversion Script |
|
|
|
|
|
Converts ProtT5 encoder-only models to ONNX format for optimized inference. |
|
|
|
|
|
Usage: |
|
|
python convert.py --model_name Rostlab/prot_t5_xl_half_uniref50-enc --output_dir ./prot_t5_onnx |
|
|
|
|
|
Requirements: |
|
|
pip install torch transformers onnx |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
from pathlib import Path |
|
|
from typing import Dict |
|
|
|
|
|
import torch |
|
|
from transformers import T5EncoderModel, T5Tokenizer |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ProtT5EncoderConverter: |
|
|
""" |
|
|
Convert ProtT5 encoder-only models to ONNX format |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str, |
|
|
max_sequence_length: int = 1024, |
|
|
fp16: bool = True |
|
|
): |
|
|
""" |
|
|
Initialize ProtT5 encoder converter |
|
|
|
|
|
Args: |
|
|
model_name: Hugging Face model identifier |
|
|
max_sequence_length: Maximum protein sequence length |
|
|
fp16: Use half precision (float16) |
|
|
""" |
|
|
self.model_name = model_name |
|
|
self.max_sequence_length = max_sequence_length |
|
|
self.fp16 = fp16 |
|
|
|
|
|
logger.info(f"Initializing converter for {model_name}") |
|
|
|
|
|
|
|
|
self.tokenizer = T5Tokenizer.from_pretrained(model_name, do_lower_case=False) |
|
|
self.model = T5EncoderModel.from_pretrained(model_name) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
if self.fp16: |
|
|
self.model = self.model.half() |
|
|
|
|
|
def prepare_dummy_inputs(self, batch_size: int = 1, sequence_length: int = 10) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Prepare dummy inputs for ONNX export tracing |
|
|
|
|
|
Note: These are minimal inputs required by torch.onnx.export() to trace |
|
|
the model's execution graph. They don't need to be realistic data. |
|
|
|
|
|
Args: |
|
|
batch_size: Number of sequences in batch |
|
|
sequence_length: Length of each sequence |
|
|
|
|
|
Returns: |
|
|
Dictionary of input tensors with correct shapes/types |
|
|
""" |
|
|
|
|
|
input_ids = torch.randint( |
|
|
0, self.tokenizer.vocab_size, |
|
|
(batch_size, sequence_length), |
|
|
dtype=torch.long |
|
|
) |
|
|
attention_mask = torch.ones( |
|
|
(batch_size, sequence_length), |
|
|
dtype=torch.long |
|
|
) |
|
|
|
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": attention_mask |
|
|
} |
|
|
|
|
|
def export_encoder_onnx(self, output_path: str) -> str: |
|
|
""" |
|
|
Export encoder model to ONNX |
|
|
|
|
|
Args: |
|
|
output_path: Path to save ONNX model |
|
|
|
|
|
Returns: |
|
|
Path to exported ONNX model |
|
|
""" |
|
|
logger.info(f"Exporting encoder to ONNX: {output_path}") |
|
|
|
|
|
|
|
|
dummy_inputs = self.prepare_dummy_inputs() |
|
|
|
|
|
|
|
|
torch.onnx.export( |
|
|
self.model, |
|
|
(dummy_inputs['input_ids'], dummy_inputs['attention_mask']), |
|
|
output_path, |
|
|
export_params=True, |
|
|
opset_version=14, |
|
|
do_constant_folding=True, |
|
|
input_names=['input_ids', 'attention_mask'], |
|
|
output_names=['last_hidden_state'], |
|
|
dynamic_axes={ |
|
|
'input_ids': {0: 'batch_size', 1: 'sequence_length'}, |
|
|
'attention_mask': {0: 'batch_size', 1: 'sequence_length'}, |
|
|
'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'} |
|
|
} |
|
|
) |
|
|
|
|
|
logger.info(f"ONNX export completed: {output_path}") |
|
|
return output_path |
|
|
|
|
|
def save_tokenizer(self, output_dir: str): |
|
|
""" |
|
|
Save tokenizer to output directory |
|
|
|
|
|
Args: |
|
|
output_dir: Directory to save tokenizer files |
|
|
""" |
|
|
logger.info(f"Saving tokenizer to {output_dir}") |
|
|
self.tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
def convert(self, output_dir: str) -> Dict[str, str]: |
|
|
""" |
|
|
Convert model and save all components |
|
|
|
|
|
Args: |
|
|
output_dir: Directory to save converted model |
|
|
|
|
|
Returns: |
|
|
Dictionary with paths to saved files |
|
|
""" |
|
|
|
|
|
output_path = Path(output_dir) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
logger.info(f"Converting {self.model_name} to ONNX in {output_dir}") |
|
|
|
|
|
|
|
|
onnx_path = output_path / "model.onnx" |
|
|
self.export_encoder_onnx(str(onnx_path)) |
|
|
|
|
|
|
|
|
self.save_tokenizer(output_dir) |
|
|
|
|
|
return { |
|
|
"onnx_model": str(onnx_path), |
|
|
"tokenizer_dir": output_dir |
|
|
} |
|
|
|
|
|
def main(): |
|
|
"""Main conversion function""" |
|
|
parser = argparse.ArgumentParser(description="Convert ProtT5 encoder to ONNX") |
|
|
parser.add_argument( |
|
|
"--model_name", |
|
|
default="Rostlab/prot_t5_xl_half_uniref50-enc", |
|
|
help="Hugging Face model name" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_dir", |
|
|
default="./prot_t5_onnx", |
|
|
help="Output directory for converted model" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_sequence_length", |
|
|
type=int, |
|
|
default=1024, |
|
|
help="Maximum sequence length" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--fp16", |
|
|
action="store_true", |
|
|
default=True, |
|
|
help="Use half precision (default: True)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--no_fp16", |
|
|
action="store_true", |
|
|
help="Disable half precision" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
fp16 = args.fp16 and not args.no_fp16 |
|
|
|
|
|
|
|
|
converter = ProtT5EncoderConverter( |
|
|
model_name=args.model_name, |
|
|
max_sequence_length=args.max_sequence_length, |
|
|
fp16=fp16 |
|
|
) |
|
|
|
|
|
|
|
|
result = converter.convert(args.output_dir) |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("PROTТ5 ONNX CONVERSION COMPLETED") |
|
|
print("="*60) |
|
|
print(f"Model: {args.model_name}") |
|
|
print(f"ONNX Model: {result['onnx_model']}") |
|
|
print(f"Tokenizer: {result['tokenizer_dir']}") |
|
|
print(f"Half Precision: {fp16}") |
|
|
print("="*60) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |