RuntimeError: shape '[-1, 2, 3]' is invalid on specific Single Slice (Depth=1) volumes

#7
by Arvind69 - opened

Hi BAAI team,

I am encountering a crash when running inference on a specific single-slice volume (Depth=1) using use_zoom=True.

While other single-slice inputs work fine, this particular case triggers a shape mismatch in the prompt encoder.

Error:

RuntimeError: shape '[-1, 2, 3]' is invalid for input of size 4

(Occurs at coords = boxes.reshape(-1, 2, 3) inside PromptEncoder._embed_boxes)

Minimal Reproducible Code: Here is a snippet using a single-slice case that reproduces the crash:

from transformers import AutoModel, AutoTokenizer
import torch
import os

# 1. Setup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 2. Load Model
clip_tokenizer = AutoTokenizer.from_pretrained("BAAI/SegVol", local_files_only=True)
model = AutoModel.from_pretrained("BAAI/SegVol", trust_remote_code=True, test_mode=True, local_files_only=True)
model.model.text_encoder.tokenizer = clip_tokenizer
model.eval().to(device)

# 3. Path Configuration
ct_path = os.path.expanduser("~/scratch/M3D-Seg/M3D_Seg/0017/0017/kidney_case13/image.npy")
gt_path = os.path.expanduser("~/scratch/M3D-Seg/M3D_Seg/0017/0017/kidney_case13/mask_(3, 497, 497, 1).npz")
organ_name = "kidney"

# 4. Prepare Data
ct_npy, gt_npy = model.processor.load_uniseg_case(ct_path, gt_path)
data_item = model.processor.zoom_transform(ct_npy, gt_npy)

# Add batch dims and move to device
data_item['image'] = data_item['image'].unsqueeze(0).to(device)
data_item['zoom_out_image'] = data_item['zoom_out_image'].unsqueeze(0).to(device)
data_item['zoom_out_label'] = data_item['zoom_out_label'].unsqueeze(0).to(device)

# 5. Generate Prompts
single_vol = data_item['zoom_out_label'][0][0] # Isolate kidney channel
bbox_prompt, bbox_map = model.processor.bbox_prompt_b(single_vol, device=device)

# 6. Trigger RuntimeError
print(f"Input Image Shape: {data_item['image'].shape}") # (1, 1, 1, 426, 497) -> Depth=1
print("Running inference...")

model.forward_test(
    image=data_item['image'].to(torch.float32),
    zoomed_image=data_item['zoom_out_image'],
    bbox_prompt_group=[bbox_prompt, bbox_map],
    text_prompt=[organ_name],
    use_zoom=True
)

Output:

Input Image Shape: torch.Size([1, 1, 1, 426, 497])
Running inference...
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[18], line 36
     33 print(f"Input Image Shape: {data_item['image'].shape}") # (1, 1, 1, 426, 497) -> Depth=1
     34 print("Running inference...")
---> 36 model.forward_test(
     37     image=data_item['image'].to(torch.float32),
     38     zoomed_image=data_item['zoom_out_image'],
     39     bbox_prompt_group=[bbox_prompt, bbox_map],
     40     text_prompt=[organ_name],
     41     use_zoom=True
     42 )

File ~/.cache/huggingface/modules/transformers_modules/BAAI/SegVol/516701d5bef430604a13f5db7ce896149d075663/model_segvol_single.py:108, in SegVolModel.forward_test(self, image, zoomed_image, text_prompt, bbox_prompt_group, point_prompt_group, use_zoom)
    106 ## inference
    107 with torch.no_grad():
--> 108     logits_single_cropped = sliding_window_inference(
    109             image_single_cropped.to(device), prompt_reflection,
    110             self.config.spatial_size, 1, self.model, 0.5,
    111             text=text_prompt,
    112             use_box=bbox_prompt is not None,
    113             use_point=point_prompt is not None,
    114         )
    115     logits_single_cropped = logits_single_cropped.cpu().squeeze()
    116 logits_global_single[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped

File ~/.cache/huggingface/modules/transformers_modules/BAAI/SegVol/516701d5bef430604a13f5db7ce896149d075663/model_segvol_single.py:927, in sliding_window_inference(inputs, prompt_reflection, roi_size, sw_batch_size, predictor, overlap, mode, sigma_scale, padding_mode, cval, sw_device, device, progress, roi_weight_map, *args, **kwargs)
    925         pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device)
    926         boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float().to(device)
--> 927 seg_prob_out = predictor(window_data, text, boxes, points)  # batched patch segmentation
    928 #############
    929 # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
    930 seg_prob_tuple: Tuple[torch.Tensor, ...]

File ~/.local/share/mamba/envs/mmm/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/.local/share/mamba/envs/mmm/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/.cache/huggingface/modules/transformers_modules/BAAI/SegVol/516701d5bef430604a13f5db7ce896149d075663/model_segvol_single.py:450, in SegVol.forward(self, image, text, boxes, points, **kwargs)
    448 # test mode
    449 if self.test_mode:
--> 450     return self.forward_decoder(image_embedding, img_shape, text, boxes, points)
    452 # train mode
    453 ## sl
    454 sl_loss = self.supervised_forward(image, image_embedding, img_shape, kwargs['train_organs'], kwargs['train_labels'])

File ~/.cache/huggingface/modules/transformers_modules/BAAI/SegVol/516701d5bef430604a13f5db7ce896149d075663/model_segvol_single.py:469, in SegVol.forward_decoder(self, image_embedding, img_shape, text, boxes, points)
    467     else:
    468         text_embedding = None
--> 469 sparse_embeddings, dense_embeddings = self.prompt_encoder(
    470     points=points,
    471     boxes=boxes,
    472     masks=None,
    473     text_embedding=text_embedding,
    474 )
    476 dense_pe = self.prompt_encoder.get_dense_pe()
    477 low_res_masks, _ = self.mask_decoder(
    478     image_embeddings=image_embedding,
    479     text_embedding = text_embedding,
   (...)    483     multimask_output=False,
    484   )

File ~/.local/share/mamba/envs/mmm/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/.local/share/mamba/envs/mmm/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/.cache/huggingface/modules/transformers_modules/BAAI/SegVol/516701d5bef430604a13f5db7ce896149d075663/model_segvol_single.py:1458, in PromptEncoder.forward(self, points, boxes, masks, text_embedding)
   1455     sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
   1457 if boxes is not None:
-> 1458     box_embeddings = self._embed_boxes(boxes)
   1459     sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
   1461 if text_embedding is not None:

File ~/.cache/huggingface/modules/transformers_modules/BAAI/SegVol/516701d5bef430604a13f5db7ce896149d075663/model_segvol_single.py:1406, in PromptEncoder._embed_boxes(self, boxes)
   1404 """Embeds box prompts."""
   1405 boxes = boxes + 0.5  # Shift to center of pixel
-> 1406 coords = boxes.reshape(-1, 2, 3)
   1407 corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
   1408 corner_embedding[:, 0, :] += self.point_embeddings[2].weight

RuntimeError: shape '[-1, 2, 3]' is invalid for input of size 4

Workaround: Currently, I am bypassing this by manually repeating the slice to increase the depth (e.g., image.repeat(1, 1, 3, 1, 1)), which doesn't throw an error, but no segmentation mask is produced (all logits are negative)

Is there a recommended native fix for this?

Thanks!

Sign up or log in to comment