HV-Khurdula commited on
Commit
4d9e33f
·
verified ·
1 Parent(s): e80a71a

Update moondream.py

Browse files

feat: udpate KV caching and support for batching, from encoding to prefill to decode.

Files changed (1) hide show
  1. moondream.py +74 -91
moondream.py CHANGED
@@ -64,26 +64,18 @@ class EncodedImage:
64
  pos: int
65
  caches: List[Tuple[torch.Tensor, torch.Tensor]]
66
 
67
- # -------------------- KVCache --------------------
68
  class KVCache(nn.Module):
69
  def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
70
  super().__init__()
71
  head_dim = dim // n_heads
72
- cache_shape = (1, n_kv_heads, max_context, head_dim)
73
- self.register_buffer("k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype))
74
- self.register_buffer("v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype))
75
 
76
  def update(self, pos_ids, k, v):
77
- """
78
- Supports:
79
- • Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
80
- • Step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,) or (B,1)
81
- • Legacy: k,v = (1, n_kv_heads, q_len, d), pos_ids = scalar
82
- Writes into self.k_cache/self.v_cache shaped (B, n_kv_heads, T_max, d).
83
- """
84
- kout, vout = self.kv_cache if hasattr(self, "kv_cache") else (self.k_cache, self.v_cache)
85
 
86
- # normalize pos_ids
87
  if not torch.is_tensor(pos_ids):
88
  pos_ids = torch.tensor(pos_ids, device=k.device, dtype=torch.long)
89
  else:
@@ -91,26 +83,25 @@ class KVCache(nn.Module):
91
 
92
  if k.dim() != 4 or v.dim() != 4:
93
  raise RuntimeError(f"KV update expects k,v 4D. Got k={tuple(k.shape)} v={tuple(v.shape)}")
94
-
95
  B, Hkv, q_len, D = k.shape
96
 
97
- # match cache batch to B (expand-from-1 allowed)
98
- if self.k_cache.size(0) != B:
99
- if self.k_cache.size(0) == 1:
100
- self.k_cache = self.k_cache.expand(B, -1, -1, -1).clone()
101
- self.v_cache = self.v_cache.expand(B, -1, -1, -1).clone()
102
  kout, vout = self.k_cache, self.v_cache
103
  else:
104
- raise RuntimeError(f"KV cache batch mismatch: cache.B={self.k_cache.size(0)} vs k.B={B}")
105
 
106
- # Case A: prefill vector of length q_len (same for all rows)
107
  if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
108
  for i in range(B):
109
  kout[i, :, pos_ids, :] = k[i]
110
  vout[i, :, pos_ids, :] = v[i]
111
  return kout, vout
112
 
113
- # Case B: single step — q_len==1 with per-row positions
114
  if q_len == 1 and pos_ids.numel() == B:
115
  pos_ids = pos_ids.view(B)
116
  for i in range(B):
@@ -119,7 +110,7 @@ class KVCache(nn.Module):
119
  vout[i, :, pi, :] = v[i, :, 0, :]
120
  return kout, vout
121
 
122
- # Case C: scalar & q_len==1
123
  if pos_ids.dim() == 0 and q_len == 1:
124
  pi = int(pos_ids.item())
125
  kout[:, :, pi, :] = k[:, :, 0, :]
@@ -129,11 +120,6 @@ class KVCache(nn.Module):
129
  raise RuntimeError(f"Unsupported KV update combo: k={tuple(k.shape)}, pos_ids={tuple(pos_ids.shape)}")
130
 
131
 
132
-
133
-
134
-
135
-
136
-
137
  class MoondreamModel(nn.Module):
138
 
139
  def __init__(
@@ -570,29 +556,29 @@ class MoondreamModel(nn.Module):
570
  return generator(next_token, pos)
571
 
572
  def encode_image(self, image, settings=None) -> EncodedImage:
573
- # always start from B=1 to avoid leftovers from batched runs
574
- self._setup_caches() # recreates caches with B=1
575
 
576
  if isinstance(image, EncodedImage):
577
  return image
578
  if not isinstance(image, Image.Image):
579
  raise ValueError("image must be a PIL Image or EncodedImage")
580
 
581
- # hard-trim to B=1 if external code changed it
582
  for blk in self.text.blocks:
583
  if blk.kv_cache.k_cache.size(0) != 1:
584
  blk.kv_cache.k_cache = blk.kv_cache.k_cache[:1].contiguous()
585
  blk.kv_cache.v_cache = blk.kv_cache.v_cache[:1].contiguous()
586
 
587
- lora = variant_state_dict(settings["variant"], device=self.device) \
588
- if settings and "variant" in settings else None
589
 
590
  with torch.inference_mode():
591
- img_emb = self._run_vision_encoder(image) # (T_img,C)
592
- bos_emb = text_encoder(torch.tensor([[self.config.tokenizer.bos_id]], device=self.device), self.text) # (1,1,C)
593
- inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1) # (1,T0,C)
 
594
 
595
- mask = self.attn_mask[:, :, :inputs_embeds.size(1), :] # (1,1,T0,K)
596
  pos_ids = torch.arange(inputs_embeds.size(1), device=self.device, dtype=torch.long) # (T0,)
597
  self._prefill(inputs_embeds, mask, pos_ids, lora)
598
 
@@ -608,7 +594,6 @@ class MoondreamModel(nn.Module):
608
 
609
 
610
 
611
-
612
  def query(
613
  self,
614
  image: Optional[Union[Image.Image, EncodedImage]] = None,
@@ -941,8 +926,7 @@ class MoondreamModel(nn.Module):
941
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
942
 
943
 
944
- def _prefill_prompt_batched(self, labels, pos: int, lora=None,
945
- temperature: float = 0.0, top_p: float = 0.0):
946
  tpl = self.config.tokenizer.templates["detect"]
947
  if tpl is None:
948
  raise NotImplementedError("Model does not support object detection.")
@@ -959,62 +943,52 @@ class MoondreamModel(nn.Module):
959
  for i, ids in enumerate(rows):
960
  prompt_ids[i, : ids.numel()] = ids
961
 
962
- prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
963
- torch._dynamo.mark_dynamic(prompt_emb, 1) # allow variable T
964
 
965
- base = self.attn_mask[:, :, pos:pos+T, :] # (1,1,T,K)
966
- mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
967
 
968
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
969
- hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
970
- logits_BTV = lm_head(hidden_BTC, self.text) # (B,T,V)
971
 
972
- idx = (torch.tensor(lens, device=self.device) - 1).clamp_min(0) # (B,)
973
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
974
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
975
 
976
  if temperature == 0.0:
977
- next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
978
  else:
979
  probs = torch.softmax(last_logits / temperature, dim=-1)
980
  probs = self._apply_top_p(probs, top_p)
981
- next_token = torch.multinomial(probs, num_samples=1) # (B,1)
982
 
983
- # shared next-free position in cache (safe upper bound)
984
- pos_end = int(pos + T)
985
  return last_hidden, next_token, pos_end
986
 
987
 
988
-
989
-
990
  def _generate_points_batched(
991
- self,
992
- hidden, # (B,1,C)
993
- next_token, # (B,1) (unused in greedy)
994
- pos, # int
995
- include_size: bool = True,
996
- max_objects: int = 50,
997
- lora=None,
998
- use_soft_argmax: bool = True, # reduces jitter
999
- ):
1000
  B = hidden.size(0)
1001
  device = self.device
1002
  out = [[] for _ in range(B)]
1003
  eos_id = self.config.tokenizer.eos_id
1004
  max_ctx = self.config.text.max_context
1005
 
1006
- # 4-D mask: (B, 1, q_len=1, kv_len)
1007
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
1008
  if int(pos) > 0:
1009
  mask[:, :, :, :int(pos)] = True
1010
- pos_ids = torch.full((B, 1), int(pos), device=device, dtype=torch.long)
1011
 
1012
- # helper: (B, bins) -> (B,) in [0,1]
1013
- def _argmax01(logits):
1014
  if use_soft_argmax:
1015
- probs = torch.softmax(logits, dim=-1)
1016
- bins = torch.arange(probs.size(-1), device=logits.device, dtype=torch.float32)
1017
- return (probs * bins).sum(dim=-1) / float(probs.size(-1) - 1)
1018
  idx = logits.argmax(dim=-1).to(torch.float32)
1019
  return idx / float(logits.size(-1) - 1)
1020
 
@@ -1023,29 +997,38 @@ class MoondreamModel(nn.Module):
1023
 
1024
  with torch.inference_mode():
1025
  while alive.any() and (counts < max_objects).any():
1026
- # ---- x
1027
  x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
1028
  if x_logits.dim() == 3: x_logits = x_logits.squeeze(1)
1029
- x_center = _argmax01(x_logits) # (B,)
1030
  x_emb = encode_coordinate(x_center.to(dtype=x_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
1031
 
1032
- mask[alive, :, :, pos_ids[0,0]] = True
1033
- logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
1034
- pos_ids[alive, 0] += 1
1035
 
1036
- # ---- y
1037
  y_logits = decode_coordinate(hidden, self.region)
1038
  if y_logits.dim() == 3: y_logits = y_logits.squeeze(1)
1039
- y_center = _argmax01(y_logits)
1040
  y_emb = encode_coordinate(y_center.to(dtype=y_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
1041
 
1042
- mask[alive, :, :, pos_ids[0,0]] = True
1043
- logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1044
- pos_ids[alive, 0] += 1
1045
 
1046
  if include_size:
1047
  size_ret = decode_size(hidden, self.region)
1048
- w_logits, h_logits = self._norm_size_logits(size_ret, B)
 
 
 
 
 
 
 
 
 
1049
 
1050
  if use_soft_argmax:
1051
  bins = torch.arange(w_logits.size(-1), device=device, dtype=torch.float32)
@@ -1055,13 +1038,12 @@ class MoondreamModel(nn.Module):
1055
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1056
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
1057
 
1058
- # inverse log scale used by md2
1059
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
1060
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
1061
 
1062
  size_emb = encode_size(torch.stack([w, h], dim=1).to(dtype=w_logits.dtype), self.region).unsqueeze(1)
1063
 
1064
- # write boxes only for alive rows
1065
  for i in range(B):
1066
  if not alive[i]: continue
1067
  xl = (x_center[i] - w[i] / 2).item()
@@ -1075,17 +1057,17 @@ class MoondreamModel(nn.Module):
1075
  "y_max": max(0.0, min(1.0, yb)),
1076
  })
1077
 
1078
- mask[alive, :, :, pos_ids[0,0]] = True
1079
- logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
1080
- pos_ids[alive, 0] += 1
1081
- next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1082
  else:
1083
  for i in range(B):
1084
  if alive[i]:
1085
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1086
- mask[alive, :, :, pos_ids[0,0]] = True
1087
- logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1088
- pos_ids[alive, 0] += 1
1089
  next_tok = logits.argmax(dim=-1).squeeze(-1)
1090
 
1091
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
@@ -1094,6 +1076,7 @@ class MoondreamModel(nn.Module):
1094
 
1095
  return out
1096
 
 
1097
  def detect_multi(self, image, objects, settings=None):
1098
  if self.config.tokenizer.templates["detect"] is None:
1099
  raise NotImplementedError("Model does not support object detection.")
@@ -1122,8 +1105,7 @@ class MoondreamModel(nn.Module):
1122
  d["label"] = lab
1123
  res[lab] = lst
1124
 
1125
- # restore B=1 so the next encode_image() starts clean
1126
- self._reset_kv_caches(1)
1127
  return {"objects": res}
1128
 
1129
 
@@ -1131,6 +1113,7 @@ class MoondreamModel(nn.Module):
1131
 
1132
 
1133
 
 
1134
  def _detect_gaze(
1135
  self,
1136
  image: EncodedImage,
 
64
  pos: int
65
  caches: List[Tuple[torch.Tensor, torch.Tensor]]
66
 
 
67
  class KVCache(nn.Module):
68
  def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
69
  super().__init__()
70
  head_dim = dim // n_heads
71
+ shape = (1, n_kv_heads, max_context, head_dim)
72
+ self.register_buffer("k_cache", torch.zeros(*shape, device=device, dtype=dtype))
73
+ self.register_buffer("v_cache", torch.zeros(*shape, device=device, dtype=dtype))
74
 
75
  def update(self, pos_ids, k, v):
76
+ # k,v: (B, n_kv_heads, q_len, head_dim)
77
+ kout, vout = self.k_cache, self.v_cache
 
 
 
 
 
 
78
 
 
79
  if not torch.is_tensor(pos_ids):
80
  pos_ids = torch.tensor(pos_ids, device=k.device, dtype=torch.long)
81
  else:
 
83
 
84
  if k.dim() != 4 or v.dim() != 4:
85
  raise RuntimeError(f"KV update expects k,v 4D. Got k={tuple(k.shape)} v={tuple(v.shape)}")
 
86
  B, Hkv, q_len, D = k.shape
87
 
88
+ # expand caches from B=1 -> B if needed
89
+ if kout.size(0) != B:
90
+ if kout.size(0) == 1:
91
+ self.k_cache = kout.expand(B, -1, -1, -1).clone()
92
+ self.v_cache = vout.expand(B, -1, -1, -1).clone()
93
  kout, vout = self.k_cache, self.v_cache
94
  else:
95
+ raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
96
 
97
+ # prefill: pos_ids = (q_len,)
98
  if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
99
  for i in range(B):
100
  kout[i, :, pos_ids, :] = k[i]
101
  vout[i, :, pos_ids, :] = v[i]
102
  return kout, vout
103
 
104
+ # one step: q_len==1 & pos_ids per row
105
  if q_len == 1 and pos_ids.numel() == B:
106
  pos_ids = pos_ids.view(B)
107
  for i in range(B):
 
110
  vout[i, :, pi, :] = v[i, :, 0, :]
111
  return kout, vout
112
 
113
+ # scalar for everyone & q_len==1
114
  if pos_ids.dim() == 0 and q_len == 1:
115
  pi = int(pos_ids.item())
116
  kout[:, :, pi, :] = k[:, :, 0, :]
 
120
  raise RuntimeError(f"Unsupported KV update combo: k={tuple(k.shape)}, pos_ids={tuple(pos_ids.shape)}")
121
 
122
 
 
 
 
 
 
123
  class MoondreamModel(nn.Module):
124
 
125
  def __init__(
 
556
  return generator(next_token, pos)
557
 
558
  def encode_image(self, image, settings=None) -> EncodedImage:
559
+ # start clean: recreate caches as B=1 every time
560
+ self._setup_caches()
561
 
562
  if isinstance(image, EncodedImage):
563
  return image
564
  if not isinstance(image, Image.Image):
565
  raise ValueError("image must be a PIL Image or EncodedImage")
566
 
567
+ # hard-trim to B=1 in case something changed it
568
  for blk in self.text.blocks:
569
  if blk.kv_cache.k_cache.size(0) != 1:
570
  blk.kv_cache.k_cache = blk.kv_cache.k_cache[:1].contiguous()
571
  blk.kv_cache.v_cache = blk.kv_cache.v_cache[:1].contiguous()
572
 
573
+ lora = variant_state_dict(settings["variant"], device=self.device) if settings and "variant" in settings else None
 
574
 
575
  with torch.inference_mode():
576
+ img_emb = self._run_vision_encoder(image) # (T_img, C)
577
+ bos = torch.tensor([[self.config.tokenizer.bos_id]], device=self.device)
578
+ bos_emb = text_encoder(bos, self.text) # (1,1,C)
579
+ inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1) # (1,T0,C)
580
 
581
+ mask = self.attn_mask[:, :, :inputs_embeds.size(1), :] # (1,1,T0,K)
582
  pos_ids = torch.arange(inputs_embeds.size(1), device=self.device, dtype=torch.long) # (T0,)
583
  self._prefill(inputs_embeds, mask, pos_ids, lora)
584
 
 
594
 
595
 
596
 
 
597
  def query(
598
  self,
599
  image: Optional[Union[Image.Image, EncodedImage]] = None,
 
926
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
927
 
928
 
929
+ def _prefill_prompt_batched(self, labels, pos: int, lora=None, temperature: float = 0.0, top_p: float = 0.0):
 
930
  tpl = self.config.tokenizer.templates["detect"]
931
  if tpl is None:
932
  raise NotImplementedError("Model does not support object detection.")
 
943
  for i, ids in enumerate(rows):
944
  prompt_ids[i, : ids.numel()] = ids
945
 
946
+ prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
947
+ torch._dynamo.mark_dynamic(prompt_emb, 1)
948
 
949
+ base = self.attn_mask[:, :, pos:pos+T, :] # (1,1,T,K)
950
+ mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
951
 
952
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
953
+ hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
954
+ logits_BTV = lm_head(hidden_BTC, self.text) # (B,T,V)
955
 
956
+ idx = (torch.tensor(lens, device=self.device) - 1).clamp_min(0) # (B,)
957
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
958
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
959
 
960
  if temperature == 0.0:
961
+ next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
962
  else:
963
  probs = torch.softmax(last_logits / temperature, dim=-1)
964
  probs = self._apply_top_p(probs, top_p)
965
+ next_token = torch.multinomial(probs, num_samples=1) # (B,1)
966
 
967
+ pos_end = int(pos + T) # shared next-free position
 
968
  return last_hidden, next_token, pos_end
969
 
970
 
 
 
971
  def _generate_points_batched(
972
+ self, hidden, next_token, pos, include_size: bool = True,
973
+ max_objects: int = 50, lora=None, use_soft_argmax: bool = False):
 
 
 
 
 
 
 
974
  B = hidden.size(0)
975
  device = self.device
976
  out = [[] for _ in range(B)]
977
  eos_id = self.config.tokenizer.eos_id
978
  max_ctx = self.config.text.max_context
979
 
980
+ # 4-D mask: (B,1,1,kv_len); advance with a 1-D position vector
981
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
982
  if int(pos) > 0:
983
  mask[:, :, :, :int(pos)] = True
984
+ pos_id_vec = torch.full((1,), int(pos), device=device, dtype=torch.long)
985
 
986
+ def _center01(logits):
987
+ # logits: (B, bins) → (B,) in [0,1]
988
  if use_soft_argmax:
989
+ p = torch.softmax(logits, dim=-1)
990
+ bins = torch.arange(p.size(-1), device=logits.device, dtype=torch.float32)
991
+ return (p * bins).sum(dim=-1) / float(p.size(-1) - 1)
992
  idx = logits.argmax(dim=-1).to(torch.float32)
993
  return idx / float(logits.size(-1) - 1)
994
 
 
997
 
998
  with torch.inference_mode():
999
  while alive.any() and (counts < max_objects).any():
1000
+ # x
1001
  x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
1002
  if x_logits.dim() == 3: x_logits = x_logits.squeeze(1)
1003
+ x_center = _center01(x_logits)
1004
  x_emb = encode_coordinate(x_center.to(dtype=x_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
1005
 
1006
+ mask[:, :, :, pos_id_vec] = True
1007
+ logits, hidden = self._decode_one_tok(x_emb, mask, pos_id_vec, lora)
1008
+ pos_id_vec += 1
1009
 
1010
+ # y
1011
  y_logits = decode_coordinate(hidden, self.region)
1012
  if y_logits.dim() == 3: y_logits = y_logits.squeeze(1)
1013
+ y_center = _center01(y_logits)
1014
  y_emb = encode_coordinate(y_center.to(dtype=y_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
1015
 
1016
+ mask[:, :, :, pos_id_vec] = True
1017
+ logits, hidden = self._decode_one_tok(y_emb, mask, pos_id_vec, lora)
1018
+ pos_id_vec += 1
1019
 
1020
  if include_size:
1021
  size_ret = decode_size(hidden, self.region)
1022
+ # Robust parse: accept (w,h) tuple OR Tensor (B,2,C)/(B,1,2,C)
1023
+ if isinstance(size_ret, (tuple, list)):
1024
+ w_logits, h_logits = size_ret
1025
+ else:
1026
+ t = size_ret
1027
+ if t.dim() == 4: # (B,1,2,C)
1028
+ t = t.squeeze(1) # → (B,2,C)
1029
+ if t.dim() != 3 or t.size(1) != 2:
1030
+ raise RuntimeError(f"Unexpected decode_size shape {tuple(t.shape)}")
1031
+ w_logits, h_logits = t[:, 0, :], t[:, 1, :]
1032
 
1033
  if use_soft_argmax:
1034
  bins = torch.arange(w_logits.size(-1), device=device, dtype=torch.float32)
 
1038
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1039
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
1040
 
1041
+ # inverse log-scale mapping used by md2
1042
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
1043
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
1044
 
1045
  size_emb = encode_size(torch.stack([w, h], dim=1).to(dtype=w_logits.dtype), self.region).unsqueeze(1)
1046
 
 
1047
  for i in range(B):
1048
  if not alive[i]: continue
1049
  xl = (x_center[i] - w[i] / 2).item()
 
1057
  "y_max": max(0.0, min(1.0, yb)),
1058
  })
1059
 
1060
+ mask[:, :, :, pos_id_vec] = True
1061
+ logits, hidden = self._decode_one_tok(size_emb, mask, pos_id_vec, lora)
1062
+ pos_id_vec += 1
1063
+ next_tok = logits.argmax(dim=-1).squeeze(-1)
1064
  else:
1065
  for i in range(B):
1066
  if alive[i]:
1067
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1068
+ mask[:, :, :, pos_id_vec] = True
1069
+ logits, hidden = self._decode_one_tok(y_emb, mask, pos_id_vec, lora)
1070
+ pos_id_vec += 1
1071
  next_tok = logits.argmax(dim=-1).squeeze(-1)
1072
 
1073
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
 
1076
 
1077
  return out
1078
 
1079
+
1080
  def detect_multi(self, image, objects, settings=None):
1081
  if self.config.tokenizer.templates["detect"] is None:
1082
  raise NotImplementedError("Model does not support object detection.")
 
1105
  d["label"] = lab
1106
  res[lab] = lst
1107
 
1108
+ self._reset_kv_caches(1) # restore B=1
 
1109
  return {"objects": res}
1110
 
1111
 
 
1113
 
1114
 
1115
 
1116
+
1117
  def _detect_gaze(
1118
  self,
1119
  image: EncodedImage,