Update moondream.py
Browse filesfix: decode batched (B>1) labels
- moondream.py +96 -112
moondream.py
CHANGED
|
@@ -850,84 +850,89 @@ class MoondreamModel(nn.Module):
|
|
| 850 |
b.kv_cache.k_cache[:, :, :T, :] = k.expand(batch_size, -1, -1, -1)
|
| 851 |
b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
|
| 852 |
|
| 853 |
-
def _prefill_prompt_batched(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 854 |
"""
|
| 855 |
-
Build detect prompts for many labels, pad to same length, prefill once as a batch
|
| 856 |
-
|
| 857 |
"""
|
| 858 |
tpl = self.config.tokenizer.templates["detect"]
|
| 859 |
if tpl is None:
|
| 860 |
raise NotImplementedError("Model does not support object detection (no detect template).")
|
| 861 |
-
|
| 862 |
rows, lens = [], []
|
| 863 |
for lab in labels:
|
| 864 |
ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
|
| 865 |
rows.append(torch.tensor(ids, device=self.device, dtype=torch.long))
|
| 866 |
lens.append(len(ids))
|
| 867 |
-
B = len(rows)
|
|
|
|
| 868 |
eos = self.config.tokenizer.eos_id
|
| 869 |
-
|
| 870 |
-
# Pad with eos so we can prefill
|
| 871 |
prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
|
| 872 |
for i, ids in enumerate(rows):
|
| 873 |
prompt_ids[i, : ids.numel()] = ids
|
| 874 |
-
|
| 875 |
# Embed & prefill once
|
| 876 |
-
prompt_emb = text_encoder(prompt_ids, self.text)
|
| 877 |
-
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
|
|
|
|
| 881 |
pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long)
|
| 882 |
-
|
| 883 |
hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B, T, C)
|
| 884 |
logits_BTV = lm_head(hidden_BTC, self.text) # (B, T, V)
|
| 885 |
-
|
| 886 |
-
# Take the last *real* token per row
|
| 887 |
idx = (torch.tensor(lens, device=self.device, dtype=torch.long) - 1).clamp_min(0)
|
| 888 |
-
last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,
|
| 889 |
-
last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,
|
| 890 |
-
|
| 891 |
if temperature == 0.0:
|
| 892 |
-
next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,
|
| 893 |
else:
|
| 894 |
probs = torch.softmax(last_logits / temperature, dim=-1)
|
| 895 |
probs = self._apply_top_p(probs, top_p)
|
| 896 |
-
next_token = torch.multinomial(probs, num_samples=1) # (B,
|
| 897 |
-
|
| 898 |
-
|
|
|
|
|
|
|
| 899 |
|
| 900 |
-
return last_hidden, next_token, pos_vec # (B,1,C), (B,1), (B,)
|
| 901 |
|
| 902 |
def _generate_points_batched(
|
| 903 |
self,
|
| 904 |
-
hidden
|
| 905 |
-
next_token
|
| 906 |
-
|
| 907 |
include_size: bool = True,
|
| 908 |
max_objects: int = 50,
|
| 909 |
-
lora=None
|
|
|
|
| 910 |
"""
|
| 911 |
-
|
| 912 |
-
|
| 913 |
-
- Uses a shared scalar position id per step (q_len = 1), as expected by RoPE.
|
| 914 |
-
- Maintains a per-row attention mask and 'alive' flags.
|
| 915 |
-
- Feeds coord encoders with (B,1) tensors; size encoder with (B,2).
|
| 916 |
-
Returns: list-of-lists of dicts, length B.
|
| 917 |
"""
|
| 918 |
B = hidden.size(0)
|
| 919 |
device = self.device
|
| 920 |
out = [[] for _ in range(B)]
|
| 921 |
eos_id = self.config.tokenizer.eos_id
|
| 922 |
-
|
| 923 |
-
# Per-row initial visibility up to each row's individual prefill pos
|
| 924 |
max_ctx = self.config.text.max_context
|
| 925 |
-
mask = torch.zeros(B, 1, max_ctx, device=device, dtype=torch.bool)
|
| 926 |
-
for i in range(B):
|
| 927 |
-
mask[i, :, : int(pos_vec[i].item())] = 1
|
| 928 |
|
| 929 |
-
#
|
| 930 |
-
|
|
|
|
|
|
|
|
|
|
| 931 |
|
| 932 |
alive = torch.ones(B, dtype=torch.bool, device=device)
|
| 933 |
counts = torch.zeros(B, dtype=torch.int32, device=device)
|
|
@@ -935,90 +940,72 @@ class MoondreamModel(nn.Module):
|
|
| 935 |
with torch.inference_mode():
|
| 936 |
while alive.any() and (counts < max_objects).any():
|
| 937 |
# --- x coordinate ---
|
| 938 |
-
x_logits = decode_coordinate(hidden, self.region)
|
| 939 |
if x_logits.dim() == 3:
|
| 940 |
-
x_logits = x_logits.squeeze(1)
|
| 941 |
-
x_bin = x_logits.argmax(dim=-1).to(torch.float32)
|
| 942 |
-
x_center = x_bin / float(x_logits.size(-1))
|
| 943 |
-
|
| 944 |
-
x_emb = encode_coordinate(
|
| 945 |
|
| 946 |
-
#
|
| 947 |
-
mask[
|
| 948 |
-
logits, hidden = self._decode_one_tok(
|
| 949 |
-
x_emb,
|
| 950 |
-
mask.unsqueeze(2), # (B,1,1,max_ctx) ✅
|
| 951 |
-
torch.tensor([pos], device=device, dtype=torch.long),
|
| 952 |
-
lora,
|
| 953 |
-
)
|
| 954 |
pos += 1
|
|
|
|
| 955 |
|
| 956 |
# --- y coordinate ---
|
| 957 |
y_logits = decode_coordinate(hidden, self.region)
|
| 958 |
if y_logits.dim() == 3:
|
| 959 |
-
y_logits = y_logits.squeeze(1)
|
| 960 |
y_bin = y_logits.argmax(dim=-1).to(torch.float32)
|
| 961 |
-
y_center = y_bin / float(y_logits.size(-1))
|
| 962 |
-
|
| 963 |
-
y_emb = encode_coordinate(
|
| 964 |
|
| 965 |
-
mask[
|
| 966 |
-
logits, hidden = self._decode_one_tok(
|
| 967 |
-
y_emb,
|
| 968 |
-
mask.unsqueeze(2), # (B,1,1,max_ctx) ✅
|
| 969 |
-
torch.tensor([pos], device=device, dtype=torch.long),
|
| 970 |
-
lora,
|
| 971 |
-
)
|
| 972 |
pos += 1
|
|
|
|
| 973 |
|
| 974 |
if include_size:
|
| 975 |
-
# --- size
|
| 976 |
size_logits = decode_size(hidden, self.region)
|
| 977 |
w_logits, h_logits = size_logits[0].squeeze(1), size_logits[1].squeeze(1)
|
| 978 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 979 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
| 980 |
-
#
|
| 981 |
w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
|
| 982 |
h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
|
| 983 |
-
|
| 984 |
-
size_emb = encode_size(
|
| 985 |
|
| 986 |
-
#
|
| 987 |
for i in range(B):
|
| 988 |
-
if
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
|
| 995 |
-
})
|
| 996 |
|
| 997 |
-
mask[
|
| 998 |
-
logits, hidden = self._decode_one_tok(
|
| 999 |
-
size_emb,
|
| 1000 |
-
mask.unsqueeze(2), # (B,1,1,max_ctx) ✅
|
| 1001 |
-
torch.tensor([pos], device=device, dtype=torch.long),
|
| 1002 |
-
lora,
|
| 1003 |
-
)
|
| 1004 |
pos += 1
|
| 1005 |
-
|
|
|
|
| 1006 |
else:
|
| 1007 |
-
# Points mode (no size)
|
| 1008 |
for i in range(B):
|
| 1009 |
if alive[i]:
|
| 1010 |
out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
|
| 1011 |
-
|
| 1012 |
-
|
| 1013 |
-
|
| 1014 |
-
mask.unsqueeze(2), # (B,1,1,max_ctx) ✅
|
| 1015 |
-
torch.tensor([pos], device=device, dtype=torch.long),
|
| 1016 |
-
lora,
|
| 1017 |
-
)
|
| 1018 |
pos += 1
|
| 1019 |
-
|
|
|
|
| 1020 |
|
| 1021 |
-
# Finish rows that emitted EOS or hit object cap
|
| 1022 |
finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
|
| 1023 |
counts = counts + (~finished_now & alive).to(counts.dtype)
|
| 1024 |
alive &= ~finished_now
|
|
@@ -1026,8 +1013,6 @@ class MoondreamModel(nn.Module):
|
|
| 1026 |
return out
|
| 1027 |
|
| 1028 |
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
def detect_multi(self, image, objects, settings=None):
|
| 1032 |
"""
|
| 1033 |
Parallel multi-label detection.
|
|
@@ -1038,35 +1023,33 @@ class MoondreamModel(nn.Module):
|
|
| 1038 |
Returns:
|
| 1039 |
{"objects": {label: [box_dict, ...]}}
|
| 1040 |
"""
|
| 1041 |
-
|
| 1042 |
-
|
| 1043 |
if self.config.tokenizer.templates["detect"] is None:
|
| 1044 |
raise NotImplementedError("Model does not support object detection.")
|
| 1045 |
settings = settings or {}
|
| 1046 |
-
|
| 1047 |
-
# Encode once; reuse caches
|
| 1048 |
image = self.encode_image(image, settings)
|
| 1049 |
B = len(objects)
|
| 1050 |
self._load_encoded_image_batched(image, B)
|
| 1051 |
-
|
| 1052 |
-
# Optional LoRA variant
|
| 1053 |
lora = None
|
| 1054 |
if "variant" in settings:
|
| 1055 |
lora = variant_state_dict(settings["variant"], device=self.device)
|
| 1056 |
-
|
| 1057 |
-
# Prefill all prompts
|
| 1058 |
-
last_hidden, next_token,
|
| 1059 |
objects, image.pos, lora=lora, temperature=0.0, top_p=0.0
|
| 1060 |
)
|
| 1061 |
-
|
| 1062 |
# Batched decode loop
|
| 1063 |
max_objects = settings.get("max_objects", 50)
|
| 1064 |
det_lists = self._generate_points_batched(
|
| 1065 |
-
last_hidden, next_token,
|
| 1066 |
include_size=True, max_objects=max_objects, lora=lora
|
| 1067 |
)
|
| 1068 |
-
|
| 1069 |
-
# Map back to labels and
|
| 1070 |
res = {}
|
| 1071 |
for lab, lst in zip(objects, det_lists):
|
| 1072 |
for d in lst:
|
|
@@ -1074,6 +1057,7 @@ class MoondreamModel(nn.Module):
|
|
| 1074 |
res[lab] = lst
|
| 1075 |
return {"objects": res}
|
| 1076 |
|
|
|
|
| 1077 |
def _detect_gaze(
|
| 1078 |
self,
|
| 1079 |
image: EncodedImage,
|
|
|
|
| 850 |
b.kv_cache.k_cache[:, :, :T, :] = k.expand(batch_size, -1, -1, -1)
|
| 851 |
b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
|
| 852 |
|
| 853 |
+
def _prefill_prompt_batched(
|
| 854 |
+
self,
|
| 855 |
+
labels,
|
| 856 |
+
pos: int,
|
| 857 |
+
lora=None,
|
| 858 |
+
temperature: float = 0.0,
|
| 859 |
+
top_p: float = 0.0,
|
| 860 |
+
):
|
| 861 |
"""
|
| 862 |
+
Build detect prompts for many labels, pad to the same length, prefill once as a batch.
|
| 863 |
+
Returns (last_hidden per row, next_token per row, shared_pos_end scalar).
|
| 864 |
"""
|
| 865 |
tpl = self.config.tokenizer.templates["detect"]
|
| 866 |
if tpl is None:
|
| 867 |
raise NotImplementedError("Model does not support object detection (no detect template).")
|
| 868 |
+
|
| 869 |
rows, lens = [], []
|
| 870 |
for lab in labels:
|
| 871 |
ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
|
| 872 |
rows.append(torch.tensor(ids, device=self.device, dtype=torch.long))
|
| 873 |
lens.append(len(ids))
|
| 874 |
+
B = len(rows)
|
| 875 |
+
T = max(lens)
|
| 876 |
eos = self.config.tokenizer.eos_id
|
| 877 |
+
|
| 878 |
+
# Pad to T with eos, so we can prefill with a single shared position range
|
| 879 |
prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
|
| 880 |
for i, ids in enumerate(rows):
|
| 881 |
prompt_ids[i, : ids.numel()] = ids
|
| 882 |
+
|
| 883 |
# Embed & prefill once
|
| 884 |
+
prompt_emb = text_encoder(prompt_ids, self.text) # (B, T, C)
|
| 885 |
+
torch._dynamo.mark_dynamic(prompt_emb, 1) # allow variable T
|
| 886 |
+
|
| 887 |
+
# 4-D mask form makes head broadcasting unambiguous later
|
| 888 |
+
attn = self.attn_mask
|
| 889 |
+
mask = attn[:, :, pos : pos + T, :].expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
|
| 890 |
pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long)
|
| 891 |
+
|
| 892 |
hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B, T, C)
|
| 893 |
logits_BTV = lm_head(hidden_BTC, self.text) # (B, T, V)
|
| 894 |
+
|
| 895 |
+
# Take the last *real* token per row
|
| 896 |
idx = (torch.tensor(lens, device=self.device, dtype=torch.long) - 1).clamp_min(0)
|
| 897 |
+
last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
|
| 898 |
+
last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
|
| 899 |
+
|
| 900 |
if temperature == 0.0:
|
| 901 |
+
next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
|
| 902 |
else:
|
| 903 |
probs = torch.softmax(last_logits / temperature, dim=-1)
|
| 904 |
probs = self._apply_top_p(probs, top_p)
|
| 905 |
+
next_token = torch.multinomial(probs, num_samples=1) # (B,1)
|
| 906 |
+
|
| 907 |
+
# Shared "next decode position" for all rows (we prefilled up to pos+T-1)
|
| 908 |
+
pos_end = pos + T
|
| 909 |
+
return last_hidden, next_token, pos_end # (B,1,C), (B,1), int
|
| 910 |
|
|
|
|
| 911 |
|
| 912 |
def _generate_points_batched(
|
| 913 |
self,
|
| 914 |
+
hidden, # (B,1,C)
|
| 915 |
+
next_token, # (B,1)
|
| 916 |
+
pos: int, # shared scalar next position
|
| 917 |
include_size: bool = True,
|
| 918 |
max_objects: int = 50,
|
| 919 |
+
lora=None,
|
| 920 |
+
):
|
| 921 |
"""
|
| 922 |
+
Vectorized version of _generate_points() that decodes x -> y -> size -> next-token
|
| 923 |
+
for all rows in the batch simultaneously. Returns list-of-lists of dicts, len B.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 924 |
"""
|
| 925 |
B = hidden.size(0)
|
| 926 |
device = self.device
|
| 927 |
out = [[] for _ in range(B)]
|
| 928 |
eos_id = self.config.tokenizer.eos_id
|
|
|
|
|
|
|
| 929 |
max_ctx = self.config.text.max_context
|
|
|
|
|
|
|
|
|
|
| 930 |
|
| 931 |
+
# 4-D mask: (B, 1, q_len=1, kv_len)
|
| 932 |
+
mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
|
| 933 |
+
if pos > 0:
|
| 934 |
+
mask[:, :, :, :pos] = True
|
| 935 |
+
pos_id = torch.tensor([pos], device=device, dtype=torch.long) # (1,)
|
| 936 |
|
| 937 |
alive = torch.ones(B, dtype=torch.bool, device=device)
|
| 938 |
counts = torch.zeros(B, dtype=torch.int32, device=device)
|
|
|
|
| 940 |
with torch.inference_mode():
|
| 941 |
while alive.any() and (counts < max_objects).any():
|
| 942 |
# --- x coordinate ---
|
| 943 |
+
x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
|
| 944 |
if x_logits.dim() == 3:
|
| 945 |
+
x_logits = x_logits.squeeze(1)
|
| 946 |
+
x_bin = x_logits.argmax(dim=-1).to(torch.float32) # (B,)
|
| 947 |
+
x_center = x_bin / float(x_logits.size(-1)) # (B,)
|
| 948 |
+
x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
|
| 949 |
+
x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
|
| 950 |
|
| 951 |
+
# advance attention one step
|
| 952 |
+
mask[:, :, :, pos] = True
|
| 953 |
+
logits, hidden = self._decode_one_tok(x_emb, mask, pos_id, lora)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 954 |
pos += 1
|
| 955 |
+
pos_id[0] = pos
|
| 956 |
|
| 957 |
# --- y coordinate ---
|
| 958 |
y_logits = decode_coordinate(hidden, self.region)
|
| 959 |
if y_logits.dim() == 3:
|
| 960 |
+
y_logits = y_logits.squeeze(1)
|
| 961 |
y_bin = y_logits.argmax(dim=-1).to(torch.float32)
|
| 962 |
+
y_center = y_bin / float(y_logits.size(-1)) # (B,)
|
| 963 |
+
y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1) # (B,1)
|
| 964 |
+
y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
|
| 965 |
|
| 966 |
+
mask[:, :, :, pos] = True
|
| 967 |
+
logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 968 |
pos += 1
|
| 969 |
+
pos_id[0] = pos
|
| 970 |
|
| 971 |
if include_size:
|
| 972 |
+
# --- size ---
|
| 973 |
size_logits = decode_size(hidden, self.region)
|
| 974 |
w_logits, h_logits = size_logits[0].squeeze(1), size_logits[1].squeeze(1)
|
| 975 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 976 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
| 977 |
+
# bins -> size in [0,1] (inverse of log-scale mapping)
|
| 978 |
w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
|
| 979 |
h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
|
| 980 |
+
size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
|
| 981 |
+
size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
|
| 982 |
|
| 983 |
+
# record boxes
|
| 984 |
for i in range(B):
|
| 985 |
+
if alive[i]:
|
| 986 |
+
out[i].append({
|
| 987 |
+
"x_min": (x_center[i] - w[i] / 2).item(),
|
| 988 |
+
"y_min": (y_center[i] - h[i] / 2).item(),
|
| 989 |
+
"x_max": (x_center[i] + w[i] / 2).item(),
|
| 990 |
+
"y_max": (y_center[i] + h[i] / 2).item(),
|
| 991 |
+
})
|
|
|
|
| 992 |
|
| 993 |
+
mask[:, :, :, pos] = True
|
| 994 |
+
logits, hidden = self._decode_one_tok(size_emb, mask, pos_id, lora)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 995 |
pos += 1
|
| 996 |
+
pos_id[0] = pos
|
| 997 |
+
next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
|
| 998 |
else:
|
|
|
|
| 999 |
for i in range(B):
|
| 1000 |
if alive[i]:
|
| 1001 |
out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
|
| 1002 |
+
|
| 1003 |
+
mask[:, :, :, pos] = True
|
| 1004 |
+
logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1005 |
pos += 1
|
| 1006 |
+
pos_id[0] = pos
|
| 1007 |
+
next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
|
| 1008 |
|
|
|
|
| 1009 |
finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
|
| 1010 |
counts = counts + (~finished_now & alive).to(counts.dtype)
|
| 1011 |
alive &= ~finished_now
|
|
|
|
| 1013 |
return out
|
| 1014 |
|
| 1015 |
|
|
|
|
|
|
|
| 1016 |
def detect_multi(self, image, objects, settings=None):
|
| 1017 |
"""
|
| 1018 |
Parallel multi-label detection.
|
|
|
|
| 1023 |
Returns:
|
| 1024 |
{"objects": {label: [box_dict, ...]}}
|
| 1025 |
"""
|
|
|
|
|
|
|
| 1026 |
if self.config.tokenizer.templates["detect"] is None:
|
| 1027 |
raise NotImplementedError("Model does not support object detection.")
|
| 1028 |
settings = settings or {}
|
| 1029 |
+
|
| 1030 |
+
# Encode once; reuse caches for B rows
|
| 1031 |
image = self.encode_image(image, settings)
|
| 1032 |
B = len(objects)
|
| 1033 |
self._load_encoded_image_batched(image, B)
|
| 1034 |
+
|
| 1035 |
+
# Optional LoRA variant
|
| 1036 |
lora = None
|
| 1037 |
if "variant" in settings:
|
| 1038 |
lora = variant_state_dict(settings["variant"], device=self.device)
|
| 1039 |
+
|
| 1040 |
+
# Prefill all prompts as a batch; shared next position
|
| 1041 |
+
last_hidden, next_token, pos_end = self._prefill_prompt_batched(
|
| 1042 |
objects, image.pos, lora=lora, temperature=0.0, top_p=0.0
|
| 1043 |
)
|
| 1044 |
+
|
| 1045 |
# Batched decode loop
|
| 1046 |
max_objects = settings.get("max_objects", 50)
|
| 1047 |
det_lists = self._generate_points_batched(
|
| 1048 |
+
last_hidden, next_token, pos_end,
|
| 1049 |
include_size=True, max_objects=max_objects, lora=lora
|
| 1050 |
)
|
| 1051 |
+
|
| 1052 |
+
# Map back to labels and tag
|
| 1053 |
res = {}
|
| 1054 |
for lab, lst in zip(objects, det_lists):
|
| 1055 |
for d in lst:
|
|
|
|
| 1057 |
res[lab] = lst
|
| 1058 |
return {"objects": res}
|
| 1059 |
|
| 1060 |
+
|
| 1061 |
def _detect_gaze(
|
| 1062 |
self,
|
| 1063 |
image: EncodedImage,
|