HV-Khurdula commited on
Commit
f9b6e6a
·
verified ·
1 Parent(s): 50407fb

Update moondream.py

Browse files

fix: decode batched (B>1) labels

Files changed (1) hide show
  1. moondream.py +96 -112
moondream.py CHANGED
@@ -850,84 +850,89 @@ class MoondreamModel(nn.Module):
850
  b.kv_cache.k_cache[:, :, :T, :] = k.expand(batch_size, -1, -1, -1)
851
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
852
 
853
- def _prefill_prompt_batched(self, labels, pos: int, lora=None, temperature: float = 0.0, top_p: float = 0.0):
 
 
 
 
 
 
 
854
  """
855
- Build detect prompts for many labels, pad to same length, prefill once as a batch,
856
- then return (last_hidden per row, next_token per row, pos per row).
857
  """
858
  tpl = self.config.tokenizer.templates["detect"]
859
  if tpl is None:
860
  raise NotImplementedError("Model does not support object detection (no detect template).")
861
-
862
  rows, lens = [], []
863
  for lab in labels:
864
  ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
865
  rows.append(torch.tensor(ids, device=self.device, dtype=torch.long))
866
  lens.append(len(ids))
867
- B = len(rows); T = max(lens)
 
868
  eos = self.config.tokenizer.eos_id
869
-
870
- # Pad with eos so we can prefill as a single batch
871
  prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
872
  for i, ids in enumerate(rows):
873
  prompt_ids[i, : ids.numel()] = ids
874
-
875
  # Embed & prefill once
876
- prompt_emb = text_encoder(prompt_ids, self.text) # (B, T, C)
877
- torch._dynamo.mark_dynamic(prompt_emb, 1) # allow variable T
878
-
879
- attn_mask = self.attn_mask
880
- mask = attn_mask[:, :, pos : pos + T, :].expand(B, -1, -1, -1).contiguous()
 
881
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long)
882
-
883
  hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B, T, C)
884
  logits_BTV = lm_head(hidden_BTC, self.text) # (B, T, V)
885
-
886
- # Take the last *real* token per row (ignore padding positions)
887
  idx = (torch.tensor(lens, device=self.device, dtype=torch.long) - 1).clamp_min(0)
888
- last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B, 1, C)
889
- last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B, V)
890
-
891
  if temperature == 0.0:
892
- next_token = last_logits.argmax(dim=-1, keepdim=True) # (B, 1)
893
  else:
894
  probs = torch.softmax(last_logits / temperature, dim=-1)
895
  probs = self._apply_top_p(probs, top_p)
896
- next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
897
-
898
- pos_vec = torch.full((B,), pos + T, device=self.device, dtype=torch.long)
 
 
899
 
900
- return last_hidden, next_token, pos_vec # (B,1,C), (B,1), (B,)
901
 
902
  def _generate_points_batched(
903
  self,
904
- hidden: torch.Tensor, # (B, 1, C) last hidden per row from prefill
905
- next_token: torch.Tensor, # (B, 1) unused here; kept for parity
906
- pos_vec: torch.Tensor, # (B,) next write pos per row after prefill
907
  include_size: bool = True,
908
  max_objects: int = 50,
909
- lora=None):
 
910
  """
911
- Batched decode loop for multi-label detection.
912
-
913
- - Uses a shared scalar position id per step (q_len = 1), as expected by RoPE.
914
- - Maintains a per-row attention mask and 'alive' flags.
915
- - Feeds coord encoders with (B,1) tensors; size encoder with (B,2).
916
- Returns: list-of-lists of dicts, length B.
917
  """
918
  B = hidden.size(0)
919
  device = self.device
920
  out = [[] for _ in range(B)]
921
  eos_id = self.config.tokenizer.eos_id
922
-
923
- # Per-row initial visibility up to each row's individual prefill pos
924
  max_ctx = self.config.text.max_context
925
- mask = torch.zeros(B, 1, max_ctx, device=device, dtype=torch.bool)
926
- for i in range(B):
927
- mask[i, :, : int(pos_vec[i].item())] = 1
928
 
929
- # Shared write index so RoPE sees a scalar q_len=1 position id
930
- pos = int(pos_vec.max().item())
 
 
 
931
 
932
  alive = torch.ones(B, dtype=torch.bool, device=device)
933
  counts = torch.zeros(B, dtype=torch.int32, device=device)
@@ -935,90 +940,72 @@ class MoondreamModel(nn.Module):
935
  with torch.inference_mode():
936
  while alive.any() and (counts < max_objects).any():
937
  # --- x coordinate ---
938
- x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
939
  if x_logits.dim() == 3:
