HV-Khurdula commited on
Commit
2c36d34
·
verified ·
1 Parent(s): 1ca98b0

Update moondream.py

Browse files

fix: sdpa dimension mismatch.

Files changed (1) hide show
  1. moondream.py +49 -49
moondream.py CHANGED
@@ -900,17 +900,18 @@ class MoondreamModel(nn.Module):
900
  return last_hidden, next_token, pos_vec # (B,1,C), (B,1), (B,)
901
 
902
  def _generate_points_batched(
903
- self,
904
- hidden: torch.Tensor, # (B, 1, C) last hidden per row from prefill
905
- next_token: torch.Tensor, # (B, 1) not used directly (kept for parity)
906
- pos_vec: torch.Tensor, # (B,) next write pos per row after prefill
907
- include_size: bool = True,
908
- max_objects: int = 50,
909
- lora=None,
910
- ):
911
  """
912
  Batched decode loop for multi-label detection.
913
- - Uses a *shared* scalar position id per step (q_len = 1), as expected by RoPE.
 
914
  - Maintains a per-row attention mask and 'alive' flags.
915
  - Feeds coord encoders with (B,1) tensors; size encoder with (B,2).
916
  Returns: list-of-lists of dicts, length B.
@@ -920,35 +921,35 @@ class MoondreamModel(nn.Module):
920
  out = [[] for _ in range(B)]
921
  eos_id = self.config.tokenizer.eos_id
922
 
923
- # --- Shared write position (scalar) consistent with RoPE q_len=1 ---
924
- # We align rows by padding; using the maximum ensures all KV rows can decode in lockstep.
925
- pos = int(pos_vec.max().item())
926
-
927
- # Per-row attention mask (1 = visible). Mark everything up to 'pos' as visible.
928
  max_ctx = self.config.text.max_context
929
  mask = torch.zeros(B, 1, max_ctx, device=device, dtype=torch.bool)
930
- mask[:, :, :pos] = 1
 
 
 
 
931
 
932
- alive = torch.ones(B, dtype=torch.bool, device=device)
933
  counts = torch.zeros(B, dtype=torch.int32, device=device)
934
 
935
  with torch.inference_mode():
936
  while alive.any() and (counts < max_objects).any():
937
  # --- x coordinate ---
938
- x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
939
  if x_logits.dim() == 3:
940
- x_logits = x_logits.squeeze(1) # (B,1024)
941
- x_bin = x_logits.argmax(dim=-1).to(torch.float32) # (B,)
942
- x_center = x_bin / float(x_logits.size(-1)) # (B,)
943
- x_input = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
944
- x_emb = encode_coordinate(x_input, self.region).unsqueeze(1) # (B,1,C)
945
 
946
- # step: decode hidden for y (advance shared pos)
947
- mask[:, :, pos] = 1
948
  logits, hidden = self._decode_one_tok(
949
  x_emb,
950
- mask,
951
- torch.tensor([pos], device=device, dtype=torch.long), # length-1 (q_len=1)
952
  lora,
953
  )
954
  pos += 1
@@ -956,17 +957,16 @@ class MoondreamModel(nn.Module):
956
  # --- y coordinate ---
957
  y_logits = decode_coordinate(hidden, self.region)
958
  if y_logits.dim() == 3:
959
- y_logits = y_logits.squeeze(1)
960
- y_bin = y_logits.argmax(dim=-1).to(torch.float32)
961
- y_center = y_bin / float(y_logits.size(-1)) # (B,)
962
- y_input = y_center.to(dtype=y_logits.dtype).unsqueeze(-1) # (B,1)
963
- y_emb = encode_coordinate(y_input, self.region).unsqueeze(1) # (B,1,C)
964
 
965
- # step: decode hidden for size / eos (advance shared pos)
966
- mask[:, :, pos] = 1
967
  logits, hidden = self._decode_one_tok(
968
  y_emb,
969
- mask,
970
  torch.tensor([pos], device=device, dtype=torch.long),
971
  lora,
972
  )
@@ -974,17 +974,17 @@ class MoondreamModel(nn.Module):
974
 
975
  if include_size:
976
  # --- size (batched) ---
977
- size_logits = decode_size(hidden, self.region) # ([B,1,1024],[B,1,1024])
978
- w_logits, h_logits = size_logits[0].squeeze(1), size_logits[1].squeeze(1) # (B,1024)
979
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
980
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
981
- # Convert log-scale bins -> sizes in [0,1]
982
- w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0) # (B,)
983
- h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0) # (B,)
984
  size_input = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
985
- size_emb = encode_size(size_input, self.region).unsqueeze(1) # (B,1,C)
986
 
987
- # Record boxes for alive rows
988
  for i in range(B):
989
  if not alive[i]:
990
  continue
@@ -995,32 +995,31 @@ class MoondreamModel(nn.Module):
995
  "y_max": (y_center[i] + h[i] / 2).item(),
996
  })
997
 
998
- # step: decode "next token" to decide continuation
999
- mask[:, :, pos] = 1
1000
  logits, hidden = self._decode_one_tok(
1001
  size_emb,
1002
- mask,
1003
  torch.tensor([pos], device=device, dtype=torch.long),
1004
  lora,
1005
  )
1006
  pos += 1
1007
- next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1008
  else:
1009
  # Points mode (no size)
1010
  for i in range(B):
1011
  if alive[i]:
1012
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1013
- mask[:, :, pos] = 1
1014
  logits, hidden = self._decode_one_tok(
1015
  y_emb,
1016
- mask,
1017
  torch.tensor([pos], device=device, dtype=torch.long),
1018
  lora,
1019
  )
1020
  pos += 1
1021
- next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1022
 
1023
- # Update finished/alive bookkeeping
1024
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1025
  counts = counts + (~finished_now & alive).to(counts.dtype)
1026
  alive &= ~finished_now
@@ -1028,6 +1027,7 @@ class MoondreamModel(nn.Module):
1028
  return out
1029
 
1030
 
 
1031
  def detect_multi(self, image, objects, settings=None):
1032
  """
