HV-Khurdula commited on
Commit
01b09b7
·
verified ·
1 Parent(s): 36f4434

Update moondream.py

Browse files
Files changed (1) hide show
  1. moondream.py +88 -57
moondream.py CHANGED
@@ -77,29 +77,28 @@ class KVCache(nn.Module):
77
  "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
78
  )
79
 
 
80
  def update(self, pos_ids, k, v):
81
  """
82
  Supports:
83
  • Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
84
- • 1-step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,) or (B,1)
85
- • Legacy: k,v = (1, n_kv_heads, q_len, d), pos_ids = scalar int
86
  Writes into self.k_cache/self.v_cache shaped (B, n_kv_heads, T_max, d).
87
  """
88
  kout, vout = self.k_cache, self.v_cache
89
 
 
90
  if not torch.is_tensor(pos_ids):
91
- # Scalar legacy path
92
- kout[:, :, pos_ids, :] = k
93
- vout[:, :, pos_ids, :] = v
94
- return kout, vout
95
-
96
- pos_ids = pos_ids.to(dtype=torch.long, device=k.device)
97
 
98
  if k.dim() != 4 or v.dim() != 4:
99
  raise RuntimeError(f"KV update expects k,v 4D. Got k={tuple(k.shape)} v={tuple(v.shape)}")
100
  B, Hkv, q_len, D = k.shape
101
 
102
- # Make sure cache batch matches B (expand-from-1 is ok, otherwise error)
103
  if kout.size(0) != B:
104
  if kout.size(0) == 1:
105
  self.k_cache = kout.expand(B, -1, -1, -1).clone()
@@ -108,23 +107,23 @@ class KVCache(nn.Module):
108
  else:
109
  raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
110
 
111
- # Case A: PREFILL — pos_ids indexes a contiguous range per row
112
  if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
113
  for i in range(B):
114
- kout[i, :, pos_ids, :] = k[i] # (Hkv, q_len, D)
115
  vout[i, :, pos_ids, :] = v[i]
116
  return kout, vout
117
 
118
- # Case B: STEP — q_len == 1 and one position per row
119
  if q_len == 1 and pos_ids.numel() == B:
120
- pos_ids_flat = pos_ids.view(-1) # handle (B,1) or (B,)
121
  for i in range(B):
122
- pi = int(pos_ids_flat[i].item())
123
  kout[i, :, pi, :] = k[i, :, 0, :]
124
  vout[i, :, pi, :] = v[i, :, 0, :]
125
  return kout, vout
126
 
127
- # Case C: scalar for everyone
128
  if pos_ids.dim() == 0 and q_len == 1:
129
  pi = int(pos_ids.item())
130
  kout[:, :, pi, :] = k[:, :, 0, :]
@@ -140,6 +139,7 @@ class KVCache(nn.Module):
140
 
141
 
142
 
 
143
  class MoondreamModel(nn.Module):
144
 
