HV-Khurdula commited on
Commit
e80a71a
·
verified ·
1 Parent(s): 40df63a

Update moondream.py

Browse files
Files changed (1) hide show
  1. moondream.py +149 -129
moondream.py CHANGED
@@ -76,48 +76,54 @@ class KVCache(nn.Module):
76
  def update(self, pos_ids, k, v):
77
  """
78
  Supports:
79
- • Prefill: k,v = (B, n_kv, q_len, d), pos_ids = (q_len,)
80
- 1-step : k,v = (B, n_kv, 1, d), pos_ids = (B,) or scalar
 
 
81
  """
82
- kout, vout = self.k_cache, self.v_cache
83
 
 
84
  if not torch.is_tensor(pos_ids):
85
  pos_ids = torch.tensor(pos_ids, device=k.device, dtype=torch.long)
86
  else:
87
- pos_ids = pos_ids.to(k.device, dtype=torch.long)
88
 
89
  if k.dim() != 4 or v.dim() != 4:
90
  raise RuntimeError(f"KV update expects k,v 4D. Got k={tuple(k.shape)} v={tuple(v.shape)}")
 
91
  B, Hkv, q_len, D = k.shape
92
 
93
- # Expand caches’ batch dim if needed
94
- if kout.size(0) != B:
95
- if kout.size(0) == 1:
96
- self.k_cache = kout.expand(B, -1, -1, -1).clone()
97
- self.v_cache = vout.expand(B, -1, -1, -1).clone()
98
  kout, vout = self.k_cache, self.v_cache
99
  else:
100
- raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
101
 
102
- # Prefill (vector of positions shared across the batch)
103
  if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
104
  for i in range(B):
105
  kout[i, :, pos_ids, :] = k[i]
106
  vout[i, :, pos_ids, :] = v[i]
107
  return kout, vout
108
 
109
- # 1-step with per-row positions
110
- if q_len == 1 and pos_ids.numel() in {1, B}:
111
- if pos_ids.numel() == 1:
112
- pi = int(pos_ids.item())
113
- kout[:, :, pi, :] = k[:, :, 0, :]
114
- vout[:, :, pi, :] = v[:, :, 0, :]
115
- else:
116
- pos_ids = pos_ids.view(B)
117
- for i in range(B):
118
- pi = int(pos_ids[i].item())
119
- kout[i, :, pi, :] = k[i, :, 0, :]
120
- vout[i, :, pi, :] = v[i, :, 0, :]
 
 
121
  return kout, vout
122
 
123
  raise RuntimeError(f"Unsupported KV update combo: k={tuple(k.shape)}, pos_ids={tuple(pos_ids.shape)}")
@@ -211,6 +217,7 @@ class MoondreamModel(nn.Module):
211
 
212
 
213
 
 
214
 
215
 
216
  def _setup_caches(self):
@@ -562,47 +569,46 @@ class MoondreamModel(nn.Module):
562
 
563
  return generator(next_token, pos)
564
 
565
- def encode_image(
566
- self,
567
- image: Union[Image.Image, EncodedImage],
568
- settings: Optional[ImageEncodingSettings] = None,
569
- ) -> EncodedImage:
570
- # Always start from single-row caches; avoids leftovers from batched runs
571
- self._setup_caches() # re-create caches (B=1)
572
- for blk in self.text.blocks: # make absolutely sure batch dim == 1
573
- if blk.kv_cache.k_cache.size(0) != 1:
574
- blk.kv_cache.k_cache = blk.kv_cache.k_cache[:1].contiguous()
575
- blk.kv_cache.v_cache = blk.kv_cache.v_cache[:1].contiguous()
576
 
577
  if isinstance(image, EncodedImage):
578
  return image
579
  if not isinstance(image, Image.Image):
580
  raise ValueError("image must be a PIL Image or EncodedImage")
581
 
