|
|
import os |
|
|
import json |
|
|
import torch |
|
|
import onnx |
|
|
import argparse |
|
|
from PIL import Image |
|
|
from torch.onnx._globals import GLOBALS |
|
|
from transformers import ColPaliForRetrieval, ColPaliProcessor |
|
|
from optimum.onnx.graph_transformations import check_and_save_model |
|
|
import onnx_graphsurgeon as gs |
|
|
from onnxconverter_common import float16 |
|
|
from onnx.external_data_helper import convert_model_to_external_data |
|
|
|
|
|
|
|
|
def export_model( |
|
|
model_id: str, |
|
|
output_dir: str, |
|
|
device: str, |
|
|
fp16: bool = False, |
|
|
export_type: str = "both", |
|
|
): |
|
|
"""Export ColPaliForRetrieval to ONNX vision/text/both""" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
model = ( |
|
|
ColPaliForRetrieval.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16 if fp16 else torch.float32, |
|
|
device_map="auto", |
|
|
) |
|
|
.to(device) |
|
|
.eval() |
|
|
) |
|
|
processor = ColPaliProcessor.from_pretrained(model_id) |
|
|
model.config.save_pretrained(output_dir) |
|
|
processor.save_pretrained(output_dir) |
|
|
|
|
|
_orig_forward = model.forward |
|
|
|
|
|
|
|
|
dummy_img = Image.new("RGB", (32, 32), color="white") |
|
|
vision_pt = processor(images=[dummy_img], return_tensors="pt").to(device) |
|
|
pv, ids, msk = ( |
|
|
vision_pt["pixel_values"], |
|
|
vision_pt["input_ids"], |
|
|
vision_pt["attention_mask"], |
|
|
) |
|
|
fake_ids = torch.zeros((pv.size(0), 1), device=device, dtype=torch.long) |
|
|
fake_mask = torch.zeros_like(fake_ids, device=device) |
|
|
fake_pv = torch.zeros_like(pv) |
|
|
|
|
|
out_paths = {} |
|
|
|
|
|
|
|
|
if export_type in ("vision", "both"): |
|
|
|
|
|
def vision_forward( |
|
|
self, pixel_values=None, input_ids=None, attention_mask=None, **kw |
|
|
): |
|
|
return _orig_forward( |
|
|
pixel_values=pixel_values, |
|
|
input_ids=None, |
|
|
attention_mask=None, |
|
|
**kw, |
|
|
).embeddings |
|
|
|
|
|
model.forward = vision_forward.__get__(model, model.__class__) |
|
|
|
|
|
vision_onnx = os.path.join(output_dir, "model_vision.onnx") |
|
|
vision_bin = "model_vision.onnx_data" |
|
|
GLOBALS.onnx_shape_inference = False |
|
|
torch.onnx.export( |
|
|
model, |
|
|
(pv, fake_ids, fake_mask), |
|
|
vision_onnx, |
|
|
export_params=True, |
|
|
opset_version=14, |
|
|
do_constant_folding=True, |
|
|
use_external_data_format=True, |
|
|
all_tensors_to_one_file=True, |
|
|
size_threshold=0, |
|
|
external_data_filename=vision_bin, |
|
|
input_names=["pixel_values", "input_ids", "attention_mask"], |
|
|
output_names=["embeddings"], |
|
|
dynamic_axes={ |
|
|
"pixel_values": {0: "batch_size"}, |
|
|
"embeddings": {0: "batch_size", 1: "seq_len"}, |
|
|
}, |
|
|
) |
|
|
print("✅ Exported VISION ONNX to", vision_onnx) |
|
|
|
|
|
|
|
|
m = onnx.shape_inference.infer_shapes_path(vision_onnx) |
|
|
m = onnx.load(vision_onnx, load_external_data=True) |
|
|
check_and_save_model(m, vision_onnx) |
|
|
print(" (shape‐inferred + external‐data fixed)") |
|
|
|
|
|
out_paths["vision"] = vision_onnx |
|
|
|
|
|
|
|
|
if export_type in ("text", "both"): |
|
|
|
|
|
def text_forward( |
|
|
self, pixel_values=None, input_ids=None, attention_mask=None, **kw |
|
|
): |
|
|
return _orig_forward( |
|
|
pixel_values=None, |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
**kw, |
|
|
).embeddings |
|
|
|
|
|
model.forward = text_forward.__get__(model, model.__class__) |
|
|
|
|
|
text_onnx = os.path.join(output_dir, "model_text.onnx") |
|
|
text_bin = "model_text.onnx_data" |
|
|
torch.onnx.export( |
|
|
model, |
|
|
(fake_pv, ids, msk), |
|
|
text_onnx, |
|
|
export_params=True, |
|
|
opset_version=14, |
|
|
do_constant_folding=True, |
|
|
use_external_data_format=True, |
|
|
all_tensors_to_one_file=True, |
|
|
size_threshold=0, |
|
|
external_data_filename=text_bin, |
|
|
input_names=["pixel_values", "input_ids", "attention_mask"], |
|
|
output_names=["embeddings"], |
|
|
dynamic_axes={ |
|
|
"input_ids": {0: "batch_size", 1: "seq_len"}, |
|
|
"attention_mask": {0: "batch_size", 1: "seq_len"}, |
|
|
"embeddings": {0: "batch_size", 1: "seq_len"}, |
|
|
}, |
|
|
) |
|
|
print("✅ Exported TEXT ONNX to", text_onnx) |
|
|
|
|
|
m = onnx.shape_inference.infer_shapes_path(text_onnx) |
|
|
m = onnx.load(text_onnx, load_external_data=True) |
|
|
check_and_save_model(m, text_onnx) |
|
|
print(" (shape‐inferred + external‐data fixed)") |
|
|
|
|
|
out_paths["text"] = text_onnx |
|
|
|
|
|
print("🎉 Done exporting model(s):", out_paths) |
|
|
return out_paths |
|
|
|
|
|
|
|
|
def quantize_fp16_and_externalize( |
|
|
input_path, |
|
|
output_path, |
|
|
external_data_filename="model.onnx_data", |
|
|
op_block_list=None, |
|
|
): |
|
|
""" |
|
|
Quantize an ONNX model from FP32 to FP16 |
|
|
1) Load FP32 ONNX (+ its .onnx_data) |
|
|
2) Cast weight tensors to FP16 |
|
|
3) Topo-sort / clean up |
|
|
4) Copy opset_import from original model |
|
|
5) Mark ALL tensors for external data |
|
|
6) Save the new ONNX + .onnx_data |
|
|
""" |
|
|
orig = onnx.load(input_path, load_external_data=True) |
|
|
model = onnx.load(input_path, load_external_data=True) |
|
|
|
|
|
disable_si = model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF |
|
|
blocked = set(float16.DEFAULT_OP_BLOCK_LIST) |
|
|
if op_block_list: |
|
|
blocked.update(op_block_list) |
|
|
blocked.update(["LayerNormalization", "Softmax", "Div"]) |
|
|
|
|
|
model_fp16 = float16.convert_float_to_float16( |
|
|
model, |
|
|
max_finite_val=65504.0, |
|
|
keep_io_types=True, |
|
|
disable_shape_infer=disable_si, |
|
|
op_block_list=blocked, |
|
|
) |
|
|
|
|
|
graph = gs.import_onnx(model_fp16) |
|
|
graph.toposort() |
|
|
model_fp16 = gs.export_onnx(graph) |
|
|
|
|
|
model_fp16.ClearField("opset_import") |
|
|
model_fp16.opset_import.extend(orig.opset_import) |
|
|
|
|
|
convert_model_to_external_data( |
|
|
model_fp16, |
|
|
all_tensors_to_one_file=True, |
|
|
location=external_data_filename, |
|
|
size_threshold=0, |
|
|
) |
|
|
|
|
|
|
|
|
if not model_fp16.opset_import: |
|
|
model_fp16.opset_import.extend( |
|
|
[ |
|
|
onnx.helper.make_opsetid("", 14), |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
check_and_save_model(model_fp16, output_path) |
|
|
|
|
|
print("✅ FP16 model quantized and saved:") |
|
|
print(f" ONNX: {output_path}") |
|
|
print( |
|
|
f" DATA: {os.path.join(os.path.dirname(output_path), external_data_filename)}" |
|
|
) |
|
|
return True |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Convert ColPali model to ONNX format and FP16 quantization" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model-id", default="vidore/colpali-v1.3-hf", help="HuggingFace model ID" |
|
|
) |
|
|
parser.add_argument("--output-dir", default=None, help="Output directory") |
|
|
parser.add_argument( |
|
|
"--quantize", action="store_true", help="Apply FP16 quantization after export" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--export-type", |
|
|
choices=["vision", "text", "both"], |
|
|
default="both", |
|
|
help="Which ONNX to export", |
|
|
) |
|
|
parser.add_argument("--device", default=None, help="Device for model (cuda/cpu)") |
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.device is None: |
|
|
args.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
if args.output_dir is None: |
|
|
args.output_dir = os.path.join("output", args.model_id.replace("/", "_")) |
|
|
|
|
|
out_paths = export_model( |
|
|
args.model_id, |
|
|
args.output_dir, |
|
|
args.device, |
|
|
fp16=False, |
|
|
export_type=args.export_type, |
|
|
) |
|
|
|
|
|
|
|
|
if args.quantize: |
|
|
print("Starting FP16 quantization") |
|
|
for key, path in out_paths.items(): |
|
|
binname = os.path.basename(path).replace(".onnx", ".onnx_data") |
|
|
quantize_fp16_and_externalize(path, path, binname) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|