HV-Khurdula commited on
Commit
a1d497d
·
verified ·
1 Parent(s): 75468c3

Update moondream.py

Browse files

fix:
1. Rerun crash in encode_image(...): After a batched call, KV caches are still (B, ...). On the next encode, Q has B=1 but K/V broadcast to B, so attention returns B rows and the reshape to (1, q_len, d_model) fails. Forcing a fresh B=1 cache avoids this. (Moondream’s attention expects Q/K/V batch dimensions to match when reshaping back to (bsz, q_len, d_model).)

2. decode_size IndexError on .squeeze(1): Upstream decode_size returns mlp(...).view(2, -1), which flattens batch/time dims; not always a (B,1,1024) pair. We reshape it back to (2, B, -1) when needed, so it works in both variants.

Files changed (1) hide show
  1. moondream.py +83 -79
moondream.py CHANGED
@@ -155,17 +155,33 @@ class MoondreamModel(nn.Module):
155
  if setup_caches:
156
  self._setup_caches()
157
 
158
- def _setup_caches(self):
 
 
 
 
 
159
  c = self.config.text
160
- for b in self.text.blocks:
161
- b.kv_cache = KVCache(
162
- c.n_heads,
163
- c.n_kv_heads,
164
- c.max_context,
165
- c.dim,
166
- device=self.device,
167
- dtype=self.vision.pos_emb.dtype,
168
- )
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  @property
171
  def device(self):
@@ -238,30 +254,30 @@ class MoondreamModel(nn.Module):
238
  image: Union[Image.Image, EncodedImage],
239
  settings: Optional[ImageEncodingSettings] = None,
240
  ) -> EncodedImage:
 
 
 
241
  if isinstance(image, EncodedImage):
242
  return image
243
  elif not isinstance(image, Image.Image):
244
  raise ValueError("image must be a PIL Image or EncodedImage")
245
-
246
  lora = (
247
  variant_state_dict(settings["variant"], device=self.device)
248
  if settings is not None and "variant" in settings
249
  else None
250
  )
251
-
252
- # Run through text model in addition to the vision encoder, to minimize
253
- # re-computation if multiple queries are performed on this image.
254
  with torch.inference_mode():
255
  img_emb = self._run_vision_encoder(image)
256
  bos_emb = text_encoder(
257
- torch.tensor([[self.config.tokenizer.bos_id]], device=self.device),
258
- self.text,
259
  )
260
  inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
261
  mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
262
  pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
263
  self._prefill(inputs_embeds, mask, pos_ids, lora)
264
-
265
  return EncodedImage(
266
  pos=inputs_embeds.size(1),
267
  caches=[
@@ -273,6 +289,7 @@ class MoondreamModel(nn.Module):
273
  ],
274
  )
275
 
 
276
  def _apply_top_p(self, probs: torch.Tensor, top_p: float):
277
  probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
278
  probs_sum = torch.cumsum(probs_sort, dim=-1)
@@ -835,7 +852,6 @@ class MoondreamModel(nn.Module):
835
  return {"points": objects}
836
 
837
 
838
- # === BEGIN: Batched multi-label detection additions ===
839
  def _load_encoded_image_batched(self, encoded_image, batch_size: int):