582
- lora = (variant_state_dict(settings["variant"], device=self.device)
583
- if settings is not None and "variant" in settings else None)
 
 
 
 
 
 
584
 
585
  with torch.inference_mode():
586
- img_emb = self._run_vision_encoder(image)
587
- bos_emb = text_encoder(torch.tensor([[self.config.tokenizer.bos_id]], device=self.device), self.text)
588
- inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
589
- mask = self.attn_mask[:, :, :inputs_embeds.size(1), :]
590
- pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long, device=self.device)
 
591
  self._prefill(inputs_embeds, mask, pos_ids, lora)
592
 
 
593
  return EncodedImage(
594
- pos=inputs_embeds.size(1),
595
  caches=[
596
- (
597
- b.kv_cache.k_cache[:, :, :inputs_embeds.size(1), :].clone(),
598
- b.kv_cache.v_cache[:, :, :inputs_embeds.size(1), :].clone(),
599
- )
600
  for b in self.text.blocks
601
  ],
602
  )
603
 
604
 
605
 
 
606
  def query(
607
  self,
608
  image: Optional[Union[Image.Image, EncodedImage]] = None,
@@ -893,8 +899,36 @@ class MoondreamModel(nn.Module):
893
 
894
  return {"points": objects}
895
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
896
 
897
- # -------------------- batched helpers --------------------
898
  def _load_encoded_image_batched(self, encoded_image, batch_size: int):
899
  for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
900
  T = k.size(2)
@@ -906,60 +940,62 @@ class MoondreamModel(nn.Module):
906
  b.kv_cache.k_cache[:, :, :T, :] = k.expand(batch_size, -1, -1, -1)
907
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
908
 
909
- def _prefill_prompt_batched(self, labels, pos: int, lora=None, temperature: float = 0.0, top_p: float = 0.0):
 
 
910
  tpl = self.config.tokenizer.templates["detect"]
911
  if tpl is None:
912
- raise NotImplementedError("Model does not support object detection (no detect template).")
913
 
914
  rows, lens = [], []
915
  for lab in labels:
916
  ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
917
  t = torch.tensor(ids, device=self.device, dtype=torch.long)
918
  rows.append(t); lens.append(t.numel())
919
- B = len(rows); T = max(lens)
920
  eos = self.config.tokenizer.eos_id
921
 
922
  prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
923
  for i, ids in enumerate(rows):
924
  prompt_ids[i, : ids.numel()] = ids
925
 
926
- prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
927
- torch._dynamo.mark_dynamic(prompt_emb, 1)
928
 
929
- # mask: (B,1,T,kv_len)
930
- base = self.attn_mask[:, :, pos:pos+T, :] # (1,1,T,K)
931
- mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
932
 
933
- # IMPORTANT: for prefill pass a 1-D vector of length T (matches upstream)
934
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
935
- hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
936
- logits_BTV = lm_head(hidden_BTC, self.text)
937
 
938
- idx = (torch.tensor(lens, device=self.device) - 1).clamp_min(0)
939
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
940
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
941
 
942
  if temperature == 0.0:
943
- next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
944
  else:
945
  probs = torch.softmax(last_logits / temperature, dim=-1)
946
  probs = self._apply_top_p(probs, top_p)
947
- next_token = torch.multinomial(probs, num_samples=1) # (B,1)
948
 
949
- # shared scalar end position
950
- return last_hidden, next_token, int(pos + T)
 
 
951
 
952
 
953
 
954
  def _generate_points_batched(
955
  self,
956
  hidden, # (B,1,C)
957
- next_token, # (B,1) (ignored in greedy)
958
  pos, # int
959
  include_size: bool = True,
960
  max_objects: int = 50,
961
  lora=None,
962
- use_soft_argmax: bool = False, # default OFF to match upstream numerics
963
  ):
964
  B = hidden.size(0)
965
  device = self.device
@@ -967,78 +1003,67 @@ class MoondreamModel(nn.Module):
967
  eos_id = self.config.tokenizer.eos_id
968
  max_ctx = self.config.text.max_context
969
 
970
- # 4-D mask: (B,1,1,kv_len)
971
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
972
- if pos > 0:
973
- mask[:, :, :, :pos] = True
 
974
 
975
- # rotary & KV path are happiest with a 1-D scalar position vector (like upstream)
976
- pos_id_vec = torch.tensor([pos], device=device, dtype=torch.long) # (1,)
 
 
 
 
 
 
977
 
978
  alive = torch.ones(B, dtype=torch.bool, device=device)
979
  counts = torch.zeros(B, dtype=torch.int32, device=device)
980
 
981
- def _center01(logits_2d):
982
- # logits_2d: (B, bins)
983
- if logits_2d.dim() == 3: # (B,1,bins) -> (B,bins)
984
- logits_2d = logits_2d.squeeze(1)
985
- bins = logits_2d.size(-1)
986
- if use_soft_argmax:
987
- p = torch.softmax(logits_2d, dim=-1)
988
- idx = (p * torch.arange(bins, device=logits_2d.device, dtype=torch.float32)).sum(dim=-1)
989
- return idx / float(bins) # match upstream scale
990
- else:
991
- return logits_2d.argmax(dim=-1).to(torch.float32) / float(bins)
992
-
993
  with torch.inference_mode():
994
  while alive.any() and (counts < max_objects).any():
995
- # --- x ---
996
- x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
997
- x_center = _center01(x_logits) # (B,) in [0,1]
998
- x_emb = encode_coordinate(x_center.to(x_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1) # (B,1,C)
 
999
 
1000
- mask[alive, :, :, pos] = True
1001
- _, hidden = self._decode_one_tok(x_emb, mask, pos_id_vec, lora)
1002
- pos += 1
1003
- pos_id_vec[0] = pos
1004
 
1005
- # --- y ---
1006
  y_logits = decode_coordinate(hidden, self.region)
1007
- y_center = _center01(y_logits)
1008
- y_emb = encode_coordinate(y_center.to(y_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
 
1009
 
1010
- mask[alive, :, :, pos] = True
1011
- _, hidden = self._decode_one_tok(y_emb, mask, pos_id_vec, lora)
1012
- pos += 1
1013
- pos_id_vec[0] = pos
1014
 
1015
  if include_size:
1016
- # --- size ---
1017
  size_ret = decode_size(hidden, self.region)
1018
- # Works for tuple or stacked tensor
1019
- if isinstance(size_ret, (tuple, list)):
1020
- w_logits, h_logits = size_ret
 
 
 
1021
  else:
1022
- # expected shapes: (B,2,1024) or (B,1,2,1024)
1023
- if size_ret.dim() == 3: # (B,2,1024)
1024
- w_logits, h_logits = size_ret[:, 0], size_ret[:, 1]
1025
- else: # (B,1,2,1024)
1026
- w_logits, h_logits = size_ret[:, 0, 0], size_ret[:, 0, 1]
1027
- if w_logits.dim() == 3: w_logits = w_logits.squeeze(1)
1028
- if h_logits.dim() == 3: h_logits = h_logits.squeeze(1)
1029
 
1030
- # bins -> size via the same inverse log2 scale as upstream
1031
- w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1032
- h_bin = h_logits.argmax(dim=-1).to(torch.float32)
1033
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
1034
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
1035
 
1036
- size_emb = encode_size(torch.stack([w, h], dim=1).to(w_logits.dtype), self.region).unsqueeze(1)
1037
 
1038
- # record boxes (clamped)
1039
  for i in range(B):
1040
- if not alive[i]:
1041
- continue
1042
  xl = (x_center[i] - w[i] / 2).item()
1043
  xr = (x_center[i] + w[i] / 2).item()
1044
  yt = (y_center[i] - h[i] / 2).item()
@@ -1050,19 +1075,17 @@ class MoondreamModel(nn.Module):
1050
  "y_max": max(0.0, min(1.0, yb)),
1051
  })
1052
 
1053
- mask[alive, :, :, pos] = True
1054
- logits, hidden = self._decode_one_tok(size_emb, mask, pos_id_vec, lora)
1055
- pos += 1
1056
- pos_id_vec[0] = pos
1057
- next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1058
  else:
1059
  for i in range(B):
1060
  if alive[i]:
1061
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1062
- mask[alive, :, :, pos] = True
1063
- logits, hidden = self._decode_one_tok(y_emb, mask, pos_id_vec, lora)
1064
- pos += 1
1065
- pos_id_vec[0] = pos
1066
  next_tok = logits.argmax(dim=-1).squeeze(-1)
1067
 
1068
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
@@ -1071,10 +1094,6 @@ class MoondreamModel(nn.Module):
1071
 
1072
  return out
1073
 
1074
-
1075
-
1076
-
1077
-
1078
  def detect_multi(self, image, objects, settings=None):
1079
  if self.config.tokenizer.templates["detect"] is None:
1080
  raise NotImplementedError("Model does not support object detection.")
@@ -1103,7 +1122,7 @@ class MoondreamModel(nn.Module):
1103
  d["label"] = lab
1104
  res[lab] = lst
1105
 
1106
- # make subsequent single-image calls stable
1107
  self._reset_kv_caches(1)
1108
  return {"objects": res}
1109
 
@@ -1111,6 +1130,7 @@ class MoondreamModel(nn.Module):
1111
 
1112
 
1113
 
 
1114
  def _detect_gaze(
1115
  self,
1116
  image: EncodedImage,
 
76
  def update(self, pos_ids, k, v):
77
  """
78
  Supports:
79
+ • Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
80
+ Step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,) or (B,1)
81
+ • Legacy: k,v = (1, n_kv_heads, q_len, d), pos_ids = scalar
82
+ Writes into self.k_cache/self.v_cache shaped (B, n_kv_heads, T_max, d).
83
  """
84
+ kout, vout = self.kv_cache if hasattr(self, "kv_cache") else (self.k_cache, self.v_cache)
85
 
86
+ # normalize pos_ids
87
  if not torch.is_tensor(pos_ids):
88
  pos_ids = torch.tensor(pos_ids, device=k.device, dtype=torch.long)
89
  else:
90
+ pos_ids = pos_ids.to(device=k.device, dtype=torch.long)
91
 
92
  if k.dim() != 4 or v.dim() != 4:
93
  raise RuntimeError(f"KV update expects k,v 4D. Got k={tuple(k.shape)} v={tuple(v.shape)}")
94
+
95
  B, Hkv, q_len, D = k.shape
96
 
97
+ # match cache batch to B (expand-from-1 allowed)
98
+ if self.k_cache.size(0) != B:
99
+ if self.k_cache.size(0) == 1:
100
+ self.k_cache = self.k_cache.expand(B, -1, -1, -1).clone()
101
+ self.v_cache = self.v_cache.expand(B, -1, -1, -1).clone()
102
  kout, vout = self.k_cache, self.v_cache
103
  else:
104
+ raise RuntimeError(f"KV cache batch mismatch: cache.B={self.k_cache.size(0)} vs k.B={B}")
105
 
106
+ # Case A: prefill — vector of length q_len (same for all rows)
107
  if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
108
  for i in range(B):
109
  kout[i, :, pos_ids, :] = k[i]
110
  vout[i, :, pos_ids, :] = v[i]
111
  return kout, vout
112
 
113
+ # Case B: single step — q_len==1 with per-row positions
114
+ if q_len == 1 and pos_ids.numel() == B:
115
+ pos_ids = pos_ids.view(B)
116
+ for i in range(B):
117
+ pi = int(pos_ids[i].item())
118
+ kout[i, :, pi, :] = k[i, :, 0, :]
119
+ vout[i, :, pi, :] = v[i, :, 0, :]
120
+ return kout, vout
121
+
122
+ # Case C: scalar & q_len==1
123
+ if pos_ids.dim() == 0 and q_len == 1:
124
+ pi = int(pos_ids.item())
125
+ kout[:, :, pi, :] = k[:, :, 0, :]
126
+ vout[:, :, pi, :] = v[:, :, 0, :]
127
  return kout, vout
128
 
129
  raise RuntimeError(f"Unsupported KV update combo: k={tuple(k.shape)}, pos_ids={tuple(pos_ids.shape)}")
 
217
 
218
 
219
 
220
+
221
 
222
 
223
  def _setup_caches(self):
 
569
 
570
  return generator(next_token, pos)
571
 
572
+ def encode_image(self, image, settings=None) -> EncodedImage:
573
+ # always start from B=1 to avoid leftovers from batched runs
574
+ self._setup_caches() # recreates caches with B=1
 
 
 
 
 
 
 
 
575
 
576
  if isinstance(image, EncodedImage):
577
  return image
578
  if not isinstance(image, Image.Image):
579
  raise ValueError("image must be a PIL Image or EncodedImage")
580
 
581
+ # hard-trim to B=1 if external code changed it
582
+ for blk in self.text.blocks:
583
+ if blk.kv_cache.k_cache.size(0) != 1:
584
+ blk.kv_cache.k_cache = blk.kv_cache.k_cache[:1].contiguous()
585
+ blk.kv_cache.v_cache = blk.kv_cache.v_cache[:1].contiguous()
586
+
587
+ lora = variant_state_dict(settings["variant"], device=self.device) \
588
+ if settings and "variant" in settings else None
589
 
590
  with torch.inference_mode():
591
+ img_emb = self._run_vision_encoder(image) # (T_img,C)
592
+ bos_emb = text_encoder(torch.tensor([[self.config.tokenizer.bos_id]], device=self.device), self.text) # (1,1,C)
593
+ inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1) # (1,T0,C)
594
+
595
+ mask = self.attn_mask[:, :, :inputs_embeds.size(1), :] # (1,1,T0,K)
596
+ pos_ids = torch.arange(inputs_embeds.size(1), device=self.device, dtype=torch.long) # (T0,)
597
  self._prefill(inputs_embeds, mask, pos_ids, lora)
598
 
599
+ T0 = inputs_embeds.size(1)
600
  return EncodedImage(
601
+ pos=T0,
602
  caches=[
603
+ (b.kv_cache.k_cache[:, :, :T0, :].clone(),
604
+ b.kv_cache.v_cache[:, :, :T0, :].clone())
 
 
605
  for b in self.text.blocks
606
  ],
607
  )
608
 
609
 
610
 
611
+
612
  def query(
613
  self,
614
  image: Optional[Union[Image.Image, EncodedImage]] = None,
 
899
 
900
  return {"points": objects}
901
 
902
+ def _norm_size_logits(self, size_ret, B: int):
903
+ """
904
+ Accepts any of:
905
+ • (w_logits, h_logits)
906
+ • Tensor (B,2,C) or (B,1,2,C) or (1,2,C) or (2,C) (B==1)
907
+ Returns (w_logits, h_logits) each shaped (B, C).
908
+ """
909
+ if isinstance(size_ret, (tuple, list)):
910
+ w_logits, h_logits = size_ret
911
+ else:
912
+ t = size_ret
913
+ # squeeze all singleton dims except batch & vocab
914
+ while t.dim() > 3:
915
+ t = t.squeeze(1)
916
+ if t.dim() == 3: # (B,2,C)
917
+ w_logits, h_logits = t[:, 0, :], t[:, 1, :]
918
+ elif t.dim() == 2:
919
+ if t.size(0) == 2 and B == 1: # (2,C) with B==1
920
+ w_logits, h_logits = t[0].unsqueeze(0), t[1].unsqueeze(0)
921
+ else: # (B,2C) fallback
922
+ C2 = t.size(1); C = C2 // 2
923
+ w_logits, h_logits = t[:, :C], t[:, C:]
924
+ else:
925
+ raise RuntimeError(f"Unexpected decode_size shape {tuple(t.shape)}")
926
+ # final squeeze if needed
927
+ if w_logits.dim() == 3: w_logits = w_logits.squeeze(1)
928
+ if h_logits.dim() == 3: h_logits = h_logits.squeeze(1)
929
+ return w_logits.contiguous(), h_logits.contiguous()
930
+
931
 
 
932
  def _load_encoded_image_batched(self, encoded_image, batch_size: int):
933
  for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
934
  T = k.size(2)
 
940
  b.kv_cache.k_cache[:, :, :T, :] = k.expand(batch_size, -1, -1, -1)
941
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
942
 
943
+
944
+ def _prefill_prompt_batched(self, labels, pos: int, lora=None,
945
+ temperature: float = 0.0, top_p: float = 0.0):
946
  tpl = self.config.tokenizer.templates["detect"]
947
  if tpl is None:
948
+ raise NotImplementedError("Model does not support object detection.")
949
 
950
  rows, lens = [], []
951
  for lab in labels:
952
  ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
953
  t = torch.tensor(ids, device=self.device, dtype=torch.long)
954
  rows.append(t); lens.append(t.numel())
955
+ B, T = len(rows), max(lens)
956
  eos = self.config.tokenizer.eos_id
957
 
958
  prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
959
  for i, ids in enumerate(rows):
960
  prompt_ids[i, : ids.numel()] = ids
961
 
962
+ prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
963
+ torch._dynamo.mark_dynamic(prompt_emb, 1) # allow variable T
964
 
965
+ base = self.attn_mask[:, :, pos:pos+T, :] # (1,1,T,K)
966
+ mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
 
967
 
 
968
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
969
+ hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
970
+ logits_BTV = lm_head(hidden_BTC, self.text) # (B,T,V)
971
 
972
+ idx = (torch.tensor(lens, device=self.device) - 1).clamp_min(0) # (B,)
973
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
974
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
975
 
976
  if temperature == 0.0:
977
+ next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
978
  else:
979
  probs = torch.softmax(last_logits / temperature, dim=-1)
980
  probs = self._apply_top_p(probs, top_p)
981
+ next_token = torch.multinomial(probs, num_samples=1) # (B,1)
982
 
983
+ # shared next-free position in cache (safe upper bound)
984
+ pos_end = int(pos + T)
985
+ return last_hidden, next_token, pos_end
986
+
987
 
988
 
989
 
990
  def _generate_points_batched(
991
  self,
992
  hidden, # (B,1,C)
993
+ next_token, # (B,1) (unused in greedy)
994
  pos, # int
995
  include_size: bool = True,
996
  max_objects: int = 50,
997
  lora=None,
998
+ use_soft_argmax: bool = True, # reduces jitter
999
  ):
1000
  B = hidden.size(0)
1001
  device = self.device
 
1003
  eos_id = self.config.tokenizer.eos_id
1004
  max_ctx = self.config.text.max_context
1005
 
1006
+ # 4-D mask: (B, 1, q_len=1, kv_len)
1007
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
1008
+ if int(pos) > 0:
1009
+ mask[:, :, :, :int(pos)] = True
1010
+ pos_ids = torch.full((B, 1), int(pos), device=device, dtype=torch.long)
1011
 
1012
+ # helper: (B, bins) -> (B,) in [0,1]
1013
+ def _argmax01(logits):
1014
+ if use_soft_argmax:
1015
+ probs = torch.softmax(logits, dim=-1)
1016
+ bins = torch.arange(probs.size(-1), device=logits.device, dtype=torch.float32)
1017
+ return (probs * bins).sum(dim=-1) / float(probs.size(-1) - 1)
1018
+ idx = logits.argmax(dim=-1).to(torch.float32)
1019
+ return idx / float(logits.size(-1) - 1)
1020
 
1021
  alive = torch.ones(B, dtype=torch.bool, device=device)
1022
  counts = torch.zeros(B, dtype=torch.int32, device=device)
1023
 
 
 
 
 
 
 
 
 
 
 
 
 
1024
  with torch.inference_mode():
1025
  while alive.any() and (counts < max_objects).any():
1026
+ # ---- x
1027
+ x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
1028
+ if x_logits.dim() == 3: x_logits = x_logits.squeeze(1)
1029
+ x_center = _argmax01(x_logits) # (B,)
1030
+ x_emb = encode_coordinate(x_center.to(dtype=x_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
1031
 
1032
+ mask[alive, :, :, pos_ids[0,0]] = True
1033
+ logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
1034
+ pos_ids[alive, 0] += 1
 
1035
 
1036
+ # ---- y
1037
  y_logits = decode_coordinate(hidden, self.region)
1038
+ if y_logits.dim() == 3: y_logits = y_logits.squeeze(1)
1039
+ y_center = _argmax01(y_logits)
1040
+ y_emb = encode_coordinate(y_center.to(dtype=y_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
1041
 
1042
+ mask[alive, :, :, pos_ids[0,0]] = True
1043
+ logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1044
+ pos_ids[alive, 0] += 1
 
1045
 
1046
  if include_size:
 
1047
  size_ret = decode_size(hidden, self.region)
1048
+ w_logits, h_logits = self._norm_size_logits(size_ret, B)
1049
+
1050
+ if use_soft_argmax:
1051
+ bins = torch.arange(w_logits.size(-1), device=device, dtype=torch.float32)
1052
+ w_bin = (torch.softmax(w_logits, dim=-1) * bins).sum(dim=-1)
1053
+ h_bin = (torch.softmax(h_logits, dim=-1) * bins).sum(dim=-1)
1054
  else:
1055
+ w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1056
+ h_bin = h_logits.argmax(dim=-1).to(torch.float32)
 
 
 
 
 
1057
 
1058
+ # inverse log scale used by md2
 
 
1059
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
1060
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
1061
 
1062
+ size_emb = encode_size(torch.stack([w, h], dim=1).to(dtype=w_logits.dtype), self.region).unsqueeze(1)
1063
 
1064
+ # write boxes only for alive rows
1065
  for i in range(B):
1066
+ if not alive[i]: continue
 
1067
  xl = (x_center[i] - w[i] / 2).item()
1068
  xr = (x_center[i] + w[i] / 2).item()
1069
  yt = (y_center[i] - h[i] / 2).item()
 
1075
  "y_max": max(0.0, min(1.0, yb)),
1076
  })
1077
 
1078
+ mask[alive, :, :, pos_ids[0,0]] = True
1079
+ logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
1080
+ pos_ids[alive, 0] += 1
1081
+ next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
 
1082
  else:
1083
  for i in range(B):
1084
  if alive[i]:
1085
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1086
+ mask[alive, :, :, pos_ids[0,0]] = True
1087
+ logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1088
+ pos_ids[alive, 0] += 1
 
1089
  next_tok = logits.argmax(dim=-1).squeeze(-1)
1090
 
1091
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
 
1094
 
1095
  return out
1096
 
 
 
 
 
1097
  def detect_multi(self, image, objects, settings=None):
1098
  if self.config.tokenizer.templates["detect"] is None:
1099
  raise NotImplementedError("Model does not support object detection.")
 
1122
  d["label"] = lab
1123
  res[lab] = lst
1124
 
1125
+ # restore B=1 so the next encode_image() starts clean
1126
  self._reset_kv_caches(1)
1127
  return {"objects": res}
1128
 
 
1130
 
1131
 
1132
 
1133
+
1134
  def _detect_gaze(
1135
  self,
1136
  image: EncodedImage,