Update moondream.py
Browse filesfix: kv cache mismatch and decode size.
- moondream.py +114 -66
moondream.py
CHANGED
|
@@ -884,36 +884,53 @@ class MoondreamModel(nn.Module):
|
|
| 884 |
|
| 885 |
return {"points": objects}
|
| 886 |
|
| 887 |
-
|
|
|
|
| 888 |
"""
|
| 889 |
Accepts any of:
|
| 890 |
-
• (w_logits, h_logits)
|
| 891 |
-
• Tensor (
|
|
|
|
|
|
|
| 892 |
Returns (w_logits, h_logits) each shaped (B, C).
|
| 893 |
"""
|
| 894 |
if isinstance(size_ret, (tuple, list)):
|
| 895 |
w_logits, h_logits = size_ret
|
| 896 |
else:
|
| 897 |
t = size_ret
|
| 898 |
-
# squeeze
|
| 899 |
-
|
| 900 |
-
|
| 901 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 902 |
w_logits, h_logits = t[:, 0, :], t[:, 1, :]
|
| 903 |
elif t.dim() == 2:
|
| 904 |
-
|
|
|
|
| 905 |
w_logits, h_logits = t[0].unsqueeze(0), t[1].unsqueeze(0)
|
| 906 |
-
else:
|
| 907 |
-
C2 = t.
|
|
|
|
|
|
|
|
|
|
| 908 |
w_logits, h_logits = t[:, :C], t[:, C:]
|
| 909 |
else:
|
| 910 |
raise RuntimeError(f"Unexpected decode_size shape {tuple(t.shape)}")
|
| 911 |
-
|
|
|
|
| 912 |
if w_logits.dim() == 3: w_logits = w_logits.squeeze(1)
|
| 913 |
if h_logits.dim() == 3: h_logits = h_logits.squeeze(1)
|
|
|
|
|
|
|
| 914 |
return w_logits.contiguous(), h_logits.contiguous()
|
| 915 |
|
| 916 |
|
|
|
|
| 917 |
def _load_encoded_image_batched(self, encoded_image, batch_size: int):
|
| 918 |
for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
|
| 919 |
T = k.size(2)
|
|
@@ -969,26 +986,35 @@ class MoondreamModel(nn.Module):
|
|
| 969 |
|
| 970 |
|
| 971 |
def _generate_points_batched(
|
| 972 |
-
self,
|
| 973 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 974 |
B = hidden.size(0)
|
| 975 |
device = self.device
|
| 976 |
out = [[] for _ in range(B)]
|
| 977 |
eos_id = self.config.tokenizer.eos_id
|
| 978 |
max_ctx = self.config.text.max_context
|
| 979 |
|
| 980 |
-
# 4-D mask: (B,1,1,kv_len)
|
| 981 |
mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
|
| 982 |
-
|
| 983 |
-
|
| 984 |
-
|
|
|
|
|
|
|
| 985 |
|
| 986 |
-
|
| 987 |
-
|
| 988 |
if use_soft_argmax:
|
| 989 |
-
|
| 990 |
-
bins
|
| 991 |
-
return (
|
| 992 |
idx = logits.argmax(dim=-1).to(torch.float32)
|
| 993 |
return idx / float(logits.size(-1) - 1)
|
| 994 |
|
|
@@ -997,38 +1023,41 @@ class MoondreamModel(nn.Module):
|
|
| 997 |
|
| 998 |
with torch.inference_mode():
|
| 999 |
while alive.any() and (counts < max_objects).any():
|
| 1000 |
-
# x
|
| 1001 |
-
x_logits = decode_coordinate(hidden, self.region)
|
| 1002 |
-
if x_logits.dim() == 3:
|
| 1003 |
-
|
| 1004 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1005 |
|
| 1006 |
-
|
| 1007 |
-
|
| 1008 |
-
|
|
|
|
|
|
|
| 1009 |
|
| 1010 |
-
# y
|
| 1011 |
y_logits = decode_coordinate(hidden, self.region)
|
| 1012 |
-
if y_logits.dim() == 3:
|
| 1013 |
-
|
| 1014 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1015 |
|
| 1016 |
-
|
| 1017 |
-
|
| 1018 |
-
|
|
|
|
| 1019 |
|
| 1020 |
if include_size:
|
|
|
|
| 1021 |
size_ret = decode_size(hidden, self.region)
|
| 1022 |
-
|
| 1023 |
-
if isinstance(size_ret, (tuple, list)):
|
| 1024 |
-
w_logits, h_logits = size_ret
|
| 1025 |
-
else:
|
| 1026 |
-
t = size_ret
|
| 1027 |
-
if t.dim() == 4: # (B,1,2,C)
|
| 1028 |
-
t = t.squeeze(1) # → (B,2,C)
|
| 1029 |
-
if t.dim() != 3 or t.size(1) != 2:
|
| 1030 |
-
raise RuntimeError(f"Unexpected decode_size shape {tuple(t.shape)}")
|
| 1031 |
-
w_logits, h_logits = t[:, 0, :], t[:, 1, :]
|
| 1032 |
|
| 1033 |
if use_soft_argmax:
|
| 1034 |
bins = torch.arange(w_logits.size(-1), device=device, dtype=torch.float32)
|
|
@@ -1038,14 +1067,18 @@ class MoondreamModel(nn.Module):
|
|
| 1038 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 1039 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
| 1040 |
|
| 1041 |
-
# inverse log-scale mapping used by
|
| 1042 |
-
w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
|
| 1043 |
-
h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
|
| 1044 |
|
| 1045 |
-
size_emb = encode_size(
|
|
|
|
|
|
|
|
|
|
| 1046 |
|
| 1047 |
-
for
|
| 1048 |
-
|
|
|
|
| 1049 |
xl = (x_center[i] - w[i] / 2).item()
|
| 1050 |
xr = (x_center[i] + w[i] / 2).item()
|
| 1051 |
yt = (y_center[i] - h[i] / 2).item()
|
|
@@ -1057,26 +1090,41 @@ class MoondreamModel(nn.Module):
|
|
| 1057 |
"y_max": max(0.0, min(1.0, yb)),
|
| 1058 |
})
|
| 1059 |
|
| 1060 |
-
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1064 |
else:
|
| 1065 |
-
|
| 1066 |
-
|
| 1067 |
-
|
| 1068 |
-
|
| 1069 |
-
|
| 1070 |
-
|
| 1071 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1072 |
|
| 1073 |
-
|
| 1074 |
-
counts
|
|
|
|
|
|
|
|
|
|
| 1075 |
alive &= ~finished_now
|
| 1076 |
|
| 1077 |
return out
|
| 1078 |
|
| 1079 |
|
|
|
|
| 1080 |
def detect_multi(self, image, objects, settings=None):
|
| 1081 |
if self.config.tokenizer.templates["detect"] is None:
|
| 1082 |
raise NotImplementedError("Model does not support object detection.")
|
|
|
|
| 884 |
|
| 885 |
return {"points": objects}
|
| 886 |
|
| 887 |
+
# moondream.py
|
| 888 |
+
def _norm_size_logits(self, size_ret: torch.Tensor | tuple, B: int):
|
| 889 |
"""
|
| 890 |
Accepts any of:
|
| 891 |
+
• tuple/list: (w_logits, h_logits)
|
| 892 |
+
• Tensor (..., 2, C) # from batch-safe region.decode_size
|
| 893 |
+
• Tensor (B, 2*C) # fallback
|
| 894 |
+
• Tensor (2, C) when B == 1
|
| 895 |
Returns (w_logits, h_logits) each shaped (B, C).
|
| 896 |
"""
|
| 897 |
if isinstance(size_ret, (tuple, list)):
|
| 898 |
w_logits, h_logits = size_ret
|
| 899 |
else:
|
| 900 |
t = size_ret
|
| 901 |
+
# if we got (..., 2, C), squeeze a single seq dim if present
|
| 902 |
+
if t.dim() >= 3 and t.shape[-2] == 2:
|
| 903 |
+
# bring to (B, 2, C)
|
| 904 |
+
while t.dim() > 3:
|
| 905 |
+
t = t.squeeze(1)
|
| 906 |
+
if t.dim() != 3 or t.shape[0] not in (1, B):
|
| 907 |
+
raise RuntimeError(f"Unexpected batched size logits shape {tuple(size_ret.shape)}")
|
| 908 |
+
# expand B if needed
|
| 909 |
+
if t.shape[0] == 1 and B > 1:
|
| 910 |
+
t = t.expand(B, -1, -1).contiguous()
|
| 911 |
w_logits, h_logits = t[:, 0, :], t[:, 1, :]
|
| 912 |
elif t.dim() == 2:
|
| 913 |
+
# (2, C) (B==1) or (B, 2*C)
|
| 914 |
+
if t.shape[0] == 2 and B == 1:
|
| 915 |
w_logits, h_logits = t[0].unsqueeze(0), t[1].unsqueeze(0)
|
| 916 |
+
else:
|
| 917 |
+
C2 = t.shape[1]
|
| 918 |
+
if C2 % 2 != 0:
|
| 919 |
+
raise RuntimeError(f"Cannot split last dim {C2} into (w,h)")
|
| 920 |
+
C = C2 // 2
|
| 921 |
w_logits, h_logits = t[:, :C], t[:, C:]
|
| 922 |
else:
|
| 923 |
raise RuntimeError(f"Unexpected decode_size shape {tuple(t.shape)}")
|
| 924 |
+
|
| 925 |
+
# final sanity: make sure they’re (B, C)
|
| 926 |
if w_logits.dim() == 3: w_logits = w_logits.squeeze(1)
|
| 927 |
if h_logits.dim() == 3: h_logits = h_logits.squeeze(1)
|
| 928 |
+
if w_logits.shape[0] != B or h_logits.shape[0] != B:
|
| 929 |
+
raise RuntimeError(f"Batched size logits mismatch: got {w_logits.shape[0]} vs B={B}")
|
| 930 |
return w_logits.contiguous(), h_logits.contiguous()
|
| 931 |
|
| 932 |
|
| 933 |
+
|
| 934 |
def _load_encoded_image_batched(self, encoded_image, batch_size: int):
|
| 935 |
for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
|
| 936 |
T = k.size(2)
|
|
|
|
| 986 |
|
| 987 |
|
| 988 |
def _generate_points_batched(
|
| 989 |
+
self,
|
| 990 |
+
hidden, # (B,1,C)
|
| 991 |
+
next_token, # (B,1) (unused for greedy)
|
| 992 |
+
pos, # int (start position in cache)
|
| 993 |
+
include_size: bool = True,
|
| 994 |
+
max_objects: int = 50,
|
| 995 |
+
lora=None,
|
| 996 |
+
use_soft_argmax: bool = True, # reduces jitter/hallucinations
|
| 997 |
+
):
|
| 998 |
B = hidden.size(0)
|
| 999 |
device = self.device
|
| 1000 |
out = [[] for _ in range(B)]
|
| 1001 |
eos_id = self.config.tokenizer.eos_id
|
| 1002 |
max_ctx = self.config.text.max_context
|
| 1003 |
|
| 1004 |
+
# 4-D mask: (B, 1, q_len=1, kv_len)
|
| 1005 |
mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
|
| 1006 |
+
p0 = int(pos)
|
| 1007 |
+
if p0 > 0:
|
| 1008 |
+
mask[:, :, :, :p0] = True
|
| 1009 |
+
# per-row position ids (B,1)
|
| 1010 |
+
pos_ids = torch.full((B, 1), p0, device=device, dtype=torch.long)
|
| 1011 |
|
| 1012 |
+
# helper: (B, bins) -> (B,) in [0,1]
|
| 1013 |
+
def _argmax01(logits: torch.Tensor) -> torch.Tensor:
|
| 1014 |
if use_soft_argmax:
|
| 1015 |
+
probs = torch.softmax(logits, dim=-1)
|
| 1016 |
+
bins = torch.arange(probs.size(-1), device=logits.device, dtype=torch.float32)
|
| 1017 |
+
return (probs * bins).sum(dim=-1) / float(probs.size(-1) - 1)
|
| 1018 |
idx = logits.argmax(dim=-1).to(torch.float32)
|
| 1019 |
return idx / float(logits.size(-1) - 1)
|
| 1020 |
|
|
|
|
| 1023 |
|
| 1024 |
with torch.inference_mode():
|
| 1025 |
while alive.any() and (counts < max_objects).any():
|
| 1026 |
+
# ---------------- x ----------------
|
| 1027 |
+
x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
|
| 1028 |
+
if x_logits.dim() == 3:
|
| 1029 |
+
x_logits = x_logits.squeeze(1) # -> (B,1024)
|
| 1030 |
+
x_center = _argmax01(x_logits) # (B,)
|
| 1031 |
+
x_emb = encode_coordinate(
|
| 1032 |
+
x_center.to(dtype=x_logits.dtype).unsqueeze(-1), # (B,1)
|
| 1033 |
+
self.region
|
| 1034 |
+
).unsqueeze(1) # (B,1,C)
|
| 1035 |
|
| 1036 |
+
# advance one token for ALIVE rows only
|
| 1037 |
+
step_col = int(pos_ids[0, 0].item())
|
| 1038 |
+
mask[alive, :, :, step_col] = True
|
| 1039 |
+
logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
|
| 1040 |
+
pos_ids[alive, 0] += 1
|
| 1041 |
|
| 1042 |
+
# ---------------- y ----------------
|
| 1043 |
y_logits = decode_coordinate(hidden, self.region)
|
| 1044 |
+
if y_logits.dim() == 3:
|
| 1045 |
+
y_logits = y_logits.squeeze(1) # (B,1024)
|
| 1046 |
+
y_center = _argmax01(y_logits) # (B,)
|
| 1047 |
+
y_emb = encode_coordinate(
|
| 1048 |
+
y_center.to(dtype=y_logits.dtype).unsqueeze(-1),
|
| 1049 |
+
self.region
|
| 1050 |
+
).unsqueeze(1) # (B,1,C)
|
| 1051 |
|
| 1052 |
+
step_col = int(pos_ids[0, 0].item())
|
| 1053 |
+
mask[alive, :, :, step_col] = True
|
| 1054 |
+
logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
|
| 1055 |
+
pos_ids[alive, 0] += 1
|
| 1056 |
|
| 1057 |
if include_size:
|
| 1058 |
+
# ------------- size (w,h) -------------
|
| 1059 |
size_ret = decode_size(hidden, self.region)
|
| 1060 |
+
w_logits, h_logits = self._norm_size_logits(size_ret, B) # each (B,C)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1061 |
|
| 1062 |
if use_soft_argmax:
|
| 1063 |
bins = torch.arange(w_logits.size(-1), device=device, dtype=torch.float32)
|
|
|
|
| 1067 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 1068 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
| 1069 |
|
| 1070 |
+
# inverse log-scale mapping used by MD2
|
| 1071 |
+
w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0) # (B,)
|
| 1072 |
+
h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0) # (B,)
|
| 1073 |
|
| 1074 |
+
size_emb = encode_size(
|
| 1075 |
+
torch.stack([w, h], dim=1).to(dtype=w_logits.dtype), # (B,2)
|
| 1076 |
+
self.region
|
| 1077 |
+
).unsqueeze(1) # (B,1,C)
|
| 1078 |
|
| 1079 |
+
# record boxes only for ALIVE rows
|
| 1080 |
+
alive_idx = alive.nonzero(as_tuple=False).view(-1)
|
| 1081 |
+
for i in alive_idx.tolist():
|
| 1082 |
xl = (x_center[i] - w[i] / 2).item()
|
| 1083 |
xr = (x_center[i] + w[i] / 2).item()
|
| 1084 |
yt = (y_center[i] - h[i] / 2).item()
|
|
|
|
| 1090 |
"y_max": max(0.0, min(1.0, yb)),
|
| 1091 |
})
|
| 1092 |
|
| 1093 |
+
step_col = int(pos_ids[0, 0].item())
|
| 1094 |
+
mask[alive, :, :, step_col] = True
|
| 1095 |
+
logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
|
| 1096 |
+
pos_ids[alive, 0] += 1
|
| 1097 |
+
next_tok = logits.argmax(dim=-1)
|
| 1098 |
+
if next_tok.dim() == 3: # (B,1,1) possible
|
| 1099 |
+
next_tok = next_tok.squeeze(-1).squeeze(-1)
|
| 1100 |
+
elif next_tok.dim() == 2: # (B,1)
|
| 1101 |
+
next_tok = next_tok.squeeze(1)
|
| 1102 |
else:
|
| 1103 |
+
# point mode
|
| 1104 |
+
alive_idx = alive.nonzero(as_tuple=False).view(-1)
|
| 1105 |
+
for i in alive_idx.tolist():
|
| 1106 |
+
out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
|
| 1107 |
+
step_col = int(pos_ids[0, 0].item())
|
| 1108 |
+
mask[alive, :, :, step_col] = True
|
| 1109 |
+
logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
|
| 1110 |
+
pos_ids[alive, 0] += 1
|
| 1111 |
+
next_tok = logits.argmax(dim=-1)
|
| 1112 |
+
if next_tok.dim() == 3:
|
| 1113 |
+
next_tok = next_tok.squeeze(-1).squeeze(-1)
|
| 1114 |
+
elif next_tok.dim() == 2:
|
| 1115 |
+
next_tok = next_tok.squeeze(1)
|
| 1116 |
|
| 1117 |
+
# we added one object/point for all ALIVE rows this iteration
|
| 1118 |
+
counts[alive] += 1
|
| 1119 |
+
|
| 1120 |
+
# stop rows that hit eos OR reached max_objects
|
| 1121 |
+
finished_now = (next_tok == eos_id) | (counts >= max_objects)
|
| 1122 |
alive &= ~finished_now
|
| 1123 |
|
| 1124 |
return out
|
| 1125 |
|
| 1126 |
|
| 1127 |
+
|
| 1128 |
def detect_multi(self, image, objects, settings=None):
|
| 1129 |
if self.config.tokenizer.templates["detect"] is None:
|
| 1130 |
raise NotImplementedError("Model does not support object detection.")
|