RuntimeError: shape '[-1, 2, 3]' is invalid on specific Single Slice (Depth=1) volumes
#7
by
Arvind69
- opened
Hi BAAI team,
I am encountering a crash when running inference on a specific single-slice volume (Depth=1) using use_zoom=True.
While other single-slice inputs work fine, this particular case triggers a shape mismatch in the prompt encoder.
Error:
RuntimeError: shape '[-1, 2, 3]' is invalid for input of size 4
(Occurs at coords = boxes.reshape(-1, 2, 3) inside PromptEncoder._embed_boxes)
Minimal Reproducible Code: Here is a snippet using a single-slice case that reproduces the crash:
from transformers import AutoModel, AutoTokenizer
import torch
import os
# 1. Setup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 2. Load Model
clip_tokenizer = AutoTokenizer.from_pretrained("BAAI/SegVol", local_files_only=True)
model = AutoModel.from_pretrained("BAAI/SegVol", trust_remote_code=True, test_mode=True, local_files_only=True)
model.model.text_encoder.tokenizer = clip_tokenizer
model.eval().to(device)
# 3. Path Configuration
ct_path = os.path.expanduser("~/scratch/M3D-Seg/M3D_Seg/0017/0017/kidney_case13/image.npy")
gt_path = os.path.expanduser("~/scratch/M3D-Seg/M3D_Seg/0017/0017/kidney_case13/mask_(3, 497, 497, 1).npz")
organ_name = "kidney"
# 4. Prepare Data
ct_npy, gt_npy = model.processor.load_uniseg_case(ct_path, gt_path)
data_item = model.processor.zoom_transform(ct_npy, gt_npy)
# Add batch dims and move to device
data_item['image'] = data_item['image'].unsqueeze(0).to(device)
data_item['zoom_out_image'] = data_item['zoom_out_image'].unsqueeze(0).to(device)
data_item['zoom_out_label'] = data_item['zoom_out_label'].unsqueeze(0).to(device)
# 5. Generate Prompts
single_vol = data_item['zoom_out_label'][0][0] # Isolate kidney channel
bbox_prompt, bbox_map = model.processor.bbox_prompt_b(single_vol, device=device)
# 6. Trigger RuntimeError
print(f"Input Image Shape: {data_item['image'].shape}") # (1, 1, 1, 426, 497) -> Depth=1
print("Running inference...")
model.forward_test(
image=data_item['image'].to(torch.float32),
zoomed_image=data_item['zoom_out_image'],
bbox_prompt_group=[bbox_prompt, bbox_map],
text_prompt=[organ_name],
use_zoom=True
)
Output:
Input Image Shape: torch.Size([1, 1, 1, 426, 497])
Running inference...
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[18], line 36
33 print(f"Input Image Shape: {data_item['image'].shape}") # (1, 1, 1, 426, 497) -> Depth=1
34 print("Running inference...")
---> 36 model.forward_test(
37 image=data_item['image'].to(torch.float32),
38 zoomed_image=data_item['zoom_out_image'],
39 bbox_prompt_group=[bbox_prompt, bbox_map],
40 text_prompt=[organ_name],
41 use_zoom=True
42 )
File ~/.cache/huggingface/modules/transformers_modules/BAAI/SegVol/516701d5bef430604a13f5db7ce896149d075663/model_segvol_single.py:108, in SegVolModel.forward_test(self, image, zoomed_image, text_prompt, bbox_prompt_group, point_prompt_group, use_zoom)
106 ## inference
107 with torch.no_grad():
--> 108 logits_single_cropped = sliding_window_inference(
109 image_single_cropped.to(device), prompt_reflection,
110 self.config.spatial_size, 1, self.model, 0.5,
111 text=text_prompt,
112 use_box=bbox_prompt is not None,
113 use_point=point_prompt is not None,
114 )
115 logits_single_cropped = logits_single_cropped.cpu().squeeze()
116 logits_global_single[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped
File ~/.cache/huggingface/modules/transformers_modules/BAAI/SegVol/516701d5bef430604a13f5db7ce896149d075663/model_segvol_single.py:927, in sliding_window_inference(inputs, prompt_reflection, roi_size, sw_batch_size, predictor, overlap, mode, sigma_scale, padding_mode, cval, sw_device, device, progress, roi_weight_map, *args, **kwargs)
925 pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device)
926 boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float().to(device)
--> 927 seg_prob_out = predictor(window_data, text, boxes, points) # batched patch segmentation
928 #############
929 # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
930 seg_prob_tuple: Tuple[torch.Tensor, ...]
File ~/.local/share/mamba/envs/mmm/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
File ~/.local/share/mamba/envs/mmm/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
1745 # If we don't have any hooks, we want to skip the rest of the logic in
1746 # this function, and just call forward.
1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1752 result = None
1753 called_always_called_hooks = set()
File ~/.cache/huggingface/modules/transformers_modules/BAAI/SegVol/516701d5bef430604a13f5db7ce896149d075663/model_segvol_single.py:450, in SegVol.forward(self, image, text, boxes, points, **kwargs)
448 # test mode
449 if self.test_mode:
--> 450 return self.forward_decoder(image_embedding, img_shape, text, boxes, points)
452 # train mode
453 ## sl
454 sl_loss = self.supervised_forward(image, image_embedding, img_shape, kwargs['train_organs'], kwargs['train_labels'])
File ~/.cache/huggingface/modules/transformers_modules/BAAI/SegVol/516701d5bef430604a13f5db7ce896149d075663/model_segvol_single.py:469, in SegVol.forward_decoder(self, image_embedding, img_shape, text, boxes, points)
467 else:
468 text_embedding = None
--> 469 sparse_embeddings, dense_embeddings = self.prompt_encoder(
470 points=points,
471 boxes=boxes,
472 masks=None,
473 text_embedding=text_embedding,
474 )
476 dense_pe = self.prompt_encoder.get_dense_pe()
477 low_res_masks, _ = self.mask_decoder(
478 image_embeddings=image_embedding,
479 text_embedding = text_embedding,
(...) 483 multimask_output=False,
484 )
File ~/.local/share/mamba/envs/mmm/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
File ~/.local/share/mamba/envs/mmm/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
1745 # If we don't have any hooks, we want to skip the rest of the logic in
1746 # this function, and just call forward.
1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1752 result = None
1753 called_always_called_hooks = set()
File ~/.cache/huggingface/modules/transformers_modules/BAAI/SegVol/516701d5bef430604a13f5db7ce896149d075663/model_segvol_single.py:1458, in PromptEncoder.forward(self, points, boxes, masks, text_embedding)
1455 sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
1457 if boxes is not None:
-> 1458 box_embeddings = self._embed_boxes(boxes)
1459 sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
1461 if text_embedding is not None:
File ~/.cache/huggingface/modules/transformers_modules/BAAI/SegVol/516701d5bef430604a13f5db7ce896149d075663/model_segvol_single.py:1406, in PromptEncoder._embed_boxes(self, boxes)
1404 """Embeds box prompts."""
1405 boxes = boxes + 0.5 # Shift to center of pixel
-> 1406 coords = boxes.reshape(-1, 2, 3)
1407 corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
1408 corner_embedding[:, 0, :] += self.point_embeddings[2].weight
RuntimeError: shape '[-1, 2, 3]' is invalid for input of size 4
Workaround: Currently, I am bypassing this by manually repeating the slice to increase the depth (e.g., image.repeat(1, 1, 3, 1, 1)), which doesn't throw an error, but no segmentation mask is produced (all logits are negative)
Is there a recommended native fix for this?
Thanks!