Update moondream.py
Browse filesfix: batched generation
- moondream.py +59 -59
moondream.py
CHANGED
|
@@ -943,7 +943,14 @@ class MoondreamModel(nn.Module):
|
|
| 943 |
b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
|
| 944 |
|
| 945 |
|
| 946 |
-
def _prefill_prompt_batched(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 947 |
tpl = self.config.tokenizer.templates["detect"]
|
| 948 |
if tpl is None:
|
| 949 |
raise NotImplementedError("Model does not support object detection.")
|
|
@@ -953,43 +960,49 @@ class MoondreamModel(nn.Module):
|
|
| 953 |
ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
|
| 954 |
t = torch.tensor(ids, device=self.device, dtype=torch.long)
|
| 955 |
rows.append(t); lens.append(t.numel())
|
|
|
|
| 956 |
B, T = len(rows), max(lens)
|
| 957 |
eos = self.config.tokenizer.eos_id
|
| 958 |
|
|
|
|
| 959 |
prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
|
| 960 |
for i, ids in enumerate(rows):
|
| 961 |
prompt_ids[i, : ids.numel()] = ids
|
| 962 |
|
| 963 |
-
prompt_emb = text_encoder(prompt_ids, self.text)
|
| 964 |
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
| 965 |
|
| 966 |
-
base = self.attn_mask[:, :, pos:pos+T, :]
|
| 967 |
-
mask = base.expand(B, -1, -1, -1).contiguous()
|
| 968 |
|
| 969 |
pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
|
| 970 |
-
hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora)
|
| 971 |
-
logits_BTV = lm_head(hidden_BTC, self.text)
|
| 972 |
|
| 973 |
-
|
|
|
|
| 974 |
last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
|
| 975 |
last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
|
| 976 |
|
| 977 |
if temperature == 0.0:
|
| 978 |
-
next_token = last_logits.argmax(dim=-1, keepdim=True)
|
| 979 |
else:
|
| 980 |
probs = torch.softmax(last_logits / temperature, dim=-1)
|
| 981 |
probs = self._apply_top_p(probs, top_p)
|
| 982 |
-
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
|
|
|
|
|
| 983 |
|
| 984 |
-
|
| 985 |
-
|
| 986 |
|
| 987 |
|
| 988 |
def _generate_points_batched(
|
| 989 |
self,
|
| 990 |
-
hidden,
|
| 991 |
-
next_token,
|
| 992 |
-
|
| 993 |
include_size: bool = True,
|
| 994 |
max_objects: int = 50,
|
| 995 |
lora=None,
|
|
@@ -999,18 +1012,17 @@ class MoondreamModel(nn.Module):
|
|
| 999 |
device = self.device
|
| 1000 |
out = [[] for _ in range(B)]
|
| 1001 |
eos_id = self.config.tokenizer.eos_id
|
|
|
|
| 1002 |
max_ctx = self.config.text.max_context
|
| 1003 |
|
| 1004 |
-
#
|
| 1005 |
mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
|
| 1006 |
-
|
| 1007 |
-
|
| 1008 |
-
|
|
|
|
|
|
|
| 1009 |
|
| 1010 |
-
# per-row position ids (B,1)
|
| 1011 |
-
pos_ids = torch.full((B, 1), p0, device=device, dtype=torch.long)
|
| 1012 |
-
|
| 1013 |
-
# helper: (B, bins) -> (B,) in [0,1]
|
| 1014 |
def _argmax01(logits: torch.Tensor) -> torch.Tensor:
|
| 1015 |
if use_soft_argmax:
|
| 1016 |
probs = torch.softmax(logits, dim=-1)
|
|
@@ -1019,14 +1031,11 @@ class MoondreamModel(nn.Module):
|
|
| 1019 |
idx = logits.argmax(dim=-1).to(torch.float32)
|
| 1020 |
return idx / float(logits.size(-1) - 1)
|
| 1021 |
|
| 1022 |
-
# advance-one-step for a subset of rows (alive only)
|
| 1023 |
def _advance_rows(row_mask: torch.Tensor):
|
| 1024 |
idx = row_mask.nonzero(as_tuple=False).flatten()
|
| 1025 |
-
# set each row's next KV column true
|
| 1026 |
for i in idx.tolist():
|
| 1027 |
col = int(pos_ids[i, 0].item())
|
| 1028 |
mask[i, 0, 0, col] = True
|
| 1029 |
-
# decoder step (all rows run, but only alive rows’ pos_ids move)
|
| 1030 |
return idx
|
| 1031 |
|
| 1032 |
alive = torch.ones(B, dtype=torch.bool, device=device)
|
|
@@ -1034,39 +1043,29 @@ class MoondreamModel(nn.Module):
|
|
| 1034 |
|
| 1035 |
with torch.inference_mode():
|
| 1036 |
while alive.any() and (counts < max_objects).any():
|
| 1037 |
-
|
| 1038 |
-
#
|
| 1039 |
-
x_logits =
|
| 1040 |
-
|
| 1041 |
-
|
| 1042 |
-
x_center = _argmax01(x_logits) # (B,)
|
| 1043 |
-
x_emb = encode_coordinate(
|
| 1044 |
-
x_center.to(dtype=x_logits.dtype).unsqueeze(-1), # (B,1)
|
| 1045 |
-
self.region
|
| 1046 |
-
).unsqueeze(1) # (B,1,C)
|
| 1047 |
|
| 1048 |
idx = _advance_rows(alive)
|
| 1049 |
logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
|
| 1050 |
pos_ids[idx, 0] += 1
|
| 1051 |
|
| 1052 |
-
#
|
| 1053 |
y_logits = decode_coordinate(hidden, self.region)
|
| 1054 |
-
if y_logits.dim() == 3:
|
| 1055 |
-
y_logits = y_logits.squeeze(1)
|
| 1056 |
y_center = _argmax01(y_logits)
|
| 1057 |
-
y_emb = encode_coordinate(
|
| 1058 |
-
y_center.to(dtype=y_logits.dtype).unsqueeze(-1),
|
| 1059 |
-
self.region
|
| 1060 |
-
).unsqueeze(1)
|
| 1061 |
|
| 1062 |
idx = _advance_rows(alive)
|
| 1063 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
|
| 1064 |
pos_ids[idx, 0] += 1
|
| 1065 |
|
| 1066 |
if include_size:
|
| 1067 |
-
# ------------- size (w,h) -------------
|
| 1068 |
size_ret = decode_size(hidden, self.region)
|
| 1069 |
-
w_logits, h_logits = self._norm_size_logits(size_ret, B) #
|
| 1070 |
|
| 1071 |
if use_soft_argmax:
|
| 1072 |
bins = torch.arange(w_logits.size(-1), device=device, dtype=torch.float32)
|
|
@@ -1076,16 +1075,12 @@ class MoondreamModel(nn.Module):
|
|
| 1076 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 1077 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
| 1078 |
|
| 1079 |
-
|
| 1080 |
-
|
| 1081 |
-
h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0) # (B,)
|
| 1082 |
|
| 1083 |
-
size_emb = encode_size(
|
| 1084 |
-
torch.stack([w, h], dim=1).to(dtype=w_logits.dtype), # (B,2)
|
| 1085 |
-
self.region
|
| 1086 |
-
).unsqueeze(1) # (B,1,C)
|
| 1087 |
|
| 1088 |
-
# record boxes only for
|
| 1089 |
for i in alive.nonzero(as_tuple=False).flatten().tolist():
|
| 1090 |
xl = (x_center[i] - w[i] / 2).item()
|
| 1091 |
xr = (x_center[i] + w[i] / 2).item()
|
|
@@ -1103,7 +1098,6 @@ class MoondreamModel(nn.Module):
|
|
| 1103 |
pos_ids[idx, 0] += 1
|
| 1104 |
next_tok = logits.argmax(dim=-1)
|
| 1105 |
else:
|
| 1106 |
-
# point mode
|
| 1107 |
for i in alive.nonzero(as_tuple=False).flatten().tolist():
|
| 1108 |
out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
|
| 1109 |
idx = _advance_rows(alive)
|
|
@@ -1111,16 +1105,22 @@ class MoondreamModel(nn.Module):
|
|
| 1111 |
pos_ids[idx, 0] += 1
|
| 1112 |
next_tok = logits.argmax(dim=-1)
|
| 1113 |
|
| 1114 |
-
# normalize next_tok to
|
| 1115 |
while next_tok.dim() > 1:
|
| 1116 |
next_tok = next_tok.squeeze(-1)
|
| 1117 |
|
|
|
|
| 1118 |
counts[alive] += 1
|
| 1119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1120 |
alive &= ~finished_now
|
| 1121 |
|
| 1122 |
return out
|
| 1123 |
|
|
|
|
| 1124 |
def detect_multi(self, image, objects, settings=None):
|
| 1125 |
if self.config.tokenizer.templates["detect"] is None:
|
| 1126 |
raise NotImplementedError("Model does not support object detection.")
|
|
@@ -1132,17 +1132,17 @@ class MoondreamModel(nn.Module):
|
|
| 1132 |
|
| 1133 |
lora = variant_state_dict(settings["variant"], device=self.device) if "variant" in settings else None
|
| 1134 |
|
| 1135 |
-
last_hidden, next_token,
|
| 1136 |
objects, enc.pos, lora=lora, temperature=0.0, top_p=0.0
|
| 1137 |
)
|
| 1138 |
-
|
| 1139 |
det_lists = self._generate_points_batched(
|
| 1140 |
-
last_hidden, next_token,
|
| 1141 |
include_size=True,
|
| 1142 |
max_objects=settings.get("max_objects", 50),
|
| 1143 |
lora=lora,
|
| 1144 |
)
|
| 1145 |
-
|
| 1146 |
res = {}
|
| 1147 |
for lab, lst in zip(objects, det_lists):
|
| 1148 |
for d in lst:
|
|
|
|
| 943 |
b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
|
| 944 |
|
| 945 |
|
| 946 |
+
def _prefill_prompt_batched(
|
| 947 |
+
self,
|
| 948 |
+
labels,
|
| 949 |
+
pos: int,
|
| 950 |
+
lora=None,
|
| 951 |
+
temperature: float = 0.0,
|
| 952 |
+
top_p: float = 0.0,
|
| 953 |
+
):
|
| 954 |
tpl = self.config.tokenizer.templates["detect"]
|
| 955 |
if tpl is None:
|
| 956 |
raise NotImplementedError("Model does not support object detection.")
|
|
|
|
| 960 |
ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
|
| 961 |
t = torch.tensor(ids, device=self.device, dtype=torch.long)
|
| 962 |
rows.append(t); lens.append(t.numel())
|
| 963 |
+
|
| 964 |
B, T = len(rows), max(lens)
|
| 965 |
eos = self.config.tokenizer.eos_id
|
| 966 |
|
| 967 |
+
# Pad with EOS in the tensor, but we will still start generation per-row at its own length
|
| 968 |
prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
|
| 969 |
for i, ids in enumerate(rows):
|
| 970 |
prompt_ids[i, : ids.numel()] = ids
|
| 971 |
|
| 972 |
+
prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
|
| 973 |
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
| 974 |
|
| 975 |
+
base = self.attn_mask[:, :, pos : pos + T, :] # (1,1,T,K)
|
| 976 |
+
mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
|
| 977 |
|
| 978 |
pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
|
| 979 |
+
hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
|
| 980 |
+
logits_BTV = lm_head(hidden_BTC, self.text) # (B,T,V)
|
| 981 |
|
| 982 |
+
# Gather last real token per row
|
| 983 |
+
idx = (torch.tensor(lens, device=self.device) - 1).clamp_min(0) # (B,)
|
| 984 |
last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
|
| 985 |
last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
|
| 986 |
|
| 987 |
if temperature == 0.0:
|
| 988 |
+
next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
|
| 989 |
else:
|
| 990 |
probs = torch.softmax(last_logits / temperature, dim=-1)
|
| 991 |
probs = self._apply_top_p(probs, top_p)
|
| 992 |
+
next_token = torch.multinomial(probs, num_samples=1) # (B,1)
|
| 993 |
+
|
| 994 |
+
# Per-row next positions (don’t force them all to pos+T)
|
| 995 |
+
pos_vec = (pos + torch.tensor(lens, device=self.device, dtype=torch.long)) # (B,)
|
| 996 |
|
| 997 |
+
return last_hidden, next_token, pos_vec
|
| 998 |
+
|
| 999 |
|
| 1000 |
|
| 1001 |
def _generate_points_batched(
|
| 1002 |
self,
|
| 1003 |
+
hidden, # (B,1,C)
|
| 1004 |
+
next_token, # (B,1) (unused for greedy)
|
| 1005 |
+
pos_vec, # (B,) next-free position per row
|
| 1006 |
include_size: bool = True,
|
| 1007 |
max_objects: int = 50,
|
| 1008 |
lora=None,
|
|
|
|
| 1012 |
device = self.device
|
| 1013 |
out = [[] for _ in range(B)]
|
| 1014 |
eos_id = self.config.tokenizer.eos_id
|
| 1015 |
+
coord_id = self.config.tokenizer.coord_id
|
| 1016 |
max_ctx = self.config.text.max_context
|
| 1017 |
|
| 1018 |
+
# Build per-row masks/positions
|
| 1019 |
mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
|
| 1020 |
+
pos_ids = pos_vec.clone().view(B, 1) # (B,1)
|
| 1021 |
+
for i in range(B):
|
| 1022 |
+
p0 = int(pos_ids[i, 0].item())
|
| 1023 |
+
if p0 > 0:
|
| 1024 |
+
mask[i, 0, 0, :p0] = True
|
| 1025 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1026 |
def _argmax01(logits: torch.Tensor) -> torch.Tensor:
|
| 1027 |
if use_soft_argmax:
|
| 1028 |
probs = torch.softmax(logits, dim=-1)
|
|
|
|
| 1031 |
idx = logits.argmax(dim=-1).to(torch.float32)
|
| 1032 |
return idx / float(logits.size(-1) - 1)
|
| 1033 |
|
|
|
|
| 1034 |
def _advance_rows(row_mask: torch.Tensor):
|
| 1035 |
idx = row_mask.nonzero(as_tuple=False).flatten()
|
|
|
|
| 1036 |
for i in idx.tolist():
|
| 1037 |
col = int(pos_ids[i, 0].item())
|
| 1038 |
mask[i, 0, 0, col] = True
|
|
|
|
| 1039 |
return idx
|
| 1040 |
|
| 1041 |
alive = torch.ones(B, dtype=torch.bool, device=device)
|
|
|
|
| 1043 |
|
| 1044 |
with torch.inference_mode():
|
| 1045 |
while alive.any() and (counts < max_objects).any():
|
| 1046 |
+
# -------- x --------
|
| 1047 |
+
x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
|
| 1048 |
+
if x_logits.dim() == 3: x_logits = x_logits.squeeze(1)
|
| 1049 |
+
x_center = _argmax01(x_logits) # (B,)
|
| 1050 |
+
x_emb = encode_coordinate(x_center.to(dtype=x_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1051 |
|
| 1052 |
idx = _advance_rows(alive)
|
| 1053 |
logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
|
| 1054 |
pos_ids[idx, 0] += 1
|
| 1055 |
|
| 1056 |
+
# -------- y --------
|
| 1057 |
y_logits = decode_coordinate(hidden, self.region)
|
| 1058 |
+
if y_logits.dim() == 3: y_logits = y_logits.squeeze(1)
|
|
|
|
| 1059 |
y_center = _argmax01(y_logits)
|
| 1060 |
+
y_emb = encode_coordinate(y_center.to(dtype=y_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
| 1061 |
|
| 1062 |
idx = _advance_rows(alive)
|
| 1063 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
|
| 1064 |
pos_ids[idx, 0] += 1
|
| 1065 |
|
| 1066 |
if include_size:
|
|
|
|
| 1067 |
size_ret = decode_size(hidden, self.region)
|
| 1068 |
+
w_logits, h_logits = self._norm_size_logits(size_ret, B) # (B,C)
|
| 1069 |
|
| 1070 |
if use_soft_argmax:
|
| 1071 |
bins = torch.arange(w_logits.size(-1), device=device, dtype=torch.float32)
|
|
|
|
| 1075 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 1076 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
| 1077 |
|
| 1078 |
+
w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
|
| 1079 |
+
h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
|
|
|
|
| 1080 |
|
| 1081 |
+
size_emb = encode_size(torch.stack([w, h], dim=1).to(dtype=w_logits.dtype), self.region).unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
| 1082 |
|
| 1083 |
+
# record boxes only for rows still alive
|
| 1084 |
for i in alive.nonzero(as_tuple=False).flatten().tolist():
|
| 1085 |
xl = (x_center[i] - w[i] / 2).item()
|
| 1086 |
xr = (x_center[i] + w[i] / 2).item()
|
|
|
|
| 1098 |
pos_ids[idx, 0] += 1
|
| 1099 |
next_tok = logits.argmax(dim=-1)
|
| 1100 |
else:
|
|
|
|
| 1101 |
for i in alive.nonzero(as_tuple=False).flatten().tolist():
|
| 1102 |
out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
|
| 1103 |
idx = _advance_rows(alive)
|
|
|
|
| 1105 |
pos_ids[idx, 0] += 1
|
| 1106 |
next_tok = logits.argmax(dim=-1)
|
| 1107 |
|
| 1108 |
+
# normalize next_tok to (B,)
|
| 1109 |
while next_tok.dim() > 1:
|
| 1110 |
next_tok = next_tok.squeeze(-1)
|
| 1111 |
|
| 1112 |
+
# we added exactly one object/point to all alive rows
|
| 1113 |
counts[alive] += 1
|
| 1114 |
+
|
| 1115 |
+
# GRAMMAR STOP: only continue if the model asks to start another coord;
|
| 1116 |
+
# otherwise stop row (covers EOS or any non-coord token).
|
| 1117 |
+
continue_mask = (next_tok == coord_id)
|
| 1118 |
+
finished_now = (~continue_mask) | (counts >= max_objects)
|
| 1119 |
alive &= ~finished_now
|
| 1120 |
|
| 1121 |
return out
|
| 1122 |
|
| 1123 |
+
|
| 1124 |
def detect_multi(self, image, objects, settings=None):
|
| 1125 |
if self.config.tokenizer.templates["detect"] is None:
|
| 1126 |
raise NotImplementedError("Model does not support object detection.")
|
|
|
|
| 1132 |
|
| 1133 |
lora = variant_state_dict(settings["variant"], device=self.device) if "variant" in settings else None
|
| 1134 |
|
| 1135 |
+
last_hidden, next_token, pos_vec = self._prefill_prompt_batched(
|
| 1136 |
objects, enc.pos, lora=lora, temperature=0.0, top_p=0.0
|
| 1137 |
)
|
| 1138 |
+
|
| 1139 |
det_lists = self._generate_points_batched(
|
| 1140 |
+
last_hidden, next_token, pos_vec,
|
| 1141 |
include_size=True,
|
| 1142 |
max_objects=settings.get("max_objects", 50),
|
| 1143 |
lora=lora,
|
| 1144 |
)
|
| 1145 |
+
|
| 1146 |
res = {}
|
| 1147 |
for lab, lst in zip(objects, det_lists):
|
| 1148 |
for d in lst:
|