HV-Khurdula commited on
Commit
aeef384
·
verified ·
1 Parent(s): c0e9503

Update moondream.py

Browse files

fix: kv cache mismatch and decode size.

Files changed (1) hide show
  1. moondream.py +114 -66
moondream.py CHANGED
@@ -884,36 +884,53 @@ class MoondreamModel(nn.Module):
884
 
885
  return {"points": objects}
886
 
887
- def _norm_size_logits(self, size_ret, B: int):
 
888
  """
889
  Accepts any of:
890
- • (w_logits, h_logits)
891
- • Tensor (B,2,C) or (B,1,2,C) or (1,2,C) or (2,C) (B==1)
 
 
892
  Returns (w_logits, h_logits) each shaped (B, C).
893
  """
894
  if isinstance(size_ret, (tuple, list)):
895
  w_logits, h_logits = size_ret
896
  else:
897
  t = size_ret
898
- # squeeze all singleton dims except batch & vocab
899
- while t.dim() > 3:
900
- t = t.squeeze(1)
901
- if t.dim() == 3: # (B,2,C)
 
 
 
 
 
 
902
  w_logits, h_logits = t[:, 0, :], t[:, 1, :]
903
  elif t.dim() == 2:
904
- if t.size(0) == 2 and B == 1: # (2,C) with B==1
 
905
  w_logits, h_logits = t[0].unsqueeze(0), t[1].unsqueeze(0)
906
- else: # (B,2C) fallback
907
- C2 = t.size(1); C = C2 // 2
 
 
 
908
  w_logits, h_logits = t[:, :C], t[:, C:]
909
  else:
910
  raise RuntimeError(f"Unexpected decode_size shape {tuple(t.shape)}")
911
- # final squeeze if needed
 
912
  if w_logits.dim() == 3: w_logits = w_logits.squeeze(1)
913
  if h_logits.dim() == 3: h_logits = h_logits.squeeze(1)
 
 
914
  return w_logits.contiguous(), h_logits.contiguous()
915
 
916
 
 
917
  def _load_encoded_image_batched(self, encoded_image, batch_size: int):
918
  for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
919
  T = k.size(2)
@@ -969,26 +986,35 @@ class MoondreamModel(nn.Module):
969
 
970
 
971
  def _generate_points_batched(
972
- self, hidden, next_token, pos, include_size: bool = True,
973
- max_objects: int = 50, lora=None, use_soft_argmax: bool = False):
 
 
 
 
 
 
 
974
  B = hidden.size(0)
975
  device = self.device
976
  out = [[] for _ in range(B)]
977
  eos_id = self.config.tokenizer.eos_id
978
  max_ctx = self.config.text.max_context
979
 
980
- # 4-D mask: (B,1,1,kv_len); advance with a 1-D position vector
981
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
982
- if int(pos) > 0:
983
- mask[:, :, :, :int(pos)] = True
984
- pos_id_vec = torch.full((1,), int(pos), device=device, dtype=torch.long)
 
 
985
 
986
- def _center01(logits):
987
- # logits: (B, bins) (B,) in [0,1]
988
  if use_soft_argmax:
989
- p = torch.softmax(logits, dim=-1)
990
- bins = torch.arange(p.size(-1), device=logits.device, dtype=torch.float32)
991
- return (p * bins).sum(dim=-1) / float(p.size(-1) - 1)
992
  idx = logits.argmax(dim=-1).to(torch.float32)
993
  return idx / float(logits.size(-1) - 1)
994
 
@@ -997,38 +1023,41 @@ class MoondreamModel(nn.Module):
997
 
998
  with torch.inference_mode():
999
  while alive.any() and (counts < max_objects).any():