1033
  Parallel multi-label detection.
 
900
  return last_hidden, next_token, pos_vec # (B,1,C), (B,1), (B,)
901
 
902
  def _generate_points_batched(
903
+ self,
904
+ hidden: torch.Tensor, # (B, 1, C) last hidden per row from prefill
905
+ next_token: torch.Tensor, # (B, 1) unused here; kept for parity
906
+ pos_vec: torch.Tensor, # (B,) next write pos per row after prefill
907
+ include_size: bool = True,
908
+ max_objects: int = 50,
909
+ lora=None
910
+ ):
911
  """
912
  Batched decode loop for multi-label detection.
913
+
914
+ - Uses a shared scalar position id per step (q_len = 1), as expected by RoPE.
915
  - Maintains a per-row attention mask and 'alive' flags.
916
  - Feeds coord encoders with (B,1) tensors; size encoder with (B,2).
917
  Returns: list-of-lists of dicts, length B.
 
921
  out = [[] for _ in range(B)]
922
  eos_id = self.config.tokenizer.eos_id
923
 
924
+ # Per-row initial visibility up to each row's individual prefill pos
 
 
 
 
925
  max_ctx = self.config.text.max_context
926
  mask = torch.zeros(B, 1, max_ctx, device=device, dtype=torch.bool)
927
+ for i in range(B):
928
+ mask[i, :, : int(pos_vec[i].item())] = 1
929
+
930
+ # Shared write index so RoPE sees a scalar q_len=1 position id
931
+ pos = int(pos_vec.max().item())
932
 
933
+ alive = torch.ones(B, dtype=torch.bool, device=device)
934
  counts = torch.zeros(B, dtype=torch.int32, device=device)
935
 
936
  with torch.inference_mode():
937
  while alive.any() and (counts < max_objects).any():
938
  # --- x coordinate ---
939
+ x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
940
  if x_logits.dim() == 3:
941
+ x_logits = x_logits.squeeze(1) # (B,1024)
942
+ x_bin = x_logits.argmax(dim=-1).to(torch.float32) # (B,)
943
+ x_center = x_bin / float(x_logits.size(-1)) # (B,)
944
+ x_input = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
945
+ x_emb = encode_coordinate(x_input, self.region).unsqueeze(1) # (B,1,C)
946
 
947
+ # Advance visibility at shared 'pos' and decode (q_len=1)
948
+ mask[alive, :, pos] = 1
949
  logits, hidden = self._decode_one_tok(
950
  x_emb,
951
+ mask.unsqueeze(2), # (B,1,1,max_ctx)
952
+ torch.tensor([pos], device=device, dtype=torch.long),
953
  lora,
954
  )
955
  pos += 1
 
957
  # --- y coordinate ---
958
  y_logits = decode_coordinate(hidden, self.region)
959
  if y_logits.dim() == 3:
960
+ y_logits = y_logits.squeeze(1) # (B,1024)
961
+ y_bin = y_logits.argmax(dim=-1).to(torch.float32)
962
+ y_center = y_bin / float(y_logits.size(-1)) # (B,)
963
+ y_input = y_center.to(dtype=y_logits.dtype).unsqueeze(-1) # (B,1)
964
+ y_emb = encode_coordinate(y_input, self.region).unsqueeze(1) # (B,1,C)
965
 
966
+ mask[alive, :, pos] = 1
 
967
  logits, hidden = self._decode_one_tok(
968
  y_emb,
969
+ mask.unsqueeze(2), # (B,1,1,max_ctx)
970
  torch.tensor([pos], device=device, dtype=torch.long),
971
  lora,
972
  )
 
974
 
975
  if include_size:
976
  # --- size (batched) ---
977
+ size_logits = decode_size(hidden, self.region)
978
+ w_logits, h_logits = size_logits[0].squeeze(1), size_logits[1].squeeze(1)
979
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
980
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
981
+ # log-scale bin actual size in [0,1]
982
+ w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
983
+ h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
984
  size_input = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
985
+ size_emb = encode_size(size_input, self.region).unsqueeze(1) # (B,1,C)
986
 
987
+ # Commit boxes for alive rows
988
  for i in range(B):
989
  if not alive[i]:
990
  continue
 
995
  "y_max": (y_center[i] + h[i] / 2).item(),
996
  })
997
 
998
+ mask[alive, :, pos] = 1
 
999
  logits, hidden = self._decode_one_tok(
1000
  size_emb,
1001
+ mask.unsqueeze(2), # (B,1,1,max_ctx) ✅
1002
  torch.tensor([pos], device=device, dtype=torch.long),
1003
  lora,
1004
  )
1005
  pos += 1
1006
+ next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1007
  else:
1008
  # Points mode (no size)
1009
  for i in range(B):
1010
  if alive[i]:
1011
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1012
+ mask[alive, :, pos] = 1
1013
  logits, hidden = self._decode_one_tok(
1014
  y_emb,
1015
+ mask.unsqueeze(2), # (B,1,1,max_ctx) ✅
1016
  torch.tensor([pos], device=device, dtype=torch.long),
1017
  lora,
1018
  )
1019
  pos += 1
1020
+ next_tok = logits.argmax(dim=-1).squeeze(-1)
1021
 
1022
+ # Finish rows that emitted EOS or hit object cap
1023
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1024
  counts = counts + (~finished_now & alive).to(counts.dtype)
1025
  alive &= ~finished_now
 
1027
  return out
1028
 
1029
 
1030
+
1031
  def detect_multi(self, image, objects, settings=None):
1032
  """
1033
  Parallel multi-label detection.