HV-Khurdula commited on
Commit
40df63a
·
verified ·
1 Parent(s): b00c890

Update moondream.py

Browse files

fix: corrupted kv cache

Files changed (1) hide show
  1. moondream.py +150 -79
moondream.py CHANGED
@@ -64,33 +64,33 @@ class EncodedImage:
64
  pos: int
65
  caches: List[Tuple[torch.Tensor, torch.Tensor]]
66
 
 
67
  class KVCache(nn.Module):
68
  def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
69
  super().__init__()
70
- cache_shape = (1, n_kv_heads, max_context, dim // n_heads)
 
71
  self.register_buffer("k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype))
72
  self.register_buffer("v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype))
73
 
74
  def update(self, pos_ids, k, v):
75
  """
76
  Supports:
77
- • Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
78
- • 1-step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,1) or (B,)
79
- • Legacy: k,v = (B, n_kv_heads, 1, d), pos_ids = scalar
80
- Writes into self.k_cache/self.v_cache shaped (B, n_kv_heads, T_max, d).
81
  """
 
 
82
  if not torch.is_tensor(pos_ids):
83
  pos_ids = torch.tensor(pos_ids, device=k.device, dtype=torch.long)
84
  else:
85
- pos_ids = pos_ids.to(device=k.device, dtype=torch.long)
86
-
87
  if k.dim() != 4 or v.dim() != 4:
88
  raise RuntimeError(f"KV update expects k,v 4D. Got k={tuple(k.shape)} v={tuple(v.shape)}")
89
-
90
  B, Hkv, q_len, D = k.shape
91
- kout, vout = self.k_cache, self.v_cache
92
-
93
- # Expand caches from B=1 lazily if needed
94
  if kout.size(0) != B:
95
  if kout.size(0) == 1:
96
  self.k_cache = kout.expand(B, -1, -1, -1).clone()
@@ -98,30 +98,28 @@ class KVCache(nn.Module):
98
  kout, vout = self.k_cache, self.v_cache
99
  else:
100
  raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
101
-
102
- # Case A: prefill (same positions for every row)
103
  if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
104
  for i in range(B):
105
  kout[i, :, pos_ids, :] = k[i]
106
  vout[i, :, pos_ids, :] = v[i]
107
  return kout, vout
108
-
109
- # Case B: single step with per-row position (B,) or (B,1)
110
- if q_len == 1 and pos_ids.numel() == B:
111
- pos_ids = pos_ids.view(B)
112
- for i in range(B):
113
- pi = int(pos_ids[i].item())
114
- kout[i, :, pi, :] = k[i, :, 0, :]
115
- vout[i, :, pi, :] = v[i, :, 0, :]
116
- return kout, vout
117
-
118
- # Case C: scalar position for everyone
119
- if q_len == 1 and pos_ids.dim() == 0:
120
- pi = int(pos_ids.item())
121
- kout[:, :, pi, :] = k[:, :, 0, :]
122
- vout[:, :, pi, :] = v[:, :, 0, :]
123
  return kout, vout
124
-
125
  raise RuntimeError(f"Unsupported KV update combo: k={tuple(k.shape)}, pos_ids={tuple(pos_ids.shape)}")
126
 
127
 
@@ -129,6 +127,7 @@ class KVCache(nn.Module):
129
 
130
 
131
 
 
132
  class MoondreamModel(nn.Module):
133
 
134
  def __init__(
@@ -211,6 +210,7 @@ class MoondreamModel(nn.Module):
211
  blk.kv_cache.v_cache = torch.zeros(shape, device=device, dtype=dtype)
212
 
213
 
 
214
 
215
 
216
  def _setup_caches(self):
@@ -567,13 +567,13 @@ class MoondreamModel(nn.Module):
567
  image: Union[Image.Image, EncodedImage],
568
  settings: Optional[ImageEncodingSettings] = None,
569
  ) -> EncodedImage:
570
- # Top of encode_image(), just after type checks:
571
- self._setup_caches() # re-create caches
572
- for blk in self.text.blocks: # force B=1 for encode
573
  if blk.kv_cache.k_cache.size(0) != 1:
574
  blk.kv_cache.k_cache = blk.kv_cache.k_cache[:1].contiguous()
575
  blk.kv_cache.v_cache = blk.kv_cache.v_cache[:1].contiguous()
576
-
577
  if isinstance(image, EncodedImage):
578
  return image
579
  if not isinstance(image, Image.Image):
@@ -602,6 +602,7 @@ class MoondreamModel(nn.Module):
602
  )
603
 
604
 
 
605
  def query(
606
  self,
607
  image: Optional[Union[Image.Image, EncodedImage]] = None,
@@ -893,6 +894,7 @@ class MoondreamModel(nn.Module):
893
  return {"points": objects}
894
 
895
 
 
896
  def _load_encoded_image_batched(self, encoded_image, batch_size: int):
897
  for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
898
  T = k.size(2)
@@ -904,11 +906,7 @@ class MoondreamModel(nn.Module):
904
  b.kv_cache.k_cache[:, :, :T, :] = k.expand(batch_size, -1, -1, -1)
905
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
906
 
907
-
908
-
909
-
910
- def _prefill_prompt_batched(self, labels, pos: int, lora=None,
911
- temperature: float = 0.0, top_p: float = 0.0):
912
  tpl = self.config.tokenizer.templates["detect"]
913
  if tpl is None:
914
  raise NotImplementedError("Model does not support object detection (no detect template).")
@@ -925,39 +923,43 @@ class MoondreamModel(nn.Module):
925
  for i, ids in enumerate(rows):
926
  prompt_ids[i, : ids.numel()] = ids
927
 
928
- prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
929
  torch._dynamo.mark_dynamic(prompt_emb, 1)
930
 
931
- base = self.attn_mask[:, :, pos:pos+T, :] # (1,1,T,K)
932
- mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
933
- pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
934
 
935
- hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
 
 
936
  logits_BTV = lm_head(hidden_BTC, self.text)
937
 
938
- idx = (torch.tensor(lens, device=self.device) - 1).clamp_min(0) # (B,)
939
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
940
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
941
 
942
  if temperature == 0.0:
943
- next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
944
  else:
945
  probs = torch.softmax(last_logits / temperature, dim=-1)
946
  probs = self._apply_top_p(probs, top_p)
947
- next_token = torch.multinomial(probs, num_samples=1) # (B,1)
948
 
 
949
  return last_hidden, next_token, int(pos + T)
950
 
951
 
 
952
  def _generate_points_batched(
953
  self,
954
  hidden, # (B,1,C)
955
- next_token, # (B,1) (not used with greedy coords; kept for API)
956
- pos, # int, next free KV slot
957
  include_size: bool = True,
958
  max_objects: int = 50,
959
  lora=None,
960
- use_soft_argmax: bool = True,
961
  ):
962
  B = hidden.size(0)
963
  device = self.device
@@ -965,40 +967,110 @@ class MoondreamModel(nn.Module):
965
  eos_id = self.config.tokenizer.eos_id
966
  max_ctx = self.config.text.max_context
967
 
968
- # mask & position ids
969
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
970
  if pos > 0:
971
  mask[:, :, :, :pos] = True
972
- pos_ids = torch.full((B, 1), pos, device=device, dtype=torch.long)
 
 
973
 
974
  alive = torch.ones(B, dtype=torch.bool, device=device)
975
  counts = torch.zeros(B, dtype=torch.int32, device=device)
976
 
977
- def _argmax01(logits_2d):
978
  # logits_2d: (B, bins)
 
 
 
979
  if use_soft_argmax:
980
- probs = torch.softmax(logits_2d, dim=-1)
981
- bins = torch.arange(probs.size(-1), device=logits_2d.device, dtype=torch.float32)
982
- val = (probs * bins).sum(dim=-1) / (probs.size(-1) - 1)
983
- return val # in [0,1]
984
- idx = logits_2d.argmax(dim=-1).to(torch.float32)
985
- return idx / float(logits_2d.size(-1) - 1)
986
 
987
  with torch.inference_mode():
988
  while alive.any() and (counts < max_objects).any():
989
- # x
990
- x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
991
- if x_logits.dim() == 3: x_logits = x_logits.squeeze(1)
992
- x_center = _argmax01(x_logits) # (B,)
993
- x_emb = encode_coordinate(x_center.to(dtype=x_logits.dtype).unsqueeze(-1),
994
- self.region).unsqueeze(1) # (B,1,C)
995
  mask[alive, :, :, pos] = True
996
- _, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
997
- pos_ids[alive, 0] += 1; pos += 1
 
998
 
999
- # y
1000
  y_logits = decode_coordinate(hidden, self.region)
1001
- if y_logits.dim
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1002
 
1003
 
1004
 
@@ -1008,38 +1080,37 @@ class MoondreamModel(nn.Module):
1008
  raise NotImplementedError("Model does not support object detection.")
1009
  settings = settings or {}
1010
 
1011
- # Encode once; reuse caches for B rows
1012
- image = self.encode_image(image, settings)
1013
  B = len(objects)
1014
- self._load_encoded_image_batched(image, B)
1015
 
1016
- lora = None
1017
- if "variant" in settings:
1018
- lora = variant_state_dict(settings["variant"], device=self.device)
1019
 
1020
  last_hidden, next_token, pos_end = self._prefill_prompt_batched(
1021
- objects, image.pos, lora=lora, temperature=0.0, top_p=0.0
1022
  )
1023
 
1024
- max_objects = settings.get("max_objects", 50)
1025
  det_lists = self._generate_points_batched(
1026
  last_hidden, next_token, pos_end,
1027
- include_size=True, max_objects=max_objects, lora=lora
 
 
1028
  )
1029
 
1030
- # Map back to labels and tag
1031
  res = {}
1032
  for lab, lst in zip(objects, det_lists):
1033
  for d in lst:
1034
  d["label"] = lab
1035
  res[lab] = lst
1036
 
1037
- # IMPORTANT: restore caches to B=1 so future calls are safe
1038
  self._reset_kv_caches(1)
1039
  return {"objects": res}
1040
 
1041
 
1042
 
 
 
1043
  def _detect_gaze(
1044
  self,
1045
  image: EncodedImage,
 
64
  pos: int
65
  caches: List[Tuple[torch.Tensor, torch.Tensor]]
66
 
67
+ # -------------------- KVCache --------------------
68
  class KVCache(nn.Module):
69
  def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
70
  super().__init__()
71
+ head_dim = dim // n_heads
72
+ cache_shape = (1, n_kv_heads, max_context, head_dim)
73
  self.register_buffer("k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype))
74
  self.register_buffer("v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype))
75
 
76
  def update(self, pos_ids, k, v):
77
  """
78
  Supports:
79
+ • Prefill: k,v = (B, n_kv, q_len, d), pos_ids = (q_len,)
80
+ • 1-step : k,v = (B, n_kv, 1, d), pos_ids = (B,) or scalar
 
 
81
  """
82
+ kout, vout = self.k_cache, self.v_cache
83
+
84
  if not torch.is_tensor(pos_ids):
85
  pos_ids = torch.tensor(pos_ids, device=k.device, dtype=torch.long)
86
  else:
87
+ pos_ids = pos_ids.to(k.device, dtype=torch.long)
88
+
89
  if k.dim() != 4 or v.dim() != 4:
90
  raise RuntimeError(f"KV update expects k,v 4D. Got k={tuple(k.shape)} v={tuple(v.shape)}")
 
91
  B, Hkv, q_len, D = k.shape
92
+
93
+ # Expand caches’ batch dim if needed
 
94
  if kout.size(0) != B:
95
  if kout.size(0) == 1:
96
  self.k_cache = kout.expand(B, -1, -1, -1).clone()
 
98
  kout, vout = self.k_cache, self.v_cache
99
  else:
100
  raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
101
+
102
+ # Prefill (vector of positions shared across the batch)
103
  if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
104
  for i in range(B):
105
  kout[i, :, pos_ids, :] = k[i]
106
  vout[i, :, pos_ids, :] = v[i]
107
  return kout, vout
108
+
109
+ # 1-step with per-row positions
110
+ if q_len == 1 and pos_ids.numel() in {1, B}:
111
+ if pos_ids.numel() == 1:
112
+ pi = int(pos_ids.item())
113
+ kout[:, :, pi, :] = k[:, :, 0, :]
114
+ vout[:, :, pi, :] = v[:, :, 0, :]
115
+ else:
116
+ pos_ids = pos_ids.view(B)
117
+ for i in range(B):
118
+ pi = int(pos_ids[i].item())
119
+ kout[i, :, pi, :] = k[i, :, 0, :]
120
+ vout[i, :, pi, :] = v[i, :, 0, :]
 
 
121
  return kout, vout
122
+
123
  raise RuntimeError(f"Unsupported KV update combo: k={tuple(k.shape)}, pos_ids={tuple(pos_ids.shape)}")
124
 
125
 
 
127
 
128
 
129
 
130
+
131
  class MoondreamModel(nn.Module):
132
 
133
  def __init__(
 
210
  blk.kv_cache.v_cache = torch.zeros(shape, device=device, dtype=dtype)
211
 
212
 
213
+
214
 
215
 
216
  def _setup_caches(self):
 
567
  image: Union[Image.Image, EncodedImage],
568
  settings: Optional[ImageEncodingSettings] = None,
569
  ) -> EncodedImage:
570
+ # Always start from single-row caches; avoids leftovers from batched runs
571
+ self._setup_caches() # re-create caches (B=1)
572
+ for blk in self.text.blocks: # make absolutely sure batch dim == 1
573
  if blk.kv_cache.k_cache.size(0) != 1:
574
  blk.kv_cache.k_cache = blk.kv_cache.k_cache[:1].contiguous()
575
  blk.kv_cache.v_cache = blk.kv_cache.v_cache[:1].contiguous()
576
+
577
  if isinstance(image, EncodedImage):
578
  return image
579
  if not isinstance(image, Image.Image):
 
602
  )
603
 
604
 
605
+
606
  def query(
607
  self,
608
  image: Optional[Union[Image.Image, EncodedImage]] = None,
 
894
  return {"points": objects}
895
 
896
 
897
+ # -------------------- batched helpers --------------------
898
  def _load_encoded_image_batched(self, encoded_image, batch_size: int):
899
  for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
900
  T = k.size(2)
 
906
  b.kv_cache.k_cache[:, :, :T, :] = k.expand(batch_size, -1, -1, -1)
907
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
908
 
909
+ def _prefill_prompt_batched(self, labels, pos: int, lora=None, temperature: float = 0.0, top_p: float = 0.0):
 
 
 
 
910
  tpl = self.config.tokenizer.templates["detect"]
911
  if tpl is None:
912
  raise NotImplementedError("Model does not support object detection (no detect template).")
 
923
  for i, ids in enumerate(rows):
924
  prompt_ids[i, : ids.numel()] = ids
925
 
926
+ prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
927
  torch._dynamo.mark_dynamic(prompt_emb, 1)
928
 
929
+ # mask: (B,1,T,kv_len)
930
+ base = self.attn_mask[:, :, pos:pos+T, :] # (1,1,T,K)
931
+ mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
932
 
933
+ # IMPORTANT: for prefill pass a 1-D vector of length T (matches upstream)
934
+ pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
935
+ hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
936
  logits_BTV = lm_head(hidden_BTC, self.text)
937
 
938
+ idx = (torch.tensor(lens, device=self.device) - 1).clamp_min(0)
939
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
940
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
941
 
942
  if temperature == 0.0:
943
+ next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
944
  else:
945
  probs = torch.softmax(last_logits / temperature, dim=-1)
946
  probs = self._apply_top_p(probs, top_p)
947
+ next_token = torch.multinomial(probs, num_samples=1) # (B,1)
948
 
949
+ # shared scalar end position
950
  return last_hidden, next_token, int(pos + T)
951
 
952
 
953
+
954
  def _generate_points_batched(
955
  self,
956
  hidden, # (B,1,C)
957
+ next_token, # (B,1) (ignored in greedy)
958
+ pos, # int
959
  include_size: bool = True,
960
  max_objects: int = 50,
961
  lora=None,
962
+ use_soft_argmax: bool = False, # default OFF to match upstream numerics
963
  ):
964
  B = hidden.size(0)
965
  device = self.device
 
967
  eos_id = self.config.tokenizer.eos_id
968
  max_ctx = self.config.text.max_context
969
 
970
+ # 4-D mask: (B,1,1,kv_len)
971
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
972
  if pos > 0:
973
  mask[:, :, :, :pos] = True
974
+
975
+ # rotary & KV path are happiest with a 1-D scalar position vector (like upstream)
976
+ pos_id_vec = torch.tensor([pos], device=device, dtype=torch.long) # (1,)
977
 
978
  alive = torch.ones(B, dtype=torch.bool, device=device)
979
  counts = torch.zeros(B, dtype=torch.int32, device=device)
980
 
981
+ def _center01(logits_2d):
982
  # logits_2d: (B, bins)
983
+ if logits_2d.dim() == 3: # (B,1,bins) -> (B,bins)
984
+ logits_2d = logits_2d.squeeze(1)
985
+ bins = logits_2d.size(-1)
986
  if use_soft_argmax:
987
+ p = torch.softmax(logits_2d, dim=-1)
988
+ idx = (p * torch.arange(bins, device=logits_2d.device, dtype=torch.float32)).sum(dim=-1)
989
+ return idx / float(bins) # match upstream scale
990
+ else:
991
+ return logits_2d.argmax(dim=-1).to(torch.float32) / float(bins)
 
992
 
993
  with torch.inference_mode():
994
  while alive.any() and (counts < max_objects).any():
995
+ # --- x ---
996
+ x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
997
+ x_center = _center01(x_logits) # (B,) in [0,1]
998
+ x_emb = encode_coordinate(x_center.to(x_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1) # (B,1,C)
999
+
 
1000
  mask[alive, :, :, pos] = True
1001
+ _, hidden = self._decode_one_tok(x_emb, mask, pos_id_vec, lora)
1002
+ pos += 1
1003
+ pos_id_vec[0] = pos
1004
 
1005
+ # --- y ---
1006
  y_logits = decode_coordinate(hidden, self.region)
1007
+ y_center = _center01(y_logits)
1008
+ y_emb = encode_coordinate(y_center.to(y_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
1009
+
1010
+ mask[alive, :, :, pos] = True
1011
+ _, hidden = self._decode_one_tok(y_emb, mask, pos_id_vec, lora)
1012
+ pos += 1
1013
+ pos_id_vec[0] = pos
1014
+
1015
+ if include_size:
1016
+ # --- size ---
1017
+ size_ret = decode_size(hidden, self.region)
1018
+ # Works for tuple or stacked tensor
1019
+ if isinstance(size_ret, (tuple, list)):
1020
+ w_logits, h_logits = size_ret
1021
+ else:
1022
+ # expected shapes: (B,2,1024) or (B,1,2,1024)
1023
+ if size_ret.dim() == 3: # (B,2,1024)
1024
+ w_logits, h_logits = size_ret[:, 0], size_ret[:, 1]
1025
+ else: # (B,1,2,1024)
1026
+ w_logits, h_logits = size_ret[:, 0, 0], size_ret[:, 0, 1]
1027
+ if w_logits.dim() == 3: w_logits = w_logits.squeeze(1)
1028
+ if h_logits.dim() == 3: h_logits = h_logits.squeeze(1)
1029
+
1030
+ # bins -> size via the same inverse log2 scale as upstream
1031
+ w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1032
+ h_bin = h_logits.argmax(dim=-1).to(torch.float32)
1033
+ w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
1034
+ h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
1035
+
1036
+ size_emb = encode_size(torch.stack([w, h], dim=1).to(w_logits.dtype), self.region).unsqueeze(1)
1037
+
1038
+ # record boxes (clamped)
1039
+ for i in range(B):
1040
+ if not alive[i]:
1041
+ continue
1042
+ xl = (x_center[i] - w[i] / 2).item()
1043
+ xr = (x_center[i] + w[i] / 2).item()
1044
+ yt = (y_center[i] - h[i] / 2).item()
1045
+ yb = (y_center[i] + h[i] / 2).item()
1046
+ out[i].append({
1047
+ "x_min": max(0.0, min(1.0, xl)),
1048
+ "y_min": max(0.0, min(1.0, yt)),
1049
+ "x_max": max(0.0, min(1.0, xr)),
1050
+ "y_max": max(0.0, min(1.0, yb)),
1051
+ })
1052
+
1053
+ mask[alive, :, :, pos] = True
1054
+ logits, hidden = self._decode_one_tok(size_emb, mask, pos_id_vec, lora)
1055
+ pos += 1
1056
+ pos_id_vec[0] = pos
1057
+ next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1058
+ else:
1059
+ for i in range(B):
1060
+ if alive[i]:
1061
+ out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1062
+ mask[alive, :, :, pos] = True
1063
+ logits, hidden = self._decode_one_tok(y_emb, mask, pos_id_vec, lora)
1064
+ pos += 1
1065
+ pos_id_vec[0] = pos
1066
+ next_tok = logits.argmax(dim=-1).squeeze(-1)
1067
+
1068
+ finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1069
+ counts = counts + ((~finished_now) & alive).to(counts.dtype)
1070
+ alive &= ~finished_now
1071
+
1072
+ return out
1073
+
1074
 
1075
 
1076
 
 
1080
  raise NotImplementedError("Model does not support object detection.")
1081
  settings = settings or {}
1082
 
1083
+ enc = self.encode_image(image, settings)
 
1084
  B = len(objects)
1085
+ self._load_encoded_image_batched(enc, B)
1086
 
1087
+ lora = variant_state_dict(settings["variant"], device=self.device) if "variant" in settings else None
 
 
1088
 
1089
  last_hidden, next_token, pos_end = self._prefill_prompt_batched(
1090
+ objects, enc.pos, lora=lora, temperature=0.0, top_p=0.0
1091
  )
1092
 
 
1093
  det_lists = self._generate_points_batched(
1094
  last_hidden, next_token, pos_end,
1095
+ include_size=True,
1096
+ max_objects=settings.get("max_objects", 50),
1097
+ lora=lora,
1098
  )
1099
 
 
1100
  res = {}
1101
  for lab, lst in zip(objects, det_lists):
1102
  for d in lst:
1103
  d["label"] = lab
1104
  res[lab] = lst
1105
 
1106
+ # make subsequent single-image calls stable
1107
  self._reset_kv_caches(1)
1108
  return {"objects": res}
1109
 
1110
 
1111
 
1112
+
1113
+
1114
  def _detect_gaze(
1115
  self,
1116
  image: EncodedImage,