1000
- # x
1001
- x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
1002
- if x_logits.dim() == 3: x_logits = x_logits.squeeze(1)
1003
- x_center = _center01(x_logits)
1004
- x_emb = encode_coordinate(x_center.to(dtype=x_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
 
 
 
 
1005
 
1006
- mask[:, :, :, pos_id_vec] = True
1007
- logits, hidden = self._decode_one_tok(x_emb, mask, pos_id_vec, lora)
1008
- pos_id_vec += 1
 
 
1009
 
1010
- # y
1011
  y_logits = decode_coordinate(hidden, self.region)
1012
- if y_logits.dim() == 3: y_logits = y_logits.squeeze(1)
1013
- y_center = _center01(y_logits)
1014
- y_emb = encode_coordinate(y_center.to(dtype=y_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
 
 
 
 
1015
 
1016
- mask[:, :, :, pos_id_vec] = True
1017
- logits, hidden = self._decode_one_tok(y_emb, mask, pos_id_vec, lora)
1018
- pos_id_vec += 1
 
1019
 
1020
  if include_size:
 
1021
  size_ret = decode_size(hidden, self.region)
1022
- # Robust parse: accept (w,h) tuple OR Tensor (B,2,C)/(B,1,2,C)
1023
- if isinstance(size_ret, (tuple, list)):
1024
- w_logits, h_logits = size_ret
1025
- else:
1026
- t = size_ret
1027
- if t.dim() == 4: # (B,1,2,C)
1028
- t = t.squeeze(1) # → (B,2,C)
1029
- if t.dim() != 3 or t.size(1) != 2:
1030
- raise RuntimeError(f"Unexpected decode_size shape {tuple(t.shape)}")
1031
- w_logits, h_logits = t[:, 0, :], t[:, 1, :]
1032
 
1033
  if use_soft_argmax:
1034
  bins = torch.arange(w_logits.size(-1), device=device, dtype=torch.float32)
@@ -1038,14 +1067,18 @@ class MoondreamModel(nn.Module):
1038
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1039
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
1040
 
1041
- # inverse log-scale mapping used by md2
1042
- w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
1043
- h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
1044
 
1045
- size_emb = encode_size(torch.stack([w, h], dim=1).to(dtype=w_logits.dtype), self.region).unsqueeze(1)
 
 
 
1046
 
1047
- for i in range(B):
1048
- if not alive[i]: continue
 
1049
  xl = (x_center[i] - w[i] / 2).item()
1050
  xr = (x_center[i] + w[i] / 2).item()
1051
  yt = (y_center[i] - h[i] / 2).item()
@@ -1057,26 +1090,41 @@ class MoondreamModel(nn.Module):
1057
  "y_max": max(0.0, min(1.0, yb)),
1058
  })
1059
 
1060
- mask[:, :, :, pos_id_vec] = True
1061
- logits, hidden = self._decode_one_tok(size_emb, mask, pos_id_vec, lora)
1062
- pos_id_vec += 1
1063
- next_tok = logits.argmax(dim=-1).squeeze(-1)
 
 
 
 
 
1064
  else:
1065
- for i in range(B):
1066
- if alive[i]:
1067
- out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1068
- mask[:, :, :, pos_id_vec] = True
1069
- logits, hidden = self._decode_one_tok(y_emb, mask, pos_id_vec, lora)
1070
- pos_id_vec += 1
1071
- next_tok = logits.argmax(dim=-1).squeeze(-1)
 
 
 
 
 
 
1072
 
1073
- finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1074
- counts = counts + ((~finished_now) & alive).to(counts.dtype)
 
 
 
1075
  alive &= ~finished_now
1076
 
1077
  return out
1078
 
1079
 
 
1080
  def detect_multi(self, image, objects, settings=None):
1081
  if self.config.tokenizer.templates["detect"] is None:
1082
  raise NotImplementedError("Model does not support object detection.")
 
884
 
885
  return {"points": objects}
886
 
887
+ # moondream.py
888
+ def _norm_size_logits(self, size_ret: torch.Tensor | tuple, B: int):
889
  """
890
  Accepts any of:
891
+ tuple/list: (w_logits, h_logits)
892
+ • Tensor (..., 2, C) # from batch-safe region.decode_size
893
+ • Tensor (B, 2*C) # fallback
894
+ • Tensor (2, C) when B == 1
895
  Returns (w_logits, h_logits) each shaped (B, C).
896
  """
897
  if isinstance(size_ret, (tuple, list)):
898
  w_logits, h_logits = size_ret
899
  else:
900
  t = size_ret
901
+ # if we got (..., 2, C), squeeze a single seq dim if present
902
+ if t.dim() >= 3 and t.shape[-2] == 2:
903
+ # bring to (B, 2, C)
904
+ while t.dim() > 3:
905
+ t = t.squeeze(1)
906
+ if t.dim() != 3 or t.shape[0] not in (1, B):
907
+ raise RuntimeError(f"Unexpected batched size logits shape {tuple(size_ret.shape)}")
908
+ # expand B if needed
909
+ if t.shape[0] == 1 and B > 1:
910
+ t = t.expand(B, -1, -1).contiguous()
911
  w_logits, h_logits = t[:, 0, :], t[:, 1, :]
912
  elif t.dim() == 2:
913
+ # (2, C) (B==1) or (B, 2*C)
914
+ if t.shape[0] == 2 and B == 1:
915
  w_logits, h_logits = t[0].unsqueeze(0), t[1].unsqueeze(0)
916
+ else:
917
+ C2 = t.shape[1]
918
+ if C2 % 2 != 0:
919
+ raise RuntimeError(f"Cannot split last dim {C2} into (w,h)")
920
+ C = C2 // 2
921
  w_logits, h_logits = t[:, :C], t[:, C:]
922
  else:
923
  raise RuntimeError(f"Unexpected decode_size shape {tuple(t.shape)}")
924
+
925
+ # final sanity: make sure they’re (B, C)
926
  if w_logits.dim() == 3: w_logits = w_logits.squeeze(1)
927
  if h_logits.dim() == 3: h_logits = h_logits.squeeze(1)
928
+ if w_logits.shape[0] != B or h_logits.shape[0] != B:
929
+ raise RuntimeError(f"Batched size logits mismatch: got {w_logits.shape[0]} vs B={B}")
930
  return w_logits.contiguous(), h_logits.contiguous()
931
 
932
 
933
+
934
  def _load_encoded_image_batched(self, encoded_image, batch_size: int):
935
  for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
936
  T = k.size(2)
 
986
 
987
 
988
  def _generate_points_batched(
989
+ self,
990
+ hidden, # (B,1,C)
991
+ next_token, # (B,1) (unused for greedy)
992
+ pos, # int (start position in cache)
993
+ include_size: bool = True,
994
+ max_objects: int = 50,
995
+ lora=None,
996
+ use_soft_argmax: bool = True, # reduces jitter/hallucinations
997
+ ):
998
  B = hidden.size(0)
999
  device = self.device
1000
  out = [[] for _ in range(B)]
1001
  eos_id = self.config.tokenizer.eos_id
1002
  max_ctx = self.config.text.max_context
1003
 
1004
+ # 4-D mask: (B, 1, q_len=1, kv_len)
1005
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
1006
+ p0 = int(pos)
1007
+ if p0 > 0:
1008
+ mask[:, :, :, :p0] = True
1009
+ # per-row position ids (B,1)
1010
+ pos_ids = torch.full((B, 1), p0, device=device, dtype=torch.long)
1011
 
1012
+ # helper: (B, bins) -> (B,) in [0,1]
1013
+ def _argmax01(logits: torch.Tensor) -> torch.Tensor:
1014
  if use_soft_argmax:
1015
+ probs = torch.softmax(logits, dim=-1)
1016
+ bins = torch.arange(probs.size(-1), device=logits.device, dtype=torch.float32)
1017
+ return (probs * bins).sum(dim=-1) / float(probs.size(-1) - 1)
1018
  idx = logits.argmax(dim=-1).to(torch.float32)
1019
  return idx / float(logits.size(-1) - 1)
1020
 
 
1023
 
1024
  with torch.inference_mode():
1025
  while alive.any() and (counts < max_objects).any():
1026
+ # ---------------- x ----------------
1027
+ x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
1028
+ if x_logits.dim() == 3:
1029
+ x_logits = x_logits.squeeze(1) # -> (B,1024)
1030
+ x_center = _argmax01(x_logits) # (B,)
1031
+ x_emb = encode_coordinate(
1032
+ x_center.to(dtype=x_logits.dtype).unsqueeze(-1), # (B,1)
1033
+ self.region
1034
+ ).unsqueeze(1) # (B,1,C)
1035
 
1036
+ # advance one token for ALIVE rows only
1037
+ step_col = int(pos_ids[0, 0].item())
1038
+ mask[alive, :, :, step_col] = True
1039
+ logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
1040
+ pos_ids[alive, 0] += 1
1041
 
1042
+ # ---------------- y ----------------
1043
  y_logits = decode_coordinate(hidden, self.region)
1044
+ if y_logits.dim() == 3:
1045
+ y_logits = y_logits.squeeze(1) # (B,1024)
1046
+ y_center = _argmax01(y_logits) # (B,)
1047
+ y_emb = encode_coordinate(
1048
+ y_center.to(dtype=y_logits.dtype).unsqueeze(-1),
1049
+ self.region
1050
+ ).unsqueeze(1) # (B,1,C)
1051
 
1052
+ step_col = int(pos_ids[0, 0].item())
1053
+ mask[alive, :, :, step_col] = True
1054
+ logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1055
+ pos_ids[alive, 0] += 1
1056
 
1057
  if include_size:
1058
+ # ------------- size (w,h) -------------
1059
  size_ret = decode_size(hidden, self.region)
1060
+ w_logits, h_logits = self._norm_size_logits(size_ret, B) # each (B,C)
 
 
 
 
 
 
 
 
 
1061
 
1062
  if use_soft_argmax:
1063
  bins = torch.arange(w_logits.size(-1), device=device, dtype=torch.float32)
 
1067
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1068
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
1069
 
1070
+ # inverse log-scale mapping used by MD2
1071
+ w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0) # (B,)
1072
+ h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0) # (B,)
1073
 