840
  """
841
  Clone single-image KV caches into a batch-B cache so we can decode B labels in parallel.
@@ -860,10 +876,6 @@ class MoondreamModel(nn.Module):
860
  temperature: float = 0.0,
861
  top_p: float = 0.0,
862
  ):
863
- """
864
- Build detect prompts for many labels, pad to the same length, prefill once as a batch.
865
- Returns (last_hidden per row, next_token per row, shared_pos_end scalar).
866
- """
867
  tpl = self.config.tokenizer.templates["detect"]
868
  if tpl is None:
869
  raise NotImplementedError("Model does not support object detection (no detect template).")
@@ -877,24 +889,21 @@ class MoondreamModel(nn.Module):
877
  T = max(lens)
878
  eos = self.config.tokenizer.eos_id
879
 
880
- # Pad to T with eos, so we can prefill with a single shared position range
881
  prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
882
  for i, ids in enumerate(rows):
883
  prompt_ids[i, : ids.numel()] = ids
884
 
885
- # Embed & prefill once
886
- prompt_emb = text_encoder(prompt_ids, self.text) # (B, T, C)
887
- torch._dynamo.mark_dynamic(prompt_emb, 1) # allow variable T
888
 
889
- # 4-D mask form makes head broadcasting unambiguous later
890
  attn = self.attn_mask
891
- mask = attn[:, :, pos : pos + T, :].expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
892
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long)
893
 
894
  hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B, T, C)
895
  logits_BTV = lm_head(hidden_BTC, self.text) # (B, T, V)
896
 
897
- # Take the last *real* token per row
898
  idx = (torch.tensor(lens, device=self.device, dtype=torch.long) - 1).clamp_min(0)
899
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
900
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
@@ -906,34 +915,27 @@ class MoondreamModel(nn.Module):
906
  probs = self._apply_top_p(probs, top_p)
907
  next_token = torch.multinomial(probs, num_samples=1) # (B,1)
908
 
909
- # Shared "next decode position" for all rows (we prefilled up to pos+T-1)
910
  pos_end = pos + T
911
  return last_hidden, next_token, pos_end # (B,1,C), (B,1), int
912
 
913
 
 
914
  def _generate_points_batched(
915
  self,
916
- hidden, # (B,1,C) last hidden after prefill (per label row)
917
- next_token, # (B,1) (kept for parity; not used when temperature=0)
918
- pos: int, # shared scalar next position for all rows
919
  include_size: bool = True,
920
  max_objects: int = 50,
921
  lora=None,
922
  ):
923
- """
924
- Vectorized version of _generate_points() that decodes x -> y -> size -> next-token
925
- for all rows in the batch simultaneously. Returns list-of-lists of dicts (len B).
926
- Batch-safe: uses 4-D masks and avoids region.decode_size() (which flattens batch).
927
- """
928
- import torch
929
-
930
  B = hidden.size(0)
931
  device = self.device
932
  out = [[] for _ in range(B)]
933
  eos_id = self.config.tokenizer.eos_id
934
  max_ctx = self.config.text.max_context
935
 
936
- # 4-D mask: (B, 1, q_len=1, kv_len), True means "visible" to match model's convention
937
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
938
  if pos > 0:
939
  mask[:, :, :, :pos] = True
@@ -944,53 +946,55 @@ class MoondreamModel(nn.Module):
944
 
945
  with torch.inference_mode():
946
  while alive.any() and (counts < max_objects).any():
947
- # --- x coordinate (batched) ---
948
  x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
949
  if x_logits.dim() == 3:
950
- x_logits = x_logits.squeeze(1) # (B,1024)
951
- x_bin = x_logits.argmax(dim=-1).to(torch.float32)
952
  x_center = x_bin / float(x_logits.size(-1)) # (B,)
953
- x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
954
- x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
955
 
956
- # advance one token
957
- mask[:, :, :, pos] = True
958
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_id, lora)
959
- pos += 1
960
- pos_id[0] = pos
961
 
962
- # --- y coordinate (batched) ---
963
- y_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
964
  if y_logits.dim() == 3:
965
  y_logits = y_logits.squeeze(1)
966
- y_bin = y_logits.argmax(dim=-1).to(torch.float32)
967
- y_center = y_bin / float(y_logits.size(-1))
968
- y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1) # (B,1)
969
- y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
970
 
971
- mask[:, :, :, pos] = True
972
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
973
- pos += 1
974
- pos_id[0] = pos
975
 
976
  if include_size:
977
- # ---- size (batched, *without* region.decode_size which flattens batch) ----
978
- # size_out_dim is 2*1024 (W then H). mlp() preserves (B,1).
979
- size_logits = mlp(hidden, self.region["size_decoder"]).squeeze(1) # (B, 2048)
980
- half = size_logits.size(-1) // 2
981
- w_logits, h_logits = size_logits[:, :half], size_logits[:, half:] # (B,1024),(B,1024)
 
 
 
 
 
 
 
982
 
983
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
984
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
985
-
986
- # inverse log-scale mapping used by the repo
987
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
988
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
989
 
990
- size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
991
  size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
992
 
993
- # commit boxes
994
  for i in range(B):
995
  if alive[i]:
996
  out[i].append({
@@ -1000,21 +1004,18 @@ class MoondreamModel(nn.Module):
1000
  "y_max": (y_center[i] + h[i] / 2).item(),
1001
  })
1002
 
1003
- # decide continuation
1004
- mask[:, :, :, pos] = True
1005
  logits, hidden = self._decode_one_tok(size_emb, mask, pos_id, lora)
1006
- pos += 1
1007
- pos_id[0] = pos
1008
  next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1009
  else:
1010
- # points mode
1011
  for i in range(B):
1012
  if alive[i]:
1013
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1014
- mask[:, :, :, pos] = True
 
1015
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
1016
- pos += 1
1017
- pos_id[0] = pos
1018
  next_tok = logits.argmax(dim=-1).squeeze(-1)
1019
 
1020
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
@@ -1024,6 +1025,7 @@ class MoondreamModel(nn.Module):
1024
  return out
1025
 
1026
 
 
1027
  def detect_multi(self, image, objects, settings=None):
1028
  """