940
- x_logits = x_logits.squeeze(1) # (B,1024)
941
- x_bin = x_logits.argmax(dim=-1).to(torch.float32) # (B,)
942
- x_center = x_bin / float(x_logits.size(-1)) # (B,)
943
- x_input = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
944
- x_emb = encode_coordinate(x_input, self.region).unsqueeze(1) # (B,1,C)
945
 
946
- # Advance visibility at shared 'pos' and decode (q_len=1)
947
- mask[alive, :, pos] = 1
948
- logits, hidden = self._decode_one_tok(
949
- x_emb,
950
- mask.unsqueeze(2), # (B,1,1,max_ctx) ✅
951
- torch.tensor([pos], device=device, dtype=torch.long),
952
- lora,
953
- )
954
  pos += 1
 
955
 
956
  # --- y coordinate ---
957
  y_logits = decode_coordinate(hidden, self.region)
958
  if y_logits.dim() == 3:
959
- y_logits = y_logits.squeeze(1) # (B,1024)
960
  y_bin = y_logits.argmax(dim=-1).to(torch.float32)
961
- y_center = y_bin / float(y_logits.size(-1)) # (B,)
962
- y_input = y_center.to(dtype=y_logits.dtype).unsqueeze(-1) # (B,1)
963
- y_emb = encode_coordinate(y_input, self.region).unsqueeze(1) # (B,1,C)
964
 
965
- mask[alive, :, pos] = 1
966
- logits, hidden = self._decode_one_tok(
967
- y_emb,
968
- mask.unsqueeze(2), # (B,1,1,max_ctx) ✅
969
- torch.tensor([pos], device=device, dtype=torch.long),
970
- lora,
971
- )
972
  pos += 1
 
973
 
974
  if include_size:
975
- # --- size (batched) ---
976
  size_logits = decode_size(hidden, self.region)
977
  w_logits, h_logits = size_logits[0].squeeze(1), size_logits[1].squeeze(1)
978
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
979
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
980
- # log-scale bin → actual size in [0,1]
981
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
982
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
983
- size_input = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
984
- size_emb = encode_size(size_input, self.region).unsqueeze(1) # (B,1,C)
985
 
986
- # Commit boxes for alive rows
987
  for i in range(B):
988
- if not alive[i]:
989
- continue
990
- out[i].append({
991
- "x_min": (x_center[i] - w[i] / 2).item(),
992
- "y_min": (y_center[i] - h[i] / 2).item(),
993
- "x_max": (x_center[i] + w[i] / 2).item(),
994
- "y_max": (y_center[i] + h[i] / 2).item(),
995
- })
996
 
997
- mask[alive, :, pos] = 1
998
- logits, hidden = self._decode_one_tok(
999
- size_emb,
1000
- mask.unsqueeze(2), # (B,1,1,max_ctx) ✅
1001
- torch.tensor([pos], device=device, dtype=torch.long),
1002
- lora,
1003
- )
1004
  pos += 1
1005
- next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
 
1006
  else:
1007
- # Points mode (no size)
1008
  for i in range(B):
1009
  if alive[i]:
1010
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1011
- mask[alive, :, pos] = 1
1012
- logits, hidden = self._decode_one_tok(
1013
- y_emb,
1014
- mask.unsqueeze(2), # (B,1,1,max_ctx) ✅
1015
- torch.tensor([pos], device=device, dtype=torch.long),
1016
- lora,
1017
- )
1018
  pos += 1
1019
- next_tok = logits.argmax(dim=-1).squeeze(-1)
 
1020
 
1021
- # Finish rows that emitted EOS or hit object cap
1022
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1023
  counts = counts + (~finished_now & alive).to(counts.dtype)
1024
  alive &= ~finished_now
@@ -1026,8 +1013,6 @@ class MoondreamModel(nn.Module):
1026
  return out
1027
 
1028
 
1029
-
1030
-
1031
  def detect_multi(self, image, objects, settings=None):
1032
  """
1033
  Parallel multi-label detection.
@@ -1038,35 +1023,33 @@ class MoondreamModel(nn.Module):
1038
  Returns:
1039
  {"objects": {label: [box_dict, ...]}}
1040
  """
1041
-
1042
-
1043
  if self.config.tokenizer.templates["detect"] is None:
1044
  raise NotImplementedError("Model does not support object detection.")
1045
  settings = settings or {}
1046
-
1047
- # Encode once; reuse caches
1048
  image = self.encode_image(image, settings)
1049
  B = len(objects)
1050
  self._load_encoded_image_batched(image, B)
1051
-
1052
- # Optional LoRA variant (same as detect())
1053
  lora = None
