import os import json import torch import onnx import argparse from PIL import Image from torch.onnx._globals import GLOBALS from transformers import ColPaliForRetrieval, ColPaliProcessor from optimum.onnx.graph_transformations import check_and_save_model import onnx_graphsurgeon as gs from onnxconverter_common import float16 from onnx.external_data_helper import convert_model_to_external_data def export_model( model_id: str, output_dir: str, device: str, fp16: bool = False, export_type: str = "both", ): """Export ColPaliForRetrieval to ONNX vision/text/both""" os.makedirs(output_dir, exist_ok=True) model = ( ColPaliForRetrieval.from_pretrained( model_id, torch_dtype=torch.float16 if fp16 else torch.float32, device_map="auto", ) .to(device) .eval() ) processor = ColPaliProcessor.from_pretrained(model_id) model.config.save_pretrained(output_dir) processor.save_pretrained(output_dir) _orig_forward = model.forward # dummy inputs dummy_img = Image.new("RGB", (32, 32), color="white") vision_pt = processor(images=[dummy_img], return_tensors="pt").to(device) pv, ids, msk = ( vision_pt["pixel_values"], vision_pt["input_ids"], vision_pt["attention_mask"], ) fake_ids = torch.zeros((pv.size(0), 1), device=device, dtype=torch.long) fake_mask = torch.zeros_like(fake_ids, device=device) fake_pv = torch.zeros_like(pv) out_paths = {} # vision model if export_type in ("vision", "both"): def vision_forward( self, pixel_values=None, input_ids=None, attention_mask=None, **kw ): return _orig_forward( pixel_values=pixel_values, input_ids=None, attention_mask=None, **kw, ).embeddings model.forward = vision_forward.__get__(model, model.__class__) vision_onnx = os.path.join(output_dir, "model_vision.onnx") vision_bin = "model_vision.onnx_data" GLOBALS.onnx_shape_inference = False torch.onnx.export( model, (pv, fake_ids, fake_mask), vision_onnx, export_params=True, opset_version=14, do_constant_folding=True, use_external_data_format=True, all_tensors_to_one_file=True, size_threshold=0, external_data_filename=vision_bin, input_names=["pixel_values", "input_ids", "attention_mask"], output_names=["embeddings"], dynamic_axes={ "pixel_values": {0: "batch_size"}, "embeddings": {0: "batch_size", 1: "seq_len"}, }, ) print("✅ Exported VISION ONNX to", vision_onnx) # fix shapes & external refs m = onnx.shape_inference.infer_shapes_path(vision_onnx) m = onnx.load(vision_onnx, load_external_data=True) check_and_save_model(m, vision_onnx) print(" (shape‐inferred + external‐data fixed)") out_paths["vision"] = vision_onnx # text model if export_type in ("text", "both"): def text_forward( self, pixel_values=None, input_ids=None, attention_mask=None, **kw ): return _orig_forward( pixel_values=None, input_ids=input_ids, attention_mask=attention_mask, **kw, ).embeddings model.forward = text_forward.__get__(model, model.__class__) text_onnx = os.path.join(output_dir, "model_text.onnx") text_bin = "model_text.onnx_data" torch.onnx.export( model, (fake_pv, ids, msk), text_onnx, export_params=True, opset_version=14, do_constant_folding=True, use_external_data_format=True, all_tensors_to_one_file=True, size_threshold=0, external_data_filename=text_bin, input_names=["pixel_values", "input_ids", "attention_mask"], output_names=["embeddings"], dynamic_axes={ "input_ids": {0: "batch_size", 1: "seq_len"}, "attention_mask": {0: "batch_size", 1: "seq_len"}, "embeddings": {0: "batch_size", 1: "seq_len"}, }, ) print("✅ Exported TEXT ONNX to", text_onnx) m = onnx.shape_inference.infer_shapes_path(text_onnx) m = onnx.load(text_onnx, load_external_data=True) check_and_save_model(m, text_onnx) print(" (shape‐inferred + external‐data fixed)") out_paths["text"] = text_onnx print("🎉 Done exporting model(s):", out_paths) return out_paths def quantize_fp16_and_externalize( input_path, output_path, external_data_filename="model.onnx_data", op_block_list=None, ): """ Quantize an ONNX model from FP32 to FP16 1) Load FP32 ONNX (+ its .onnx_data) 2) Cast weight tensors to FP16 3) Topo-sort / clean up 4) Copy opset_import from original model 5) Mark ALL tensors for external data 6) Save the new ONNX + .onnx_data """ orig = onnx.load(input_path, load_external_data=True) model = onnx.load(input_path, load_external_data=True) disable_si = model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF blocked = set(float16.DEFAULT_OP_BLOCK_LIST) if op_block_list: blocked.update(op_block_list) blocked.update(["LayerNormalization", "Softmax", "Div"]) model_fp16 = float16.convert_float_to_float16( model, max_finite_val=65504.0, keep_io_types=True, disable_shape_infer=disable_si, op_block_list=blocked, ) graph = gs.import_onnx(model_fp16) graph.toposort() model_fp16 = gs.export_onnx(graph) model_fp16.ClearField("opset_import") model_fp16.opset_import.extend(orig.opset_import) convert_model_to_external_data( model_fp16, all_tensors_to_one_file=True, location=external_data_filename, size_threshold=0, ) # required for check_and_save_model if not model_fp16.opset_import: model_fp16.opset_import.extend( [ onnx.helper.make_opsetid("", 14), # Default domain with opset 14 ] ) # Save with shape-infer + final checks check_and_save_model(model_fp16, output_path) print("✅ FP16 model quantized and saved:") print(f" ONNX: {output_path}") print( f" DATA: {os.path.join(os.path.dirname(output_path), external_data_filename)}" ) return True def main(): parser = argparse.ArgumentParser( description="Convert ColPali model to ONNX format and FP16 quantization" ) parser.add_argument( "--model-id", default="vidore/colpali-v1.3-hf", help="HuggingFace model ID" ) parser.add_argument("--output-dir", default=None, help="Output directory") parser.add_argument( "--quantize", action="store_true", help="Apply FP16 quantization after export" ) parser.add_argument( "--export-type", choices=["vision", "text", "both"], default="both", help="Which ONNX to export", ) parser.add_argument("--device", default=None, help="Device for model (cuda/cpu)") args = parser.parse_args() if args.device is None: args.device = "cuda" if torch.cuda.is_available() else "cpu" if args.output_dir is None: args.output_dir = os.path.join("output", args.model_id.replace("/", "_")) out_paths = export_model( args.model_id, args.output_dir, args.device, fp16=False, export_type=args.export_type, ) # quantize whichever were exported if args.quantize: print("Starting FP16 quantization") for key, path in out_paths.items(): binname = os.path.basename(path).replace(".onnx", ".onnx_data") quantize_fp16_and_externalize(path, path, binname) if __name__ == "__main__": main()