Update moondream.py
Browse filesfix: dimension mismatch for batched point generation
- moondream.py +15 -14
moondream.py
CHANGED
|
@@ -22,6 +22,10 @@ from .region import (
|
|
| 22 |
from .layers import QuantizedLinear
|
| 23 |
from .lora import variant_state_dict
|
| 24 |
from .utils import remove_outlier_points
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
ImageEncodingSettings = TypedDict(
|
| 27 |
"ImageEncodingSettings",
|
|
@@ -851,9 +855,6 @@ class MoondreamModel(nn.Module):
|
|
| 851 |
Build detect prompts for many labels, pad to same length, prefill once as a batch,
|
| 852 |
then return (last_hidden per row, next_token per row, pos per row).
|
| 853 |
"""
|
| 854 |
-
import torch
|
| 855 |
-
from .text import text_encoder, lm_head
|
| 856 |
-
|
| 857 |
tpl = self.config.tokenizer.templates["detect"]
|
| 858 |
if tpl is None:
|
| 859 |
raise NotImplementedError("Model does not support object detection (no detect template).")
|
|
@@ -873,7 +874,6 @@ class MoondreamModel(nn.Module):
|
|
| 873 |
|
| 874 |
# Embed & prefill once
|
| 875 |
prompt_emb = text_encoder(prompt_ids, self.text) # (B, T, C)
|
| 876 |
-
import torch
|
| 877 |
torch._dynamo.mark_dynamic(prompt_emb, 1) # allow variable T
|
| 878 |
|
| 879 |
attn_mask = self.attn_mask
|
|
@@ -905,8 +905,6 @@ class MoondreamModel(nn.Module):
|
|
| 905 |
for all rows in the batch simultaneously.
|
| 906 |
Returns: list-of-lists of dicts, length B.
|
| 907 |
"""
|
| 908 |
-
import torch
|
| 909 |
-
from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
|
| 910 |
|
| 911 |
B = hidden.size(0)
|
| 912 |
device = self.device
|
|
@@ -930,8 +928,10 @@ class MoondreamModel(nn.Module):
|
|
| 930 |
if x_logits.dim() == 3:
|
| 931 |
x_logits = x_logits.squeeze(1) # (B, 1024)
|
| 932 |
x_bin = x_logits.argmax(dim=-1).to(torch.float32) # (B,)
|
| 933 |
-
x_center = x_bin / float(x_logits.size(-1))
|
| 934 |
-
|
|
|
|
|
|
|
| 935 |
|
| 936 |
# step: decode to get hidden for y
|
| 937 |
for i in range(B):
|
|
@@ -945,8 +945,10 @@ class MoondreamModel(nn.Module):
|
|
| 945 |
if y_logits.dim() == 3:
|
| 946 |
y_logits = y_logits.squeeze(1) # (B, 1024)
|
| 947 |
y_bin = y_logits.argmax(dim=-1).to(torch.float32)
|
| 948 |
-
y_center = y_bin / float(y_logits.size(-1))
|
| 949 |
-
|
|
|
|
|
|
|
| 950 |
|
| 951 |
# step: decode to get hidden for size (or eos)
|
| 952 |
for i in range(B):
|
|
@@ -964,7 +966,8 @@ class MoondreamModel(nn.Module):
|
|
| 964 |
# Convert from log-scale bin to size in [0,1]
|
| 965 |
w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
|
| 966 |
h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
|
| 967 |
-
|
|
|
|
| 968 |
|
| 969 |
# Commit boxes for alive rows
|
| 970 |
for i in range(B):
|
|
@@ -1015,8 +1018,7 @@ class MoondreamModel(nn.Module):
|
|
| 1015 |
Returns:
|
| 1016 |
{"objects": {label: [box_dict, ...]}}
|
| 1017 |
"""
|
| 1018 |
-
|
| 1019 |
-
from typing import Optional, List, Union
|
| 1020 |
|
| 1021 |
if self.config.tokenizer.templates["detect"] is None:
|
| 1022 |
raise NotImplementedError("Model does not support object detection.")
|
|
@@ -1030,7 +1032,6 @@ class MoondreamModel(nn.Module):
|
|
| 1030 |
# Optional LoRA variant (same as detect())
|
| 1031 |
lora = None
|
| 1032 |
if "variant" in settings:
|
| 1033 |
-
from .lora import variant_state_dict
|
| 1034 |
lora = variant_state_dict(settings["variant"], device=self.device)
|
| 1035 |
|
| 1036 |
# Prefill all prompts at once
|
|
|
|
| 22 |
from .layers import QuantizedLinear
|
| 23 |
from .lora import variant_state_dict
|
| 24 |
from .utils import remove_outlier_points
|
| 25 |
+
from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
|
| 26 |
+
from .text import text_encoder, lm_head
|
| 27 |
+
from typing import Optional, List, Union
|
| 28 |
+
from .lora import variant_state_dict
|
| 29 |
|
| 30 |
ImageEncodingSettings = TypedDict(
|
| 31 |
"ImageEncodingSettings",
|
|
|
|
| 855 |
Build detect prompts for many labels, pad to same length, prefill once as a batch,
|
| 856 |
then return (last_hidden per row, next_token per row, pos per row).
|
| 857 |
"""
|
|
|
|
|
|
|
|
|
|
| 858 |
tpl = self.config.tokenizer.templates["detect"]
|
| 859 |
if tpl is None:
|
| 860 |
raise NotImplementedError("Model does not support object detection (no detect template).")
|
|
|
|
| 874 |
|
| 875 |
# Embed & prefill once
|
| 876 |
prompt_emb = text_encoder(prompt_ids, self.text) # (B, T, C)
|
|
|
|
| 877 |
torch._dynamo.mark_dynamic(prompt_emb, 1) # allow variable T
|
| 878 |
|
| 879 |
attn_mask = self.attn_mask
|
|
|
|
| 905 |
for all rows in the batch simultaneously.
|
| 906 |
Returns: list-of-lists of dicts, length B.
|
| 907 |
"""
|
|
|
|
|
|
|
| 908 |
|
| 909 |
B = hidden.size(0)
|
| 910 |
device = self.device
|
|
|
|
| 928 |
if x_logits.dim() == 3:
|
| 929 |
x_logits = x_logits.squeeze(1) # (B, 1024)
|
| 930 |
x_bin = x_logits.argmax(dim=-1).to(torch.float32) # (B,)
|
| 931 |
+
x_center = x_bin / float(x_logits.size(-1)) # (B,)
|
| 932 |
+
x_input = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B, 1) ✅
|
| 933 |
+
x_emb = encode_coordinate(x_input, self.region).unsqueeze(1) # (B,1,C)
|
| 934 |
+
|
| 935 |
|
| 936 |
# step: decode to get hidden for y
|
| 937 |
for i in range(B):
|
|
|
|
| 945 |
if y_logits.dim() == 3:
|
| 946 |
y_logits = y_logits.squeeze(1) # (B, 1024)
|
| 947 |
y_bin = y_logits.argmax(dim=-1).to(torch.float32)
|
| 948 |
+
y_center = y_bin / float(y_logits.size(-1)) # (B,)
|
| 949 |
+
y_input = y_center.to(dtype=y_logits.dtype).unsqueeze(-1) # (B, 1) ✅
|
| 950 |
+
y_emb = encode_coordinate(y_input, self.region).unsqueeze(1)
|
| 951 |
+
|
| 952 |
|
| 953 |
# step: decode to get hidden for size (or eos)
|
| 954 |
for i in range(B):
|
|
|
|
| 966 |
# Convert from log-scale bin to size in [0,1]
|
| 967 |
w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
|
| 968 |
h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
|
| 969 |
+
size_input = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B, 2) ✅
|
| 970 |
+
size_emb = encode_size(size_input, self.region).unsqueeze(1)
|
| 971 |
|
| 972 |
# Commit boxes for alive rows
|
| 973 |
for i in range(B):
|
|
|
|
| 1018 |
Returns:
|
| 1019 |
{"objects": {label: [box_dict, ...]}}
|
| 1020 |
"""
|
| 1021 |
+
|
|
|
|
| 1022 |
|
| 1023 |
if self.config.tokenizer.templates["detect"] is None:
|
| 1024 |
raise NotImplementedError("Model does not support object detection.")
|
|
|
|
| 1032 |
# Optional LoRA variant (same as detect())
|
| 1033 |
lora = None
|
| 1034 |
if "variant" in settings:
|
|
|
|
| 1035 |
lora = variant_state_dict(settings["variant"], device=self.device)
|
| 1036 |
|
| 1037 |
# Prefill all prompts at once
|