1054
  if "variant" in settings:
1055
  lora = variant_state_dict(settings["variant"], device=self.device)
1056
-
1057
- # Prefill all prompts at once
1058
- last_hidden, next_token, pos_vec = self._prefill_prompt_batched(
1059
  objects, image.pos, lora=lora, temperature=0.0, top_p=0.0
1060
  )
1061
-
1062
  # Batched decode loop
1063
  max_objects = settings.get("max_objects", 50)
1064
  det_lists = self._generate_points_batched(
1065
- last_hidden, next_token, pos_vec,
1066
  include_size=True, max_objects=max_objects, lora=lora
1067
  )
1068
-
1069
- # Map back to labels and add "label" tags
1070
  res = {}
1071
  for lab, lst in zip(objects, det_lists):
1072
  for d in lst:
@@ -1074,6 +1057,7 @@ class MoondreamModel(nn.Module):
1074
  res[lab] = lst
1075
  return {"objects": res}
1076
 
 
1077
  def _detect_gaze(
1078
  self,
1079
  image: EncodedImage,
 
850
  b.kv_cache.k_cache[:, :, :T, :] = k.expand(batch_size, -1, -1, -1)
851
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
852
 
853
+ def _prefill_prompt_batched(
854
+ self,
855
+ labels,
856
+ pos: int,
857
+ lora=None,
858
+ temperature: float = 0.0,
859
+ top_p: float = 0.0,
860
+ ):
861
  """
862
+ Build detect prompts for many labels, pad to the same length, prefill once as a batch.
863
+ Returns (last_hidden per row, next_token per row, shared_pos_end scalar).
864
  """
865
  tpl = self.config.tokenizer.templates["detect"]
866
  if tpl is None:
867
  raise NotImplementedError("Model does not support object detection (no detect template).")
868
+
869
  rows, lens = [], []
870
  for lab in labels:
871
  ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
872
  rows.append(torch.tensor(ids, device=self.device, dtype=torch.long))
873
  lens.append(len(ids))
874
+ B = len(rows)
875
+ T = max(lens)
876
  eos = self.config.tokenizer.eos_id
877
+
878
+ # Pad to T with eos, so we can prefill with a single shared position range
879
  prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
880
  for i, ids in enumerate(rows):
881
  prompt_ids[i, : ids.numel()] = ids
882
+
883
  # Embed & prefill once
884
+ prompt_emb = text_encoder(prompt_ids, self.text) # (B, T, C)
885
+ torch._dynamo.mark_dynamic(prompt_emb, 1) # allow variable T
886
+
887
+ # 4-D mask form makes head broadcasting unambiguous later
888
+ attn = self.attn_mask
889
+ mask = attn[:, :, pos : pos + T, :].expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
890
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long)
891
+
892
  hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B, T, C)
893
  logits_BTV = lm_head(hidden_BTC, self.text) # (B, T, V)
894
+
895
+ # Take the last *real* token per row
896
  idx = (torch.tensor(lens, device=self.device, dtype=torch.long) - 1).clamp_min(0)
897
+ last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
898
+ last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
899
+
900
  if temperature == 0.0:
901
+ next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
902
  else:
903
  probs = torch.softmax(last_logits / temperature, dim=-1)
904
  probs = self._apply_top_p(probs, top_p)
905
+ next_token = torch.multinomial(probs, num_samples=1) # (B,1)
906
+
907
+ # Shared "next decode position" for all rows (we prefilled up to pos+T-1)
908
+ pos_end = pos + T
909
+ return last_hidden, next_token, pos_end # (B,1,C), (B,1), int
910
 
 
911
 
912
  def _generate_points_batched(
913
  self,
914
+ hidden, # (B,1,C)
915
+ next_token, # (B,1)
916
+ pos: int, # shared scalar next position
917
  include_size: bool = True,
918
  max_objects: int = 50,
919
+ lora=None,
920
+ ):
921
  """
922
+ Vectorized version of _generate_points() that decodes x -> y -> size -> next-token
923
+ for all rows in the batch simultaneously. Returns list-of-lists of dicts, len B.
 
 
 
 
924
  """
925
  B = hidden.size(0)
926
  device = self.device
927
  out = [[] for _ in range(B)]
928
  eos_id = self.config.tokenizer.eos_id
 
 
929
  max_ctx = self.config.text.max_context
 
 
 
930
 
931
+ # 4-D mask: (B, 1, q_len=1, kv_len)
932
+ mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
933
+ if pos > 0:
934
+ mask[:, :, :, :pos] = True
935
+ pos_id = torch.tensor([pos], device=device, dtype=torch.long) # (1,)
936
 
