Visual Document Retrieval
Transformers
ONNX
ColPali
English
pretraining
kitsuneb's picture
update main
7cc4b82
import os
import json
import torch
import onnx
import argparse
from PIL import Image
from torch.onnx._globals import GLOBALS
from transformers import ColPaliForRetrieval, ColPaliProcessor
from optimum.onnx.graph_transformations import check_and_save_model
import onnx_graphsurgeon as gs
from onnxconverter_common import float16
from onnx.external_data_helper import convert_model_to_external_data
def export_model(
model_id: str,
output_dir: str,
device: str,
fp16: bool = False,
export_type: str = "both",
):
"""Export ColPaliForRetrieval to ONNX vision/text/both"""
os.makedirs(output_dir, exist_ok=True)
model = (
ColPaliForRetrieval.from_pretrained(
model_id,
torch_dtype=torch.float16 if fp16 else torch.float32,
device_map="auto",
)
.to(device)
.eval()
)
processor = ColPaliProcessor.from_pretrained(model_id)
model.config.save_pretrained(output_dir)
processor.save_pretrained(output_dir)
_orig_forward = model.forward
# dummy inputs
dummy_img = Image.new("RGB", (32, 32), color="white")
vision_pt = processor(images=[dummy_img], return_tensors="pt").to(device)
pv, ids, msk = (
vision_pt["pixel_values"],
vision_pt["input_ids"],
vision_pt["attention_mask"],
)
fake_ids = torch.zeros((pv.size(0), 1), device=device, dtype=torch.long)
fake_mask = torch.zeros_like(fake_ids, device=device)
fake_pv = torch.zeros_like(pv)
out_paths = {}
# vision model
if export_type in ("vision", "both"):
def vision_forward(
self, pixel_values=None, input_ids=None, attention_mask=None, **kw
):
return _orig_forward(
pixel_values=pixel_values,
input_ids=None,
attention_mask=None,
**kw,
).embeddings
model.forward = vision_forward.__get__(model, model.__class__)
vision_onnx = os.path.join(output_dir, "model_vision.onnx")
vision_bin = "model_vision.onnx_data"
GLOBALS.onnx_shape_inference = False
torch.onnx.export(
model,
(pv, fake_ids, fake_mask),
vision_onnx,
export_params=True,
opset_version=14,
do_constant_folding=True,
use_external_data_format=True,
all_tensors_to_one_file=True,
size_threshold=0,
external_data_filename=vision_bin,
input_names=["pixel_values", "input_ids", "attention_mask"],
output_names=["embeddings"],
dynamic_axes={
"pixel_values": {0: "batch_size"},
"embeddings": {0: "batch_size", 1: "seq_len"},
},
)
print("✅ Exported VISION ONNX to", vision_onnx)
# fix shapes & external refs
m = onnx.shape_inference.infer_shapes_path(vision_onnx)
m = onnx.load(vision_onnx, load_external_data=True)
check_and_save_model(m, vision_onnx)
print(" (shape‐inferred + external‐data fixed)")
out_paths["vision"] = vision_onnx
# text model
if export_type in ("text", "both"):
def text_forward(
self, pixel_values=None, input_ids=None, attention_mask=None, **kw
):
return _orig_forward(
pixel_values=None,
input_ids=input_ids,
attention_mask=attention_mask,
**kw,
).embeddings
model.forward = text_forward.__get__(model, model.__class__)
text_onnx = os.path.join(output_dir, "model_text.onnx")
text_bin = "model_text.onnx_data"
torch.onnx.export(
model,
(fake_pv, ids, msk),
text_onnx,
export_params=True,
opset_version=14,
do_constant_folding=True,
use_external_data_format=True,
all_tensors_to_one_file=True,
size_threshold=0,
external_data_filename=text_bin,
input_names=["pixel_values", "input_ids", "attention_mask"],
output_names=["embeddings"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "seq_len"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
"embeddings": {0: "batch_size", 1: "seq_len"},
},
)
print("✅ Exported TEXT ONNX to", text_onnx)
m = onnx.shape_inference.infer_shapes_path(text_onnx)
m = onnx.load(text_onnx, load_external_data=True)
check_and_save_model(m, text_onnx)
print(" (shape‐inferred + external‐data fixed)")
out_paths["text"] = text_onnx
print("🎉 Done exporting model(s):", out_paths)
return out_paths
def quantize_fp16_and_externalize(
input_path,
output_path,
external_data_filename="model.onnx_data",
op_block_list=None,
):
"""
Quantize an ONNX model from FP32 to FP16
1) Load FP32 ONNX (+ its .onnx_data)
2) Cast weight tensors to FP16
3) Topo-sort / clean up
4) Copy opset_import from original model
5) Mark ALL tensors for external data
6) Save the new ONNX + .onnx_data
"""
orig = onnx.load(input_path, load_external_data=True)
model = onnx.load(input_path, load_external_data=True)
disable_si = model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF
blocked = set(float16.DEFAULT_OP_BLOCK_LIST)
if op_block_list:
blocked.update(op_block_list)
blocked.update(["LayerNormalization", "Softmax", "Div"])
model_fp16 = float16.convert_float_to_float16(
model,
max_finite_val=65504.0,
keep_io_types=True,
disable_shape_infer=disable_si,
op_block_list=blocked,
)
graph = gs.import_onnx(model_fp16)
graph.toposort()
model_fp16 = gs.export_onnx(graph)
model_fp16.ClearField("opset_import")
model_fp16.opset_import.extend(orig.opset_import)
convert_model_to_external_data(
model_fp16,
all_tensors_to_one_file=True,
location=external_data_filename,
size_threshold=0,
)
# required for check_and_save_model
if not model_fp16.opset_import:
model_fp16.opset_import.extend(
[
onnx.helper.make_opsetid("", 14), # Default domain with opset 14
]
)
# Save with shape-infer + final checks
check_and_save_model(model_fp16, output_path)
print("✅ FP16 model quantized and saved:")
print(f" ONNX: {output_path}")
print(
f" DATA: {os.path.join(os.path.dirname(output_path), external_data_filename)}"
)
return True
def main():
parser = argparse.ArgumentParser(
description="Convert ColPali model to ONNX format and FP16 quantization"
)
parser.add_argument(
"--model-id", default="vidore/colpali-v1.3-hf", help="HuggingFace model ID"
)
parser.add_argument("--output-dir", default=None, help="Output directory")
parser.add_argument(
"--quantize", action="store_true", help="Apply FP16 quantization after export"
)
parser.add_argument(
"--export-type",
choices=["vision", "text", "both"],
default="both",
help="Which ONNX to export",
)
parser.add_argument("--device", default=None, help="Device for model (cuda/cpu)")
args = parser.parse_args()
if args.device is None:
args.device = "cuda" if torch.cuda.is_available() else "cpu"
if args.output_dir is None:
args.output_dir = os.path.join("output", args.model_id.replace("/", "_"))
out_paths = export_model(
args.model_id,
args.output_dir,
args.device,
fp16=False,
export_type=args.export_type,
)
# quantize whichever were exported
if args.quantize:
print("Starting FP16 quantization")
for key, path in out_paths.items():
binname = os.path.basename(path).replace(".onnx", ".onnx_data")
quantize_fp16_and_externalize(path, path, binname)
if __name__ == "__main__":
main()