1029
  Parallel multi-label detection.
@@ -1043,17 +1045,14 @@ class MoondreamModel(nn.Module):
1043
  B = len(objects)
1044
  self._load_encoded_image_batched(image, B)
1045
 
1046
- # Optional LoRA variant
1047
  lora = None
1048
  if "variant" in settings:
1049
  lora = variant_state_dict(settings["variant"], device=self.device)
1050
 
1051
- # Prefill all prompts as a batch; shared next position
1052
  last_hidden, next_token, pos_end = self._prefill_prompt_batched(
1053
  objects, image.pos, lora=lora, temperature=0.0, top_p=0.0
1054
  )
1055
 
1056
- # Batched decode loop
1057
  max_objects = settings.get("max_objects", 50)
1058
  det_lists = self._generate_points_batched(
1059
  last_hidden, next_token, pos_end,
@@ -1066,9 +1065,14 @@ class MoondreamModel(nn.Module):
1066
  for d in lst:
1067
  d["label"] = lab
1068
  res[lab] = lst
 
 
 
 
1069
  return {"objects": res}
1070
 
1071
 
 
1072
  def _detect_gaze(
1073
  self,
1074
  image: EncodedImage,
 
155
  if setup_caches:
156
  self._setup_caches()
157
 
158
+
159
+ def _reset_kv_caches(self, batch_size: int = 1):
160
+ """
161
+ Recreate KV caches with the requested batch size so subsequent calls
162
+ (e.g., encode_image) start from a consistent shape.
163
+ """
164
  c = self.config.text
165
+ head_dim = c.dim // c.n_heads
166
+ for blk in self.text.blocks:
167
+ device = blk.kv_cache.k_cache.device
168
+ dtype = blk.kv_cache.k_cache.dtype
169
+ shape = (batch_size, c.n_kv_heads, c.max_context, head_dim)
170
+ blk.kv_cache.k_cache = torch.zeros(shape, device=device, dtype=dtype)
171
+ blk.kv_cache.v_cache = torch.zeros(shape, device=device, dtype=dtype)
172
+
173
+
174
+ def _setup_caches(self):
175
+ c = self.config.text
176
+ for b in self.text.blocks:
177
+ b.kv_cache = KVCache(
178
+ c.n_heads,
179
+ c.n_kv_heads,
180
+ c.max_context,
181
+ c.dim,
182
+ device=self.device,
183
+ dtype=self.vision.pos_emb.dtype,
184
+ )
185
 
186
  @property
187
  def device(self):
 
254
  image: Union[Image.Image, EncodedImage],
255
  settings: Optional[ImageEncodingSettings] = None,
256
  ) -> EncodedImage:
257
+ # Always start from single-row caches; avoids leftovers from batched runs.
258
+ self._setup_caches()
259
+
260
  if isinstance(image, EncodedImage):
261
  return image
262
  elif not isinstance(image, Image.Image):
263
  raise ValueError("image must be a PIL Image or EncodedImage")
264
+
265
  lora = (
266
  variant_state_dict(settings["variant"], device=self.device)
267
  if settings is not None and "variant" in settings
268
  else None
269
  )
270
+
 
 
271
  with torch.inference_mode():
272
  img_emb = self._run_vision_encoder(image)
273
  bos_emb = text_encoder(
274
+ torch.tensor([[self.config.tokenizer.bos_id]], device=self.device), self.text
 
275
  )
276
  inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
277
  mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
278
  pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
279
  self._prefill(inputs_embeds, mask, pos_ids, lora)
280
+
281
  return EncodedImage(
282
  pos=inputs_embeds.size(1),
283
  caches=[
 
289
  ],
290
  )
291
 
292
+
293
  def _apply_top_p(self, probs: torch.Tensor, top_p: float):
294
  probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
295
  probs_sum = torch.cumsum(probs_sort, dim=-1)
 
852
  return {"points": objects}
853
 
854
 
 
855
  def _load_encoded_image_batched(self, encoded_image, batch_size: int):
856
  """
857
  Clone single-image KV caches into a batch-B cache so we can decode B labels in parallel.
 
876
  temperature: float = 0.0,
877
  top_p: float = 0.0,
878
  ):
 
 
 
 
879
  tpl = self.config.tokenizer.templates["detect"]
880
  if tpl is None:
881
  raise NotImplementedError("Model does not support object detection (no detect template).")
 
889
  T = max(lens)
890
  eos = self.config.tokenizer.eos_id
891
 
 
892
  prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
893
  for i, ids in enumerate(rows):
894
  prompt_ids[i, : ids.numel()] = ids
895
 
896
+ prompt_emb = text_encoder(prompt_ids, self.text) # (B, T, C)
897
+ torch._dynamo.mark_dynamic(prompt_emb, 1)
 
898
 
899
+ # 4-D mask is broadcastable to (B, n_heads, T, K)
900
  attn = self.attn_mask
901
+ mask = attn[:, :, pos : pos + T, :].expand(B, -1, -1, -1).contiguous()
902
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long)
903
 
904
  hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B, T, C)
905
  logits_BTV = lm_head(hidden_BTC, self.text) # (B, T, V)
906
 
 
907
  idx = (torch.tensor(lens, device=self.device, dtype=torch.long) - 1).clamp_min(0)
908
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
909
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
 
915
  probs = self._apply_top_p(probs, top_p)
916
  next_token = torch.multinomial(probs, num_samples=1) # (B,1)
917
 
 
918
  pos_end = pos + T
919
  return last_hidden, next_token, pos_end # (B,1,C), (B,1), int
920
 
921
 
922
+
923
  def _generate_points_batched(
924
  self,
925
+ hidden, # (B,1,C)
926
+ next_token, # (B,1)
927
+ pos: int, # shared scalar next position
928
  include_size: bool = True,
929
  max_objects: int = 50,
930
  lora=None,
931
  ):
 
 
 
 
 
 
 
932
  B = hidden.size(0)
933
  device = self.device
934
  out = [[] for _ in range(B)]
935
  eos_id = self.config.tokenizer.eos_id
936
  max_ctx = self.config.text.max_context
937
 
938
+ # 4-D mask: (B, 1, q_len=1, kv_len)
939
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
940
  if pos > 0:
941
  mask[:, :, :, :pos] = True
 
946
 
947
  with torch.inference_mode():
948
  while alive.any() and (counts < max_objects).any():
949
+ # --- x coordinate ---
950
  x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
951
  if x_logits.dim() == 3:
952
+ x_logits = x_logits.squeeze(1)
953
+ x_bin = x_logits.argmax(dim=-1).to(torch.float32)
954
  x_center = x_bin / float(x_logits.size(-1)) # (B,)
955
+ x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
956
+ x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
957
 
958
+ mask[:, :, :, pos_id[0].item()] = True
 
959
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_id, lora)
960
+ pos_id += 1
 
961
 
962
+ # --- y coordinate ---
963
+ y_logits = decode_coordinate(hidden, self.region)
964
  if y_logits.dim() == 3:
965
  y_logits = y_logits.squeeze(1)
966
+ y_bin = y_logits.argmax(dim=-1).to(torch.float32)
967
+ y_center = y_bin / float(y_logits.size(-1)) # (B,)
968
+ y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1)
969
+ y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
970
 
971
+ mask[:, :, :, pos_id[0].item()] = True
972
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
973
+ pos_id += 1
 
974
 
975
  if include_size:
976
+ size_logits = decode_size(hidden, self.region)
977
+ # Support both tuple-of-tensors and flattened (2, -1) forms
978
+ if isinstance(size_logits, (tuple, list)):
979
+ w_logits = size_logits[0]
980
+ h_logits = size_logits[1]
981
+ if w_logits.dim() == 3: # (B,1,1024)
982
+ w_logits = w_logits.squeeze(1)
983
+ h_logits = h_logits.squeeze(1)
984
+ else:
985
+ # size_logits shape: (2, B * size_bins) — reshape it back.
986
+ size_logits = size_logits.reshape(2, B, -1)
987
+ w_logits, h_logits = size_logits[0], size_logits[1] # (B, size_bins)
988
 
989
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
990
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
991
+ # inverse of log-scale mapping used by Moondream
 
992
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
993
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
994
 
995
+ size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
996
  size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
997
 
 
998
  for i in range(B):
999
  if alive[i]:
1000
  out[i].append({
 
1004
  "y_max": (y_center[i] + h[i] / 2).item(),
1005
  })
1006
 
1007
+ mask[:, :, :, pos_id[0].item()] = True
 
1008
  logits, hidden = self._decode_one_tok(size_emb, mask, pos_id, lora)
1009
+ pos_id += 1
 
1010
  next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1011
  else:
 
1012
  for i in range(B):
1013
  if alive[i]:
1014
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1015
+
1016
+ mask[:, :, :, pos_id[0].item()] = True
1017
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
1018
+ pos_id += 1
 
1019
  next_tok = logits.argmax(dim=-1).squeeze(-1)
1020
 
1021
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
 
1025
  return out
1026
 
1027
 
1028
+
1029
  def detect_multi(self, image, objects, settings=None):
1030
  """
1031
  Parallel multi-label detection.
 
1045
  B = len(objects)
1046
  self._load_encoded_image_batched(image, B)
1047
 
 
1048
  lora = None
1049
  if "variant" in settings:
1050
  lora = variant_state_dict(settings["variant"], device=self.device)
1051
 
 
1052
  last_hidden, next_token, pos_end = self._prefill_prompt_batched(
1053
  objects, image.pos, lora=lora, temperature=0.0, top_p=0.0
1054
  )
1055
 
 
1056
  max_objects = settings.get("max_objects", 50)
1057
  det_lists = self._generate_points_batched(
1058
  last_hidden, next_token, pos_end,
 
1065
  for d in lst:
1066
  d["label"] = lab
1067
  res[lab] = lst
1068
+
1069
+ # IMPORTANT: restore caches to B=1 so future calls (e.g., encode_image) are safe.
1070
+ self._reset_kv_caches(1)
1071
+
1072
  return {"objects": res}
1073
 
1074
 
1075
+
1076
  def _detect_gaze(
1077
  self,
1078
  image: EncodedImage,