1074
+ size_emb = encode_size(
1075
+ torch.stack([w, h], dim=1).to(dtype=w_logits.dtype), # (B,2)
1076
+ self.region
1077
+ ).unsqueeze(1) # (B,1,C)
1078
 
1079
+ # record boxes only for ALIVE rows
1080
+ alive_idx = alive.nonzero(as_tuple=False).view(-1)
1081
+ for i in alive_idx.tolist():
1082
  xl = (x_center[i] - w[i] / 2).item()
1083
  xr = (x_center[i] + w[i] / 2).item()
1084
  yt = (y_center[i] - h[i] / 2).item()
 
1090
  "y_max": max(0.0, min(1.0, yb)),
1091
  })
1092
 
1093
+ step_col = int(pos_ids[0, 0].item())
1094
+ mask[alive, :, :, step_col] = True
1095
+ logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
1096
+ pos_ids[alive, 0] += 1
1097
+ next_tok = logits.argmax(dim=-1)
1098
+ if next_tok.dim() == 3: # (B,1,1) possible
1099
+ next_tok = next_tok.squeeze(-1).squeeze(-1)
1100
+ elif next_tok.dim() == 2: # (B,1)
1101
+ next_tok = next_tok.squeeze(1)
1102
  else:
1103
+ # point mode
1104
+ alive_idx = alive.nonzero(as_tuple=False).view(-1)
1105
+ for i in alive_idx.tolist():
1106
+ out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1107
+ step_col = int(pos_ids[0, 0].item())
1108
+ mask[alive, :, :, step_col] = True
1109
+ logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1110
+ pos_ids[alive, 0] += 1
1111
+ next_tok = logits.argmax(dim=-1)
1112
+ if next_tok.dim() == 3:
1113
+ next_tok = next_tok.squeeze(-1).squeeze(-1)
1114
+ elif next_tok.dim() == 2:
1115
+ next_tok = next_tok.squeeze(1)
1116
 
1117
+ # we added one object/point for all ALIVE rows this iteration
1118
+ counts[alive] += 1
1119
+
1120
+ # stop rows that hit eos OR reached max_objects
1121
+ finished_now = (next_tok == eos_id) | (counts >= max_objects)
1122
  alive &= ~finished_now
1123
 
1124
  return out
1125
 
1126
 
1127
+
1128
  def detect_multi(self, image, objects, settings=None):
1129
  if self.config.tokenizer.templates["detect"] is None:
1130
  raise NotImplementedError("Model does not support object detection.")