145
  def __init__(
@@ -589,10 +589,12 @@ class MoondreamModel(nn.Module):
589
  elif not isinstance(image, Image.Image):
590
  raise ValueError("image must be a PIL Image or EncodedImage")
591
 
 
592
  for blk in self.text.blocks:
593
  if blk.kv_cache.k_cache.size(0) != 1:
594
  blk.kv_cache.k_cache = blk.kv_cache.k_cache[:1].contiguous()
595
  blk.kv_cache.v_cache = blk.kv_cache.v_cache[:1].contiguous()
 
596
 
597
  lora = (
598
  variant_state_dict(settings["variant"], device=self.device)
@@ -971,107 +973,135 @@ class MoondreamModel(nn.Module):
971
  # CRITICAL: per-row next position
972
  pos_vec = torch.tensor(lens, device=self.device, dtype=torch.long) + pos # (B,)
973
 
974
- return last_hidden, next_token, pos_vec
 
 
 
975
 
976
- # In class MoondreamModel, replace the whole method:
977
  def _generate_points_batched(
978
  self,
979
  hidden, # (B,1,C)
980
- next_token, # (B,1) (not used when temperature=0, but ok)
981
- pos: int, # shared scalar next position
982
  include_size: bool = True,
983
  max_objects: int = 50,
984
  lora=None,
 
985
  ):
986
- """
987
- Vectorized version of _generate_points() that decodes x -> y -> size -> next-token
988
- for all rows in the batch simultaneously. Returns list-of-lists of dicts, len B.
989
- """
990
  B = hidden.size(0)
991
  device = self.device
992
  out = [[] for _ in range(B)]
993
  eos_id = self.config.tokenizer.eos_id
994
  max_ctx = self.config.text.max_context
995
 
 
 
 
 
996
  # 4-D mask: (B, 1, q_len=1, kv_len)
997
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
998
  if pos > 0:
999
  mask[:, :, :, :pos] = True
1000
- # IMPORTANT: position_ids must be (B, 1) for rotary; KVCache.update accepts (B,1) too
 
1001
  pos_ids = torch.full((B, 1), pos, device=device, dtype=torch.long)
1002
 
1003
  alive = torch.ones(B, dtype=torch.bool, device=device)
1004
  counts = torch.zeros(B, dtype=torch.int32, device=device)
1005
 
 
 
 
 
 
 
 
 
 
 
 
 
1006
  with torch.inference_mode():
1007
  while alive.any() and (counts < max_objects).any():
1008
- # --- x coordinate ---
1009
- x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
1010
  if x_logits.dim() == 3:
1011
  x_logits = x_logits.squeeze(1)
1012
- x_center = x_logits.argmax(dim=-1).to(torch.float32) / float(x_logits.size(-1)) # (B,)
1013
- x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
1014
- x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
1015
 
1016
- # advance attention one step
1017
- mask[:, :, :, pos] = True
1018
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
1019
- pos += 1
1020
- pos_ids[:, 0] = pos
1021
 
1022
- # --- y coordinate ---
1023
  y_logits = decode_coordinate(hidden, self.region)
1024
  if y_logits.dim() == 3:
1025
  y_logits = y_logits.squeeze(1)
1026
- y_center = y_logits.argmax(dim=-1).to(torch.float32) / float(y_logits.size(-1))
1027
- y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1) # (B,1)
1028
  y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
1029
 
1030
- mask[:, :, :, pos] = True
1031
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
 
1032
  pos += 1
1033
- pos_ids[:, 0] = pos
1034
 
1035
  if include_size:
1036
- # --- size ---
1037
- size_logits = decode_size(hidden, self.region) # tuple/list [w_logits, h_logits]
1038
- # Support both (B,1,1024) and (B,1024)
1039
  w_logits = size_logits[0].squeeze(1)
1040
  h_logits = size_logits[1].squeeze(1)
1041
- w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1042
- h_bin = h_logits.argmax(dim=-1).to(torch.float32)
 
 
 
 
 
1043
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
1044
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
 
1045
  size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
1046
- size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
1047
 
1048
- # record boxes
1049
  for i in range(B):
1050
- if alive[i]:
1051
- out[i].append({
1052
- "x_min": (x_center[i] - w[i] / 2).item(),
1053
- "y_min": (y_center[i] - h[i] / 2).item(),
1054
- "x_max": (x_center[i] + w[i] / 2).item(),
1055
- "y_max": (y_center[i] + h[i] / 2).item(),
1056
- })
 
 
 
 
 
 
1057
 
1058
- mask[:, :, :, pos] = True
1059
  logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
 
1060
  pos += 1
1061
- pos_ids[:, 0] = pos
1062
  next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1063
  else:
1064
  for i in range(B):
1065
  if alive[i]:
1066
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1067
- mask[:, :, :, pos] = True
1068
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
 
1069
  pos += 1
1070
- pos_ids[:, 0] = pos
1071
  next_tok = logits.argmax(dim=-1).squeeze(-1)
1072
 
 
1073
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1074
- counts = counts + (~finished_now & alive).to(counts.dtype)
1075
  alive &= ~finished_now
1076
 
1077
  return out
@@ -1080,6 +1110,7 @@ class MoondreamModel(nn.Module):
1080
 
1081
 
1082
 
 
1083
  def detect_multi(self, image, objects, settings=None):
1084
  """
1085
  Parallel multi-label detection.
 
77
  "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
78
  )
79
 
80
+ # In class KVCache, REPLACE the whole update() with this:
81
  def update(self, pos_ids, k, v):
82
  """
83
  Supports:
84
  • Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
85
+ • 1-step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,1) or (B,)
86
+ • Legacy: k,v = (1, n_kv_heads, q_len, d), pos_ids = scalar
87
  Writes into self.k_cache/self.v_cache shaped (B, n_kv_heads, T_max, d).
