Update moondream.py
Browse filesfix: _generate_points_batched to advance the mask per_row & never the batch to row_0
- moondream.py +31 -35
moondream.py
CHANGED
|
@@ -993,7 +993,7 @@ class MoondreamModel(nn.Module):
|
|
| 993 |
include_size: bool = True,
|
| 994 |
max_objects: int = 50,
|
| 995 |
lora=None,
|
| 996 |
-
use_soft_argmax: bool = True,
|
| 997 |
):
|
| 998 |
B = hidden.size(0)
|
| 999 |
device = self.device
|
|
@@ -1006,6 +1006,7 @@ class MoondreamModel(nn.Module):
|
|
| 1006 |
p0 = int(pos)
|
| 1007 |
if p0 > 0:
|
| 1008 |
mask[:, :, :, :p0] = True
|
|
|
|
| 1009 |
# per-row position ids (B,1)
|
| 1010 |
pos_ids = torch.full((B, 1), p0, device=device, dtype=torch.long)
|
| 1011 |
|
|
@@ -1018,41 +1019,49 @@ class MoondreamModel(nn.Module):
|
|
| 1018 |
idx = logits.argmax(dim=-1).to(torch.float32)
|
| 1019 |
return idx / float(logits.size(-1) - 1)
|
| 1020 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1021 |
alive = torch.ones(B, dtype=torch.bool, device=device)
|
| 1022 |
counts = torch.zeros(B, dtype=torch.int32, device=device)
|
| 1023 |
|
| 1024 |
with torch.inference_mode():
|
| 1025 |
while alive.any() and (counts < max_objects).any():
|
|
|
|
| 1026 |
# ---------------- x ----------------
|
| 1027 |
x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
|
| 1028 |
if x_logits.dim() == 3:
|
| 1029 |
x_logits = x_logits.squeeze(1) # -> (B,1024)
|
| 1030 |
x_center = _argmax01(x_logits) # (B,)
|
| 1031 |
x_emb = encode_coordinate(
|
| 1032 |
-
x_center.to(dtype=x_logits.dtype).unsqueeze(-1),
|
| 1033 |
self.region
|
| 1034 |
).unsqueeze(1) # (B,1,C)
|
| 1035 |
|
| 1036 |
-
|
| 1037 |
-
step_col = int(pos_ids[0, 0].item())
|
| 1038 |
-
mask[alive, :, :, step_col] = True
|
| 1039 |
logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
|
| 1040 |
-
pos_ids[
|
| 1041 |
|
| 1042 |
# ---------------- y ----------------
|
| 1043 |
y_logits = decode_coordinate(hidden, self.region)
|
| 1044 |
if y_logits.dim() == 3:
|
| 1045 |
-
y_logits = y_logits.squeeze(1)
|
| 1046 |
-
y_center = _argmax01(y_logits)
|
| 1047 |
y_emb = encode_coordinate(
|
| 1048 |
y_center.to(dtype=y_logits.dtype).unsqueeze(-1),
|
| 1049 |
self.region
|
| 1050 |
-
).unsqueeze(1)
|
| 1051 |
|
| 1052 |
-
|
| 1053 |
-
mask[alive, :, :, step_col] = True
|
| 1054 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
|
| 1055 |
-
pos_ids[
|
| 1056 |
|
| 1057 |
if include_size:
|
| 1058 |
# ------------- size (w,h) -------------
|
|
@@ -1077,8 +1086,7 @@ class MoondreamModel(nn.Module):
|
|
| 1077 |
).unsqueeze(1) # (B,1,C)
|
| 1078 |
|
| 1079 |
# record boxes only for ALIVE rows
|
| 1080 |
-
|
| 1081 |
-
for i in alive_idx.tolist():
|
| 1082 |
xl = (x_center[i] - w[i] / 2).item()
|
| 1083 |
xr = (x_center[i] + w[i] / 2).item()
|
| 1084 |
yt = (y_center[i] - h[i] / 2).item()
|
|
@@ -1090,41 +1098,29 @@ class MoondreamModel(nn.Module):
|
|
| 1090 |
"y_max": max(0.0, min(1.0, yb)),
|
| 1091 |
})
|
| 1092 |
|
| 1093 |
-
|
| 1094 |
-
mask[alive, :, :, step_col] = True
|
| 1095 |
logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
|
| 1096 |
-
pos_ids[
|
| 1097 |
next_tok = logits.argmax(dim=-1)
|
| 1098 |
-
if next_tok.dim() == 3: # (B,1,1) possible
|
| 1099 |
-
next_tok = next_tok.squeeze(-1).squeeze(-1)
|
| 1100 |
-
elif next_tok.dim() == 2: # (B,1)
|
| 1101 |
-
next_tok = next_tok.squeeze(1)
|
| 1102 |
else:
|
| 1103 |
# point mode
|
| 1104 |
-
|
| 1105 |
-
for i in alive_idx.tolist():
|
| 1106 |
out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
|
| 1107 |
-
|
| 1108 |
-
mask[alive, :, :, step_col] = True
|
| 1109 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
|
| 1110 |
-
pos_ids[
|
| 1111 |
next_tok = logits.argmax(dim=-1)
|
| 1112 |
-
if next_tok.dim() == 3:
|
| 1113 |
-
next_tok = next_tok.squeeze(-1).squeeze(-1)
|
| 1114 |
-
elif next_tok.dim() == 2:
|
| 1115 |
-
next_tok = next_tok.squeeze(1)
|
| 1116 |
|
| 1117 |
-
#
|
| 1118 |
-
|
|
|
|
| 1119 |
|
| 1120 |
-
|
| 1121 |
finished_now = (next_tok == eos_id) | (counts >= max_objects)
|
| 1122 |
alive &= ~finished_now
|
| 1123 |
|
| 1124 |
return out
|
| 1125 |
|
| 1126 |
-
|
| 1127 |
-
|
| 1128 |
def detect_multi(self, image, objects, settings=None):
|
| 1129 |
if self.config.tokenizer.templates["detect"] is None:
|
| 1130 |
raise NotImplementedError("Model does not support object detection.")
|
|
|
|
| 993 |
include_size: bool = True,
|
| 994 |
max_objects: int = 50,
|
| 995 |
lora=None,
|
| 996 |
+
use_soft_argmax: bool = True,
|
| 997 |
):
|
| 998 |
B = hidden.size(0)
|
| 999 |
device = self.device
|
|
|
|
| 1006 |
p0 = int(pos)
|
| 1007 |
if p0 > 0:
|
| 1008 |
mask[:, :, :, :p0] = True
|
| 1009 |
+
|
| 1010 |
# per-row position ids (B,1)
|
| 1011 |
pos_ids = torch.full((B, 1), p0, device=device, dtype=torch.long)
|
| 1012 |
|
|
|
|
| 1019 |
idx = logits.argmax(dim=-1).to(torch.float32)
|
| 1020 |
return idx / float(logits.size(-1) - 1)
|
| 1021 |
|
| 1022 |
+
# advance-one-step for a subset of rows (alive only)
|
| 1023 |
+
def _advance_rows(row_mask: torch.Tensor):
|
| 1024 |
+
idx = row_mask.nonzero(as_tuple=False).flatten()
|
| 1025 |
+
# set each row's next KV column true
|
| 1026 |
+
for i in idx.tolist():
|
| 1027 |
+
col = int(pos_ids[i, 0].item())
|
| 1028 |
+
mask[i, 0, 0, col] = True
|
| 1029 |
+
# decoder step (all rows run, but only alive rows’ pos_ids move)
|
| 1030 |
+
return idx
|
| 1031 |
+
|
| 1032 |
alive = torch.ones(B, dtype=torch.bool, device=device)
|
| 1033 |
counts = torch.zeros(B, dtype=torch.int32, device=device)
|
| 1034 |
|
| 1035 |
with torch.inference_mode():
|
| 1036 |
while alive.any() and (counts < max_objects).any():
|
| 1037 |
+
|
| 1038 |
# ---------------- x ----------------
|
| 1039 |
x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
|
| 1040 |
if x_logits.dim() == 3:
|
| 1041 |
x_logits = x_logits.squeeze(1) # -> (B,1024)
|
| 1042 |
x_center = _argmax01(x_logits) # (B,)
|
| 1043 |
x_emb = encode_coordinate(
|
| 1044 |
+
x_center.to(dtype=x_logits.dtype).unsqueeze(-1), # (B,1)
|
| 1045 |
self.region
|
| 1046 |
).unsqueeze(1) # (B,1,C)
|
| 1047 |
|
| 1048 |
+
idx = _advance_rows(alive)
|
|
|
|
|
|
|
| 1049 |
logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
|
| 1050 |
+
pos_ids[idx, 0] += 1
|
| 1051 |
|
| 1052 |
# ---------------- y ----------------
|
| 1053 |
y_logits = decode_coordinate(hidden, self.region)
|
| 1054 |
if y_logits.dim() == 3:
|
| 1055 |
+
y_logits = y_logits.squeeze(1)
|
| 1056 |
+
y_center = _argmax01(y_logits)
|
| 1057 |
y_emb = encode_coordinate(
|
| 1058 |
y_center.to(dtype=y_logits.dtype).unsqueeze(-1),
|
| 1059 |
self.region
|
| 1060 |
+
).unsqueeze(1)
|
| 1061 |
|
| 1062 |
+
idx = _advance_rows(alive)
|
|
|
|
| 1063 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
|
| 1064 |
+
pos_ids[idx, 0] += 1
|
| 1065 |
|
| 1066 |
if include_size:
|
| 1067 |
# ------------- size (w,h) -------------
|
|
|
|
| 1086 |
).unsqueeze(1) # (B,1,C)
|
| 1087 |
|
| 1088 |
# record boxes only for ALIVE rows
|
| 1089 |
+
for i in alive.nonzero(as_tuple=False).flatten().tolist():
|
|
|
|
| 1090 |
xl = (x_center[i] - w[i] / 2).item()
|
| 1091 |
xr = (x_center[i] + w[i] / 2).item()
|
| 1092 |
yt = (y_center[i] - h[i] / 2).item()
|
|
|
|
| 1098 |
"y_max": max(0.0, min(1.0, yb)),
|
| 1099 |
})
|
| 1100 |
|
| 1101 |
+
idx = _advance_rows(alive)
|
|
|
|
| 1102 |
logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
|
| 1103 |
+
pos_ids[idx, 0] += 1
|
| 1104 |
next_tok = logits.argmax(dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1105 |
else:
|
| 1106 |
# point mode
|
| 1107 |
+
for i in alive.nonzero(as_tuple=False).flatten().tolist():
|
|
|
|
| 1108 |
out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
|
| 1109 |
+
idx = _advance_rows(alive)
|
|
|
|
| 1110 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
|
| 1111 |
+
pos_ids[idx, 0] += 1
|
| 1112 |
next_tok = logits.argmax(dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1113 |
|
| 1114 |
+
# normalize next_tok to shape (B,)
|
| 1115 |
+
while next_tok.dim() > 1:
|
| 1116 |
+
next_tok = next_tok.squeeze(-1)
|
| 1117 |
|
| 1118 |
+
counts[alive] += 1
|
| 1119 |
finished_now = (next_tok == eos_id) | (counts >= max_objects)
|
| 1120 |
alive &= ~finished_now
|
| 1121 |
|
| 1122 |
return out
|
| 1123 |
|
|
|
|
|
|
|
| 1124 |
def detect_multi(self, image, objects, settings=None):
|
| 1125 |
if self.config.tokenizer.templates["detect"] is None:
|
| 1126 |
raise NotImplementedError("Model does not support object detection.")
|