937
  alive = torch.ones(B, dtype=torch.bool, device=device)
938
  counts = torch.zeros(B, dtype=torch.int32, device=device)
 
940
  with torch.inference_mode():
941
  while alive.any() and (counts < max_objects).any():
942
  # --- x coordinate ---
943
+ x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
944
  if x_logits.dim() == 3:
945
+ x_logits = x_logits.squeeze(1)
946
+ x_bin = x_logits.argmax(dim=-1).to(torch.float32) # (B,)
947
+ x_center = x_bin / float(x_logits.size(-1)) # (B,)
948
+ x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
949
+ x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
950
 
951
+ # advance attention one step
952
+ mask[:, :, :, pos] = True
953
+ logits, hidden = self._decode_one_tok(x_emb, mask, pos_id, lora)
 
 
 
 
 
954
  pos += 1
955
+ pos_id[0] = pos
956
 
957
  # --- y coordinate ---
958
  y_logits = decode_coordinate(hidden, self.region)
959
  if y_logits.dim() == 3:
960
+ y_logits = y_logits.squeeze(1)
961
  y_bin = y_logits.argmax(dim=-1).to(torch.float32)
962
+ y_center = y_bin / float(y_logits.size(-1)) # (B,)
963
+ y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1) # (B,1)
964
+ y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
965
 
966
+ mask[:, :, :, pos] = True
967
+ logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
 
 
 
 
 
968
  pos += 1
969
+ pos_id[0] = pos
970
 
971
  if include_size:
972
+ # --- size ---
973
  size_logits = decode_size(hidden, self.region)
974
  w_logits, h_logits = size_logits[0].squeeze(1), size_logits[1].squeeze(1)
975
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
976
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
977
+ # bins -> size in [0,1] (inverse of log-scale mapping)
978
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
979
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
980
+ size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
981
+ size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
982
 
983
+ # record boxes
984
  for i in range(B):
985
+ if alive[i]:
986
+ out[i].append({
987
+ "x_min": (x_center[i] - w[i] / 2).item(),
988
+ "y_min": (y_center[i] - h[i] / 2).item(),
989
+ "x_max": (x_center[i] + w[i] / 2).item(),
990
+ "y_max": (y_center[i] + h[i] / 2).item(),
991
+ })
 
992
 
993
+ mask[:, :, :, pos] = True
994
+ logits, hidden = self._decode_one_tok(size_emb, mask, pos_id, lora)
 
 
 
 
 
995
  pos += 1
996
+ pos_id[0] = pos
997
+ next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
998
  else:
 
999
  for i in range(B):
1000
  if alive[i]:
1001
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1002
+
1003
+ mask[:, :, :, pos] = True
1004
+ logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
 
 
 
 
1005
  pos += 1
1006
+ pos_id[0] = pos
1007
+ next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1008
 
 
1009
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1010
  counts = counts + (~finished_now & alive).to(counts.dtype)
1011
  alive &= ~finished_now
 
1013
  return out
1014
 
1015
 
 
 
1016
  def detect_multi(self, image, objects, settings=None):
1017
  """
1018
  Parallel multi-label detection.
 
1023
  Returns:
1024
  {"objects": {label: [box_dict, ...]}}
1025
  """
 
 
1026
  if self.config.tokenizer.templates["detect"] is None:
1027
  raise NotImplementedError("Model does not support object detection.")
1028
  settings = settings or {}
1029
+
1030
+ # Encode once; reuse caches for B rows
1031
  image = self.encode_image(image, settings)
1032
  B = len(objects)
1033
  self._load_encoded_image_batched(image, B)
1034
+
1035
+ # Optional LoRA variant
1036
  lora = None
1037
  if "variant" in settings:
1038
  lora = variant_state_dict(settings["variant"], device=self.device)
1039
+
1040
+ # Prefill all prompts as a batch; shared next position
1041
+ last_hidden, next_token, pos_end = self._prefill_prompt_batched(
1042
  objects, image.pos, lora=lora, temperature=0.0, top_p=0.0
1043
  )
1044
+
1045
  # Batched decode loop
1046
  max_objects = settings.get("max_objects", 50)
1047
  det_lists = self._generate_points_batched(
1048
+ last_hidden, next_token, pos_end,
1049
  include_size=True, max_objects=max_objects, lora=lora
1050
  )
1051
+
1052
+ # Map back to labels and tag
1053
  res = {}
1054
  for lab, lst in zip(objects, det_lists):
1055
  for d in lst:
 
1057
  res[lab] = lst
1058
  return {"objects": res}
1059
 
1060
+
1061
  def _detect_gaze(
1062
  self,
1063
  image: EncodedImage,