88
  """
89
  kout, vout = self.k_cache, self.v_cache
90
 
91
+ # Normalize pos_ids
92
  if not torch.is_tensor(pos_ids):
93
+ pos_ids = torch.tensor(pos_ids, device=k.device, dtype=torch.long)
94
+ else:
95
+ pos_ids = pos_ids.to(device=k.device, dtype=torch.long)
 
 
 
96
 
97
  if k.dim() != 4 or v.dim() != 4:
98
  raise RuntimeError(f"KV update expects k,v 4D. Got k={tuple(k.shape)} v={tuple(v.shape)}")
99
  B, Hkv, q_len, D = k.shape
100
 
101
+ # Ensure cache batch matches B (expand-from-1 allowed)
102
  if kout.size(0) != B:
103
  if kout.size(0) == 1:
104
  self.k_cache = kout.expand(B, -1, -1, -1).clone()
 
107
  else:
108
  raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
109
 
110
+ # Case A: PREFILL — vector of length q_len (same for all B rows)
111
  if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
112
  for i in range(B):
113
+ kout[i, :, pos_ids, :] = k[i] # (Hkv, q_len, D)
114
  vout[i, :, pos_ids, :] = v[i]
115
  return kout, vout
116
 
117
+ # Case B: 1-STEP — q_len == 1 with (B,) or (B,1) per-row positions
118
  if q_len == 1 and pos_ids.numel() == B:
119
+ pos_ids = pos_ids.view(B)
120
  for i in range(B):
121
+ pi = int(pos_ids[i].item())
122
  kout[i, :, pi, :] = k[i, :, 0, :]
123
  vout[i, :, pi, :] = v[i, :, 0, :]
124
  return kout, vout
125
 
126
+ # Case C: scalar for everyone & q_len == 1
127
  if pos_ids.dim() == 0 and q_len == 1:
128
  pi = int(pos_ids.item())
129
  kout[:, :, pi, :] = k[:, :, 0, :]
 
139
 
140
 
141
 
142
+
143
  class MoondreamModel(nn.Module):
144
 
145
  def __init__(
 
589
  elif not isinstance(image, Image.Image):
590
  raise ValueError("image must be a PIL Image or EncodedImage")
591
 
592
+ # At the VERY TOP of encode_image(), right after the type checks:
593
  for blk in self.text.blocks:
594
  if blk.kv_cache.k_cache.size(0) != 1:
595
  blk.kv_cache.k_cache = blk.kv_cache.k_cache[:1].contiguous()
596
  blk.kv_cache.v_cache = blk.kv_cache.v_cache[:1].contiguous()
597
+
598
 
599
  lora = (
600
  variant_state_dict(settings["variant"], device=self.device)
 
973
  # CRITICAL: per-row next position
974
  pos_vec = torch.tensor(lens, device=self.device, dtype=torch.long) + pos # (B,)
975
 
976
+ # At the end of _prefill_prompt_batched(), return a Python int:
977
+ pos_end = int((pos + T))
978
+ return last_hidden, next_token, pos_end
979
+
980
 
 
981
  def _generate_points_batched(
982
  self,
983
  hidden, # (B,1,C)
984
+ next_token, # (B,1) (unused in greedy, but OK)
985
+ pos, # int or Tensor; normalized below
986
  include_size: bool = True,
987
  max_objects: int = 50,
988
  lora=None,
989
+ use_soft_argmax: bool = True, # NEW: reduces jitter/hallucinations
990
  ):
 
 
 
 
991
  B = hidden.size(0)
992
  device = self.device
993
  out = [[] for _ in range(B)]
994
  eos_id = self.config.tokenizer.eos_id
995
  max_ctx = self.config.text.max_context
996
 
997
+ # Normalize pos to a scalar int (supports int, (1,), (B,), (B,1))
998
+ if torch.is_tensor(pos):
999
+ pos = int(pos.max().item()) # safe upper bound; we manage per-row with pos_ids/alive
1000
+
1001
  # 4-D mask: (B, 1, q_len=1, kv_len)
1002
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
1003
  if pos > 0:
1004
  mask[:, :, :, :pos] = True
1005
+
1006
+ # position_ids must be (B,1)
1007
  pos_ids = torch.full((B, 1), pos, device=device, dtype=torch.long)
1008
 
1009
  alive = torch.ones(B, dtype=torch.bool, device=device)
1010
  counts = torch.zeros(B, dtype=torch.int32, device=device)
1011
 
1012
+ # helpers ---------------------------------------------------------
1013
+ def _argmax01(logits):
1014
+ # logits: (B, bins)
1015
+ if use_soft_argmax:
1016
+ probs = torch.softmax(logits, dim=-1)
1017
+ bins = torch.arange(probs.size(-1), device=logits.device, dtype=torch.float32)
1018
+ idx = (probs * bins).sum(dim=-1) / (probs.size(-1) - 1)
1019
+ return idx # 0..1
1020
+ else:
1021
+ idx = logits.argmax(dim=-1).to(torch.float32)
1022
+ return idx / float(logits.size(-1) - 1)
1023
+
1024
  with torch.inference_mode():
1025
  while alive.any() and (counts < max_objects).any():
1026
+ # --- x ---------------------------------------------------
1027
+ x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
1028
  if x_logits.dim() == 3:
1029
  x_logits = x_logits.squeeze(1)
1030
+ x_center = _argmax01(x_logits) # (B,) in [0,1]
1031
+ x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
1032
+ x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
1033
 
1034
+ # advance attention one step FOR ALIVE ROWS ONLY
1035
+ mask[alive, :, :, pos] = True
1036
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
1037
+ pos_ids[alive, 0] += 1
1038
+ pos += 1 # scalar next free slot
1039
 
1040
+ # --- y ---------------------------------------------------
1041
  y_logits = decode_coordinate(hidden, self.region)
1042
  if y_logits.dim() == 3:
1043
  y_logits = y_logits.squeeze(1)
1044
+ y_center = _argmax01(y_logits) # (B,)
1045
+ y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1)
1046
  y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
1047
 
1048
+ mask[alive, :, :, pos] = True
1049
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1050
+ pos_ids[alive, 0] += 1
1051
  pos += 1
 
1052
 
1053
  if include_size:
1054
+ # --- size --------------------------------------------
1055
+ size_logits = decode_size(hidden, self.region)
 
1056
  w_logits = size_logits[0].squeeze(1)
1057
  h_logits = size_logits[1].squeeze(1)
1058
+ if use_soft_argmax:
1059
+ # convert expected-bin -> size (same mapping as paper/code)
1060
+ w_bin = (torch.softmax(w_logits, dim=-1) * torch.arange(w_logits.size(-1), device=device)).sum(dim=-1)
1061
+ h_bin = (torch.softmax(h_logits, dim=-1) * torch.arange(h_logits.size(-1), device=device)).sum(dim=-1)
1062
+ else:
1063
+ w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1064
+ h_bin = h_logits.argmax(dim=-1).to(torch.float32)
1065
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
1066
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
1067
+
1068
  size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
1069
+ size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
1070
 
1071
+ # record boxes only for alive rows
1072
  for i in range(B):
1073
+ if not alive[i]:
1074
+ continue
1075
+ xl = (x_center[i] - w[i] / 2).item()
1076
+ xr = (x_center[i] + w[i] / 2).item()
1077
+ yt = (y_center[i] - h[i] / 2).item()
1078
+ yb = (y_center[i] + h[i] / 2).item()
1079
+ # clamp for safety
1080
+ out[i].append({
1081
+ "x_min": max(0.0, min(1.0, xl)),
1082
+ "y_min": max(0.0, min(1.0, yt)),
1083
+ "x_max": max(0.0, min(1.0, xr)),
1084
+ "y_max": max(0.0, min(1.0, yb)),
1085
+ })
1086
 
1087
+ mask[alive, :, :, pos] = True
1088
  logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
1089
+ pos_ids[alive, 0] += 1
1090
  pos += 1
 
1091
  next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1092
  else:
1093
  for i in range(B):
1094
  if alive[i]:
1095
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1096
+ mask[alive, :, :, pos] = True
1097
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1098
+ pos_ids[alive, 0] += 1
1099
  pos += 1
 
1100
  next_tok = logits.argmax(dim=-1).squeeze(-1)
1101
 
1102
+ # stop only rows that hit eos (or reached max objects)
1103
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1104
+ counts = counts + ((~finished_now) & alive).to(counts.dtype)
1105
  alive &= ~finished_now
1106
 
1107
  return out
 
1110
 
1111
 
1112
 
1113
+
1114
  def detect_multi(self, image, objects, settings=None):
1115
  """
1116
  Parallel multi-label detection.