Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Export all SAM 2.1 model sizes to ONNX format. | |
| Supports: tiny, small, base-plus, and large models. | |
| """ | |
| import os | |
| import sys | |
| import subprocess | |
| import shutil | |
| import torch | |
| import torch.nn as nn | |
| import onnx | |
| import onnxruntime as ort | |
| from huggingface_hub import snapshot_download | |
| # Ensure repository root (which contains the local 'sam2' package) is on sys.path | |
| _REPO_ROOT = os.path.dirname(__file__) | |
| if _REPO_ROOT not in sys.path: | |
| sys.path.insert(0, _REPO_ROOT) | |
| from sam2.build_sam import build_sam2 | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| # Model configurations | |
| MODEL_CONFIGS = { | |
| 'tiny': { | |
| 'hf_id': 'facebook/sam2.1-hiera-tiny', | |
| 'config_file': 'configs/sam2.1/sam2.1_hiera_t.yaml', | |
| 'checkpoint_name': 'sam2.1_hiera_tiny.pt', | |
| 'bb_feat_sizes': [(256, 256), (128, 128), (64, 64)] | |
| }, | |
| 'small': { | |
| 'hf_id': 'facebook/sam2.1-hiera-small', | |
| 'config_file': 'configs/sam2.1/sam2.1_hiera_s.yaml', | |
| 'checkpoint_name': 'sam2.1_hiera_small.pt', | |
| 'bb_feat_sizes': [(256, 256), (128, 128), (64, 64)] | |
| }, | |
| 'base_plus': { | |
| 'hf_id': 'facebook/sam2.1-hiera-base-plus', | |
| 'config_file': 'configs/sam2.1/sam2.1_hiera_b+.yaml', | |
| 'checkpoint_name': 'sam2.1_hiera_base_plus.pt', | |
| 'bb_feat_sizes': [(256, 256), (128, 128), (64, 64)] | |
| }, | |
| 'large': { | |
| 'hf_id': 'facebook/sam2.1-hiera-large', | |
| 'config_file': 'configs/sam2.1/sam2.1_hiera_l.yaml', | |
| 'checkpoint_name': 'sam2.1_hiera_large.pt', | |
| 'bb_feat_sizes': [(256, 256), (128, 128), (64, 64)] | |
| } | |
| } | |
| def model_local_dir_from_size(model_size: str) -> str: | |
| """Return the local download directory for a given model size.""" | |
| return f"./sam2.1-hiera-{model_size.replace('_', '-')}-downloaded" | |
| def cleanup_downloaded_files_for_model(model_size: str) -> None: | |
| """Delete the downloaded files for a model size after successful export/tests. | |
| Safety checks ensure we only remove the expected snapshot directory. | |
| """ | |
| local_dir = model_local_dir_from_size(model_size) | |
| try: | |
| # Safety: ensure directory exists and name matches expected pattern | |
| base = os.path.basename(os.path.normpath(local_dir)) | |
| if os.path.isdir(local_dir) and base.startswith("sam2.1-hiera-") and base.endswith("-downloaded"): | |
| shutil.rmtree(local_dir) | |
| print(f"π§Ή Cleaned up downloaded files at: {local_dir}") | |
| else: | |
| print(f"β Skipping cleanup; unexpected directory path: {local_dir}") | |
| except Exception as e: | |
| print(f"β Failed to clean up {local_dir}: {e}") | |
| class SAM2CompleteModel(nn.Module): | |
| """Complete SAM2 model wrapper for ONNX export.""" | |
| def __init__(self, sam2_model, bb_feat_sizes): | |
| super().__init__() | |
| self.sam2_model = sam2_model | |
| self.image_encoder = sam2_model.image_encoder | |
| self.prompt_encoder = sam2_model.sam_prompt_encoder | |
| self.mask_decoder = sam2_model.sam_mask_decoder | |
| self.no_mem_embed = sam2_model.no_mem_embed | |
| self.directly_add_no_mem_embed = sam2_model.directly_add_no_mem_embed | |
| self.bb_feat_sizes = bb_feat_sizes | |
| # Precompute image_pe as a buffer for constant folding optimization | |
| with torch.no_grad(): | |
| self.register_buffer( | |
| "image_pe_const", | |
| self.prompt_encoder.get_dense_pe() | |
| ) | |
| def forward(self, image, point_coords, point_labels): | |
| """ | |
| Complete SAM2 forward pass. | |
| Args: | |
| image: [1, 3, 1024, 1024] - Input image | |
| point_coords: [1, N, 2] - Point coordinates in pixels | |
| point_labels: [1, N] - Point labels (1=positive, 0=negative) | |
| Returns: | |
| masks: [1, 3, 1024, 1024] - Predicted masks | |
| iou_predictions: [1, 3] - IoU predictions | |
| """ | |
| # 1. Image encoding | |
| backbone_out = self.sam2_model.forward_image(image) | |
| _, vision_feats, _, _ = self.sam2_model._prepare_backbone_features(backbone_out) | |
| # Add no_mem_embed if needed | |
| if self.directly_add_no_mem_embed: | |
| vision_feats[-1] = vision_feats[-1] + self.no_mem_embed | |
| # Process features | |
| feats = [] | |
| for feat, feat_size in zip(vision_feats[::-1], self.bb_feat_sizes[::-1]): | |
| feat_reshaped = feat.permute(1, 2, 0).reshape(1, -1, feat_size[0], feat_size[1]) | |
| feats.append(feat_reshaped) | |
| feats = feats[::-1] | |
| image_embeddings = feats[-1] | |
| high_res_features = feats[:-1] | |
| # 2. Prompt encoding | |
| points = (point_coords, point_labels) | |
| sparse_embeddings, dense_embeddings = self.prompt_encoder( | |
| points=points, boxes=None, masks=None | |
| ) | |
| # 3. Mask decoding | |
| low_res_masks, iou_predictions, _, _ = self.mask_decoder( | |
| image_embeddings=image_embeddings, | |
| image_pe=self.image_pe_const, | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| multimask_output=True, | |
| repeat_image=False, | |
| high_res_features=high_res_features, | |
| ) | |
| # 4. Upscale masks | |
| masks = torch.nn.functional.interpolate( | |
| low_res_masks, size=(1024, 1024), mode='bilinear', align_corners=False | |
| ) | |
| return masks, iou_predictions | |
| def download_model(model_size): | |
| """Download model from Hugging Face Hub.""" | |
| config = MODEL_CONFIGS[model_size] | |
| local_dir = f"./sam2.1-hiera-{model_size.replace('_', '-')}-downloaded" | |
| print(f"Downloading {model_size} model from {config['hf_id']}...") | |
| if os.path.exists(local_dir): | |
| print(f"β Model directory already exists: {local_dir}") | |
| return local_dir | |
| try: | |
| snapshot_download( | |
| repo_id=config['hf_id'], | |
| local_dir=local_dir, | |
| local_dir_use_symlinks=False, | |
| resume_download=True | |
| ) | |
| print(f"β Model downloaded to: {local_dir}") | |
| return local_dir | |
| except Exception as e: | |
| print(f"β Failed to download {model_size} model: {e}") | |
| return None | |
| def load_sam2_model(model_size): | |
| """Load SAM2 model of specified size.""" | |
| config = MODEL_CONFIGS[model_size] | |
| local_dir = download_model(model_size) | |
| if not local_dir: | |
| raise RuntimeError(f"Failed to download {model_size} model") | |
| config_file = config['config_file'] | |
| ckpt_path = os.path.join(local_dir, config['checkpoint_name']) | |
| if not os.path.exists(ckpt_path): | |
| raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") | |
| print(f"Loading {model_size} model...") | |
| sam2_model = build_sam2( | |
| config_file=config_file, | |
| ckpt_path=ckpt_path, | |
| device="cpu", | |
| mode="eval" | |
| ) | |
| print(f"β {model_size} model loaded successfully") | |
| return sam2_model, config['bb_feat_sizes'] | |
| def create_test_inputs(): | |
| """Create test inputs for the model.""" | |
| image = torch.randn(1, 3, 1024, 1024) | |
| point_coords = torch.tensor([[[512.0, 512.0]]], dtype=torch.float32) | |
| point_labels = torch.tensor([[1]], dtype=torch.float32) | |
| return image, point_coords, point_labels | |
| def test_model_wrapper(sam2_model, bb_feat_sizes, model_size): | |
| """Test the model wrapper before ONNX export.""" | |
| print(f"\nTesting {model_size} model wrapper...") | |
| wrapper = SAM2CompleteModel(sam2_model, bb_feat_sizes) | |
| wrapper.eval() | |
| image, point_coords, point_labels = create_test_inputs() | |
| with torch.no_grad(): | |
| masks, iou_predictions = wrapper(image, point_coords, point_labels) | |
| print(f"β {model_size} model wrapper test successful") | |
| print(f" - Masks shape: {masks.shape}") | |
| print(f" - IoU predictions shape: {iou_predictions.shape}") | |
| return wrapper | |
| def slim_onnx_model_with_onnxslim(input_path: str, image_shape=(1,3,1024,1024), num_points=1) -> bool: | |
| """Slim an ONNX model in-place using onnxslim via uvx. | |
| Returns True if slimming succeeded and replaced the original file. | |
| """ | |
| try: | |
| # Build command; include onnxruntime so model_check can run | |
| slim_path = input_path + ".slim.onnx" | |
| model_check_inputs = [ | |
| f"image:{','.join(map(str, image_shape))}", | |
| f"point_coords:1,{num_points},2", | |
| f"point_labels:1,{num_points}", | |
| ] | |
| cmd = [ | |
| "uvx", "--with", "onnxruntime", "onnxslim", | |
| input_path, slim_path, | |
| "--model-check", | |
| "--model-check-inputs", | |
| *model_check_inputs, | |
| ] | |
| print(f"Running ONNXSlim: {' '.join(cmd)}") | |
| res = subprocess.run(cmd, capture_output=True, text=True) | |
| if res.returncode != 0: | |
| print("ONNXSlim failed; keeping original model.") | |
| if res.stderr: | |
| print(res.stderr[:1000]) | |
| return False | |
| if not os.path.exists(slim_path): | |
| print("ONNXSlim did not produce output; keeping original model.") | |
| return False | |
| # Verify and replace original | |
| try: | |
| onnx_model = onnx.load(slim_path) | |
| onnx.checker.check_model(onnx_model) | |
| except Exception as e: | |
| print(f"Slimmed model failed ONNX checker: {e}; keeping original.") | |
| try: | |
| os.remove(slim_path) | |
| except Exception: | |
| pass | |
| return False | |
| # Replace original file atomically | |
| orig_size = os.path.getsize(input_path) | |
| slim_size = os.path.getsize(slim_path) | |
| os.replace(slim_path, input_path) | |
| print(f"β Replaced original ONNX with slimmed model. Size: {orig_size/(1024**2):.2f} MB -> {slim_size/(1024**2):.2f} MB") | |
| return True | |
| except FileNotFoundError as e: | |
| print(f"ONNXSlim or uvx not found: {e}. Skipping slimming.") | |
| except Exception as e: | |
| print(f"Unexpected error during ONNXSlim: {e}. Skipping slimming.") | |
| return False | |
| def export_model_to_onnx(sam2_model, bb_feat_sizes, model_size): | |
| """Export SAM2 model to ONNX format.""" | |
| output_path = f"sam2_{model_size}.onnx" | |
| print(f"\nExporting {model_size} model to ONNX...") | |
| wrapper = SAM2CompleteModel(sam2_model, bb_feat_sizes) | |
| wrapper.eval() | |
| image, point_coords, point_labels = create_test_inputs() | |
| try: | |
| torch.onnx.export( | |
| wrapper, | |
| (image, point_coords, point_labels), | |
| output_path, | |
| export_params=True, | |
| opset_version=17, | |
| do_constant_folding=True, | |
| input_names=['image', 'point_coords', 'point_labels'], | |
| output_names=['masks', 'iou_predictions'], | |
| dynamic_axes={ | |
| 'image': {0: 'batch_size'}, | |
| 'point_coords': {0: 'batch_size', 1: 'num_points'}, | |
| 'point_labels': {0: 'batch_size', 1: 'num_points'}, | |
| 'masks': {0: 'batch_size'}, | |
| 'iou_predictions': {0: 'batch_size'} | |
| }, | |
| training=torch.onnx.TrainingMode.EVAL, | |
| keep_initializers_as_inputs=False, | |
| verbose=False | |
| ) | |
| print(f"β {model_size} model exported to: {output_path}") | |
| # Verify the exported model | |
| onnx_model = onnx.load(output_path) | |
| onnx.checker.check_model(onnx_model) | |
| print(f"β ONNX model verification passed") | |
| # Get model info | |
| file_size = os.path.getsize(output_path) | |
| print(f"β ONNX model size: {file_size / (1024**2):.2f} MB") | |
| # Try to slim the ONNX model in-place with onnxslim | |
| slimmed = slim_onnx_model_with_onnxslim(output_path, image_shape=(1,3,1024,1024), num_points=1) | |
| if slimmed: | |
| # Recompute size after slimming | |
| file_size = os.path.getsize(output_path) | |
| print(f"β Slimmed ONNX model size: {file_size / (1024**2):.2f} MB") | |
| else: | |
| print("β Skipping slimming or slimming failed; using original ONNX model.") | |
| return output_path, file_size | |
| except Exception as e: | |
| print(f"β Error exporting {model_size} to ONNX: {e}") | |
| raise | |
| def test_onnx_model(onnx_path, original_model, bb_feat_sizes, model_size): | |
| """Test the ONNX model and compare with original.""" | |
| print(f"\nTesting {model_size} ONNX model...") | |
| try: | |
| # Load ONNX model with CPU-optimized session options | |
| sess_options = ort.SessionOptions() | |
| sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| sess_options.enable_mem_pattern = True | |
| sess_options.enable_cpu_mem_arena = True | |
| try: | |
| import os as _os | |
| sess_options.intra_op_num_threads = max(1, (_os.cpu_count() or 1) // 2) | |
| except Exception: | |
| pass | |
| sess_options.inter_op_num_threads = 1 | |
| providers = [("CPUExecutionProvider", {"use_arena": True})] | |
| ort_session = ort.InferenceSession(onnx_path, sess_options, providers=providers) | |
| image, point_coords, point_labels = create_test_inputs() | |
| # Run ONNX inference | |
| ort_inputs = { | |
| 'image': image.numpy(), | |
| 'point_coords': point_coords.numpy(), | |
| 'point_labels': point_labels.numpy() | |
| } | |
| onnx_outputs = ort_session.run(None, ort_inputs) | |
| onnx_masks, onnx_iou = onnx_outputs | |
| # Compare with original model | |
| wrapper = SAM2CompleteModel(original_model, bb_feat_sizes) | |
| wrapper.eval() | |
| with torch.no_grad(): | |
| torch_masks, torch_iou = wrapper(image, point_coords, point_labels) | |
| torch_masks = torch_masks.numpy() | |
| torch_iou = torch_iou.numpy() | |
| # Calculate differences | |
| mask_max_diff = abs(onnx_masks - torch_masks).max() | |
| iou_max_diff = abs(onnx_iou - torch_iou).max() | |
| print(f"β {model_size} ONNX inference successful") | |
| print(f" - Masks max difference: {mask_max_diff:.6f}") | |
| print(f" - IoU max difference: {iou_max_diff:.6f}") | |
| tolerance = 1e-3 | |
| success = mask_max_diff < tolerance and iou_max_diff < tolerance | |
| if success: | |
| print(f"β Numerical accuracy within tolerance ({tolerance})") | |
| else: | |
| print(f"β Some differences exceed tolerance ({tolerance})") | |
| return success | |
| except Exception as e: | |
| print(f"β Error testing {model_size} ONNX model: {e}") | |
| return False | |
| def export_all_models(): | |
| """Export all SAM2.1 model sizes to ONNX.""" | |
| print("=== SAM 2.1 All Models ONNX Export ===\n") | |
| results = {} | |
| for model_size in MODEL_CONFIGS.keys(): | |
| try: | |
| print(f"\n{'='*50}") | |
| print(f"Processing {model_size.upper()} model") | |
| print(f"{'='*50}") | |
| # Load model | |
| sam2_model, bb_feat_sizes = load_sam2_model(model_size) | |
| # Test wrapper | |
| wrapper = test_model_wrapper(sam2_model, bb_feat_sizes, model_size) | |
| # Export to ONNX | |
| onnx_path, file_size = export_model_to_onnx(sam2_model, bb_feat_sizes, model_size) | |
| # Test ONNX model | |
| success = test_onnx_model(onnx_path, sam2_model, bb_feat_sizes, model_size) | |
| # Cleanup downloaded files only if export + test succeeded | |
| if success: | |
| cleanup_downloaded_files_for_model(model_size) | |
| else: | |
| print(f"β Skipping cleanup for {model_size}; export/test not fully successful.") | |
| results[model_size] = { | |
| 'onnx_path': onnx_path, | |
| 'file_size_mb': file_size / (1024**2), | |
| 'success': success | |
| } | |
| print(f"β {model_size} model export completed!") | |
| except Exception as e: | |
| print(f"β Failed to export {model_size} model: {e}") | |
| results[model_size] = { | |
| 'error': str(e), | |
| 'success': False | |
| } | |
| # Print summary | |
| print(f"\n{'='*60}") | |
| print("EXPORT SUMMARY") | |
| print(f"{'='*60}") | |
| for model_size, result in results.items(): | |
| if result['success']: | |
| print(f"β {model_size:12} - {result['onnx_path']:20} ({result['file_size_mb']:.1f} MB)") | |
| else: | |
| print(f"β {model_size:12} - FAILED") | |
| return results | |
| if __name__ == "__main__": | |
| export_all_models() | |