Update moondream.py
Browse filesfix:
1. Rerun crash in encode_image(...): After a batched call, KV caches are still (B, ...). On the next encode, Q has B=1 but K/V broadcast to B, so attention returns B rows and the reshape to (1, q_len, d_model) fails. Forcing a fresh B=1 cache avoids this. (Moondream’s attention expects Q/K/V batch dimensions to match when reshaping back to (bsz, q_len, d_model).)
2. decode_size IndexError on .squeeze(1): Upstream decode_size returns mlp(...).view(2, -1), which flattens batch/time dims; not always a (B,1,1024) pair. We reshape it back to (2, B, -1) when needed, so it works in both variants.
- moondream.py +83 -79
moondream.py
CHANGED
|
@@ -155,17 +155,33 @@ class MoondreamModel(nn.Module):
|
|
| 155 |
if setup_caches:
|
| 156 |
self._setup_caches()
|
| 157 |
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
c = self.config.text
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
@property
|
| 171 |
def device(self):
|
|
@@ -238,30 +254,30 @@ class MoondreamModel(nn.Module):
|
|
| 238 |
image: Union[Image.Image, EncodedImage],
|
| 239 |
settings: Optional[ImageEncodingSettings] = None,
|
| 240 |
) -> EncodedImage:
|
|
|
|
|
|
|
|
|
|
| 241 |
if isinstance(image, EncodedImage):
|
| 242 |
return image
|
| 243 |
elif not isinstance(image, Image.Image):
|
| 244 |
raise ValueError("image must be a PIL Image or EncodedImage")
|
| 245 |
-
|
| 246 |
lora = (
|
| 247 |
variant_state_dict(settings["variant"], device=self.device)
|
| 248 |
if settings is not None and "variant" in settings
|
| 249 |
else None
|
| 250 |
)
|
| 251 |
-
|
| 252 |
-
# Run through text model in addition to the vision encoder, to minimize
|
| 253 |
-
# re-computation if multiple queries are performed on this image.
|
| 254 |
with torch.inference_mode():
|
| 255 |
img_emb = self._run_vision_encoder(image)
|
| 256 |
bos_emb = text_encoder(
|
| 257 |
-
torch.tensor([[self.config.tokenizer.bos_id]], device=self.device),
|
| 258 |
-
self.text,
|
| 259 |
)
|
| 260 |
inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
|
| 261 |
mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
|
| 262 |
pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
|
| 263 |
self._prefill(inputs_embeds, mask, pos_ids, lora)
|
| 264 |
-
|
| 265 |
return EncodedImage(
|
| 266 |
pos=inputs_embeds.size(1),
|
| 267 |
caches=[
|
|
@@ -273,6 +289,7 @@ class MoondreamModel(nn.Module):
|
|
| 273 |
],
|
| 274 |
)
|
| 275 |
|
|
|
|
| 276 |
def _apply_top_p(self, probs: torch.Tensor, top_p: float):
|
| 277 |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
| 278 |
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
|
@@ -835,7 +852,6 @@ class MoondreamModel(nn.Module):
|
|
| 835 |
return {"points": objects}
|
| 836 |
|
| 837 |
|
| 838 |
-
# === BEGIN: Batched multi-label detection additions ===
|
| 839 |
def _load_encoded_image_batched(self, encoded_image, batch_size: int):
|
| 840 |
"""
|
| 841 |
Clone single-image KV caches into a batch-B cache so we can decode B labels in parallel.
|
|
@@ -860,10 +876,6 @@ class MoondreamModel(nn.Module):
|
|
| 860 |
temperature: float = 0.0,
|
| 861 |
top_p: float = 0.0,
|
| 862 |
):
|
| 863 |
-
"""
|
| 864 |
-
Build detect prompts for many labels, pad to the same length, prefill once as a batch.
|
| 865 |
-
Returns (last_hidden per row, next_token per row, shared_pos_end scalar).
|
| 866 |
-
"""
|
| 867 |
tpl = self.config.tokenizer.templates["detect"]
|
| 868 |
if tpl is None:
|
| 869 |
raise NotImplementedError("Model does not support object detection (no detect template).")
|
|
@@ -877,24 +889,21 @@ class MoondreamModel(nn.Module):
|
|
| 877 |
T = max(lens)
|
| 878 |
eos = self.config.tokenizer.eos_id
|
| 879 |
|
| 880 |
-
# Pad to T with eos, so we can prefill with a single shared position range
|
| 881 |
prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
|
| 882 |
for i, ids in enumerate(rows):
|
| 883 |
prompt_ids[i, : ids.numel()] = ids
|
| 884 |
|
| 885 |
-
#
|
| 886 |
-
|
| 887 |
-
torch._dynamo.mark_dynamic(prompt_emb, 1) # allow variable T
|
| 888 |
|
| 889 |
-
# 4-D mask
|
| 890 |
attn = self.attn_mask
|
| 891 |
-
mask = attn[:, :, pos : pos + T, :].expand(B, -1, -1, -1).contiguous()
|
| 892 |
pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long)
|
| 893 |
|
| 894 |
hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B, T, C)
|
| 895 |
logits_BTV = lm_head(hidden_BTC, self.text) # (B, T, V)
|
| 896 |
|
| 897 |
-
# Take the last *real* token per row
|
| 898 |
idx = (torch.tensor(lens, device=self.device, dtype=torch.long) - 1).clamp_min(0)
|
| 899 |
last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
|
| 900 |
last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
|
|
@@ -906,34 +915,27 @@ class MoondreamModel(nn.Module):
|
|
| 906 |
probs = self._apply_top_p(probs, top_p)
|
| 907 |
next_token = torch.multinomial(probs, num_samples=1) # (B,1)
|
| 908 |
|
| 909 |
-
# Shared "next decode position" for all rows (we prefilled up to pos+T-1)
|
| 910 |
pos_end = pos + T
|
| 911 |
return last_hidden, next_token, pos_end # (B,1,C), (B,1), int
|
| 912 |
|
| 913 |
|
|
|
|
| 914 |
def _generate_points_batched(
|
| 915 |
self,
|
| 916 |
-
hidden, # (B,1,C)
|
| 917 |
-
next_token, # (B,1)
|
| 918 |
-
pos: int, # shared scalar next position
|
| 919 |
include_size: bool = True,
|
| 920 |
max_objects: int = 50,
|
| 921 |
lora=None,
|
| 922 |
):
|
| 923 |
-
"""
|
| 924 |
-
Vectorized version of _generate_points() that decodes x -> y -> size -> next-token
|
| 925 |
-
for all rows in the batch simultaneously. Returns list-of-lists of dicts (len B).
|
| 926 |
-
Batch-safe: uses 4-D masks and avoids region.decode_size() (which flattens batch).
|
| 927 |
-
"""
|
| 928 |
-
import torch
|
| 929 |
-
|
| 930 |
B = hidden.size(0)
|
| 931 |
device = self.device
|
| 932 |
out = [[] for _ in range(B)]
|
| 933 |
eos_id = self.config.tokenizer.eos_id
|
| 934 |
max_ctx = self.config.text.max_context
|
| 935 |
|
| 936 |
-
# 4-D mask: (B, 1, q_len=1, kv_len)
|
| 937 |
mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
|
| 938 |
if pos > 0:
|
| 939 |
mask[:, :, :, :pos] = True
|
|
@@ -944,53 +946,55 @@ class MoondreamModel(nn.Module):
|
|
| 944 |
|
| 945 |
with torch.inference_mode():
|
| 946 |
while alive.any() and (counts < max_objects).any():
|
| 947 |
-
# --- x coordinate
|
| 948 |
x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
|
| 949 |
if x_logits.dim() == 3:
|
| 950 |
-
x_logits = x_logits.squeeze(1)
|
| 951 |
-
x_bin
|
| 952 |
x_center = x_bin / float(x_logits.size(-1)) # (B,)
|
| 953 |
-
x_in
|
| 954 |
-
x_emb
|
| 955 |
|
| 956 |
-
|
| 957 |
-
mask[:, :, :, pos] = True
|
| 958 |
logits, hidden = self._decode_one_tok(x_emb, mask, pos_id, lora)
|
| 959 |
-
|
| 960 |
-
pos_id[0] = pos
|
| 961 |
|
| 962 |
-
# --- y coordinate
|
| 963 |
-
y_logits = decode_coordinate(hidden, self.region)
|
| 964 |
if y_logits.dim() == 3:
|
| 965 |
y_logits = y_logits.squeeze(1)
|
| 966 |
-
y_bin
|
| 967 |
-
y_center = y_bin / float(y_logits.size(-1))
|
| 968 |
-
y_in
|
| 969 |
-
y_emb
|
| 970 |
|
| 971 |
-
mask[:, :, :,
|
| 972 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
|
| 973 |
-
|
| 974 |
-
pos_id[0] = pos
|
| 975 |
|
| 976 |
if include_size:
|
| 977 |
-
|
| 978 |
-
#
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 982 |
|
| 983 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 984 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
| 985 |
-
|
| 986 |
-
# inverse log-scale mapping used by the repo
|
| 987 |
w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
|
| 988 |
h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
|
| 989 |
|
| 990 |
-
size_in
|
| 991 |
size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
|
| 992 |
|
| 993 |
-
# commit boxes
|
| 994 |
for i in range(B):
|
| 995 |
if alive[i]:
|
| 996 |
out[i].append({
|
|
@@ -1000,21 +1004,18 @@ class MoondreamModel(nn.Module):
|
|
| 1000 |
"y_max": (y_center[i] + h[i] / 2).item(),
|
| 1001 |
})
|
| 1002 |
|
| 1003 |
-
|
| 1004 |
-
mask[:, :, :, pos] = True
|
| 1005 |
logits, hidden = self._decode_one_tok(size_emb, mask, pos_id, lora)
|
| 1006 |
-
|
| 1007 |
-
pos_id[0] = pos
|
| 1008 |
next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
|
| 1009 |
else:
|
| 1010 |
-
# points mode
|
| 1011 |
for i in range(B):
|
| 1012 |
if alive[i]:
|
| 1013 |
out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
|
| 1014 |
-
|
|
|
|
| 1015 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
|
| 1016 |
-
|
| 1017 |
-
pos_id[0] = pos
|
| 1018 |
next_tok = logits.argmax(dim=-1).squeeze(-1)
|
| 1019 |
|
| 1020 |
finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
|
|
@@ -1024,6 +1025,7 @@ class MoondreamModel(nn.Module):
|
|
| 1024 |
return out
|
| 1025 |
|
| 1026 |
|
|
|
|
| 1027 |
def detect_multi(self, image, objects, settings=None):
|
| 1028 |
"""
|
| 1029 |
Parallel multi-label detection.
|
|
@@ -1043,17 +1045,14 @@ class MoondreamModel(nn.Module):
|
|
| 1043 |
B = len(objects)
|
| 1044 |
self._load_encoded_image_batched(image, B)
|
| 1045 |
|
| 1046 |
-
# Optional LoRA variant
|
| 1047 |
lora = None
|
| 1048 |
if "variant" in settings:
|
| 1049 |
lora = variant_state_dict(settings["variant"], device=self.device)
|
| 1050 |
|
| 1051 |
-
# Prefill all prompts as a batch; shared next position
|
| 1052 |
last_hidden, next_token, pos_end = self._prefill_prompt_batched(
|
| 1053 |
objects, image.pos, lora=lora, temperature=0.0, top_p=0.0
|
| 1054 |
)
|
| 1055 |
|
| 1056 |
-
# Batched decode loop
|
| 1057 |
max_objects = settings.get("max_objects", 50)
|
| 1058 |
det_lists = self._generate_points_batched(
|
| 1059 |
last_hidden, next_token, pos_end,
|
|
@@ -1066,9 +1065,14 @@ class MoondreamModel(nn.Module):
|
|
| 1066 |
for d in lst:
|
| 1067 |
d["label"] = lab
|
| 1068 |
res[lab] = lst
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1069 |
return {"objects": res}
|
| 1070 |
|
| 1071 |
|
|
|
|
| 1072 |
def _detect_gaze(
|
| 1073 |
self,
|
| 1074 |
image: EncodedImage,
|
|
|
|
| 155 |
if setup_caches:
|
| 156 |
self._setup_caches()
|
| 157 |
|
| 158 |
+
|
| 159 |
+
def _reset_kv_caches(self, batch_size: int = 1):
|
| 160 |
+
"""
|
| 161 |
+
Recreate KV caches with the requested batch size so subsequent calls
|
| 162 |
+
(e.g., encode_image) start from a consistent shape.
|
| 163 |
+
"""
|
| 164 |
c = self.config.text
|
| 165 |
+
head_dim = c.dim // c.n_heads
|
| 166 |
+
for blk in self.text.blocks:
|
| 167 |
+
device = blk.kv_cache.k_cache.device
|
| 168 |
+
dtype = blk.kv_cache.k_cache.dtype
|
| 169 |
+
shape = (batch_size, c.n_kv_heads, c.max_context, head_dim)
|
| 170 |
+
blk.kv_cache.k_cache = torch.zeros(shape, device=device, dtype=dtype)
|
| 171 |
+
blk.kv_cache.v_cache = torch.zeros(shape, device=device, dtype=dtype)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _setup_caches(self):
|
| 175 |
+
c = self.config.text
|
| 176 |
+
for b in self.text.blocks:
|
| 177 |
+
b.kv_cache = KVCache(
|
| 178 |
+
c.n_heads,
|
| 179 |
+
c.n_kv_heads,
|
| 180 |
+
c.max_context,
|
| 181 |
+
c.dim,
|
| 182 |
+
device=self.device,
|
| 183 |
+
dtype=self.vision.pos_emb.dtype,
|
| 184 |
+
)
|
| 185 |
|
| 186 |
@property
|
| 187 |
def device(self):
|
|
|
|
| 254 |
image: Union[Image.Image, EncodedImage],
|
| 255 |
settings: Optional[ImageEncodingSettings] = None,
|
| 256 |
) -> EncodedImage:
|
| 257 |
+
# Always start from single-row caches; avoids leftovers from batched runs.
|
| 258 |
+
self._setup_caches()
|
| 259 |
+
|
| 260 |
if isinstance(image, EncodedImage):
|
| 261 |
return image
|
| 262 |
elif not isinstance(image, Image.Image):
|
| 263 |
raise ValueError("image must be a PIL Image or EncodedImage")
|
| 264 |
+
|
| 265 |
lora = (
|
| 266 |
variant_state_dict(settings["variant"], device=self.device)
|
| 267 |
if settings is not None and "variant" in settings
|
| 268 |
else None
|
| 269 |
)
|
| 270 |
+
|
|
|
|
|
|
|
| 271 |
with torch.inference_mode():
|
| 272 |
img_emb = self._run_vision_encoder(image)
|
| 273 |
bos_emb = text_encoder(
|
| 274 |
+
torch.tensor([[self.config.tokenizer.bos_id]], device=self.device), self.text
|
|
|
|
| 275 |
)
|
| 276 |
inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
|
| 277 |
mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
|
| 278 |
pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
|
| 279 |
self._prefill(inputs_embeds, mask, pos_ids, lora)
|
| 280 |
+
|
| 281 |
return EncodedImage(
|
| 282 |
pos=inputs_embeds.size(1),
|
| 283 |
caches=[
|
|
|
|
| 289 |
],
|
| 290 |
)
|
| 291 |
|
| 292 |
+
|
| 293 |
def _apply_top_p(self, probs: torch.Tensor, top_p: float):
|
| 294 |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
| 295 |
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
|
|
|
| 852 |
return {"points": objects}
|
| 853 |
|
| 854 |
|
|
|
|
| 855 |
def _load_encoded_image_batched(self, encoded_image, batch_size: int):
|
| 856 |
"""
|
| 857 |
Clone single-image KV caches into a batch-B cache so we can decode B labels in parallel.
|
|
|
|
| 876 |
temperature: float = 0.0,
|
| 877 |
top_p: float = 0.0,
|
| 878 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 879 |
tpl = self.config.tokenizer.templates["detect"]
|
| 880 |
if tpl is None:
|
| 881 |
raise NotImplementedError("Model does not support object detection (no detect template).")
|
|
|
|
| 889 |
T = max(lens)
|
| 890 |
eos = self.config.tokenizer.eos_id
|
| 891 |
|
|
|
|
| 892 |
prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
|
| 893 |
for i, ids in enumerate(rows):
|
| 894 |
prompt_ids[i, : ids.numel()] = ids
|
| 895 |
|
| 896 |
+
prompt_emb = text_encoder(prompt_ids, self.text) # (B, T, C)
|
| 897 |
+
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
|
|
|
| 898 |
|
| 899 |
+
# 4-D mask is broadcastable to (B, n_heads, T, K)
|
| 900 |
attn = self.attn_mask
|
| 901 |
+
mask = attn[:, :, pos : pos + T, :].expand(B, -1, -1, -1).contiguous()
|
| 902 |
pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long)
|
| 903 |
|
| 904 |
hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B, T, C)
|
| 905 |
logits_BTV = lm_head(hidden_BTC, self.text) # (B, T, V)
|
| 906 |
|
|
|
|
| 907 |
idx = (torch.tensor(lens, device=self.device, dtype=torch.long) - 1).clamp_min(0)
|
| 908 |
last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
|
| 909 |
last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
|
|
|
|
| 915 |
probs = self._apply_top_p(probs, top_p)
|
| 916 |
next_token = torch.multinomial(probs, num_samples=1) # (B,1)
|
| 917 |
|
|
|
|
| 918 |
pos_end = pos + T
|
| 919 |
return last_hidden, next_token, pos_end # (B,1,C), (B,1), int
|
| 920 |
|
| 921 |
|
| 922 |
+
|
| 923 |
def _generate_points_batched(
|
| 924 |
self,
|
| 925 |
+
hidden, # (B,1,C)
|
| 926 |
+
next_token, # (B,1)
|
| 927 |
+
pos: int, # shared scalar next position
|
| 928 |
include_size: bool = True,
|
| 929 |
max_objects: int = 50,
|
| 930 |
lora=None,
|
| 931 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 932 |
B = hidden.size(0)
|
| 933 |
device = self.device
|
| 934 |
out = [[] for _ in range(B)]
|
| 935 |
eos_id = self.config.tokenizer.eos_id
|
| 936 |
max_ctx = self.config.text.max_context
|
| 937 |
|
| 938 |
+
# 4-D mask: (B, 1, q_len=1, kv_len)
|
| 939 |
mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
|
| 940 |
if pos > 0:
|
| 941 |
mask[:, :, :, :pos] = True
|
|
|
|
| 946 |
|
| 947 |
with torch.inference_mode():
|
| 948 |
while alive.any() and (counts < max_objects).any():
|
| 949 |
+
# --- x coordinate ---
|
| 950 |
x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
|
| 951 |
if x_logits.dim() == 3:
|
| 952 |
+
x_logits = x_logits.squeeze(1)
|
| 953 |
+
x_bin = x_logits.argmax(dim=-1).to(torch.float32)
|
| 954 |
x_center = x_bin / float(x_logits.size(-1)) # (B,)
|
| 955 |
+
x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
|
| 956 |
+
x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
|
| 957 |
|
| 958 |
+
mask[:, :, :, pos_id[0].item()] = True
|
|
|
|
| 959 |
logits, hidden = self._decode_one_tok(x_emb, mask, pos_id, lora)
|
| 960 |
+
pos_id += 1
|
|
|
|
| 961 |
|
| 962 |
+
# --- y coordinate ---
|
| 963 |
+
y_logits = decode_coordinate(hidden, self.region)
|
| 964 |
if y_logits.dim() == 3:
|
| 965 |
y_logits = y_logits.squeeze(1)
|
| 966 |
+
y_bin = y_logits.argmax(dim=-1).to(torch.float32)
|
| 967 |
+
y_center = y_bin / float(y_logits.size(-1)) # (B,)
|
| 968 |
+
y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1)
|
| 969 |
+
y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
|
| 970 |
|
| 971 |
+
mask[:, :, :, pos_id[0].item()] = True
|
| 972 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
|
| 973 |
+
pos_id += 1
|
|
|
|
| 974 |
|
| 975 |
if include_size:
|
| 976 |
+
size_logits = decode_size(hidden, self.region)
|
| 977 |
+
# Support both tuple-of-tensors and flattened (2, -1) forms
|
| 978 |
+
if isinstance(size_logits, (tuple, list)):
|
| 979 |
+
w_logits = size_logits[0]
|
| 980 |
+
h_logits = size_logits[1]
|
| 981 |
+
if w_logits.dim() == 3: # (B,1,1024)
|
| 982 |
+
w_logits = w_logits.squeeze(1)
|
| 983 |
+
h_logits = h_logits.squeeze(1)
|
| 984 |
+
else:
|
| 985 |
+
# size_logits shape: (2, B * size_bins) — reshape it back.
|
| 986 |
+
size_logits = size_logits.reshape(2, B, -1)
|
| 987 |
+
w_logits, h_logits = size_logits[0], size_logits[1] # (B, size_bins)
|
| 988 |
|
| 989 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 990 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
| 991 |
+
# inverse of log-scale mapping used by Moondream
|
|
|
|
| 992 |
w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
|
| 993 |
h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
|
| 994 |
|
| 995 |
+
size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
|
| 996 |
size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
|
| 997 |
|
|
|
|
| 998 |
for i in range(B):
|
| 999 |
if alive[i]:
|
| 1000 |
out[i].append({
|
|
|
|
| 1004 |
"y_max": (y_center[i] + h[i] / 2).item(),
|
| 1005 |
})
|
| 1006 |
|
| 1007 |
+
mask[:, :, :, pos_id[0].item()] = True
|
|
|
|
| 1008 |
logits, hidden = self._decode_one_tok(size_emb, mask, pos_id, lora)
|
| 1009 |
+
pos_id += 1
|
|
|
|
| 1010 |
next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
|
| 1011 |
else:
|
|
|
|
| 1012 |
for i in range(B):
|
| 1013 |
if alive[i]:
|
| 1014 |
out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
|
| 1015 |
+
|
| 1016 |
+
mask[:, :, :, pos_id[0].item()] = True
|
| 1017 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
|
| 1018 |
+
pos_id += 1
|
|
|
|
| 1019 |
next_tok = logits.argmax(dim=-1).squeeze(-1)
|
| 1020 |
|
| 1021 |
finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
|
|
|
|
| 1025 |
return out
|
| 1026 |
|
| 1027 |
|
| 1028 |
+
|
| 1029 |
def detect_multi(self, image, objects, settings=None):
|
| 1030 |
"""
|
| 1031 |
Parallel multi-label detection.
|
|
|
|
| 1045 |
B = len(objects)
|
| 1046 |
self._load_encoded_image_batched(image, B)
|
| 1047 |
|
|
|
|
| 1048 |
lora = None
|
| 1049 |
if "variant" in settings:
|
| 1050 |
lora = variant_state_dict(settings["variant"], device=self.device)
|
| 1051 |
|
|
|
|
| 1052 |
last_hidden, next_token, pos_end = self._prefill_prompt_batched(
|
| 1053 |
objects, image.pos, lora=lora, temperature=0.0, top_p=0.0
|
| 1054 |
)
|
| 1055 |
|
|
|
|
| 1056 |
max_objects = settings.get("max_objects", 50)
|
| 1057 |
det_lists = self._generate_points_batched(
|
| 1058 |
last_hidden, next_token, pos_end,
|
|
|
|
| 1065 |
for d in lst:
|
| 1066 |
d["label"] = lab
|
| 1067 |
res[lab] = lst
|
| 1068 |
+
|
| 1069 |
+
# IMPORTANT: restore caches to B=1 so future calls (e.g., encode_image) are safe.
|
| 1070 |
+
self._reset_kv_caches(1)
|
| 1071 |
+
|
| 1072 |
return {"objects": res}
|
| 1073 |
|
| 1074 |
|
| 1075 |
+
|
| 1076 |
def _detect_gaze(
|
| 1077 |
self,
|
| 1078 |
image: EncodedImage,
|