HV-Khurdula commited on
Commit
cdfd7db
·
verified ·
1 Parent(s): a2fbee5

Update moondream.py

Browse files

fix: dimension mismatch for batched point generation

Files changed (1) hide show
  1. moondream.py +15 -14
moondream.py CHANGED
@@ -22,6 +22,10 @@ from .region import (
22
  from .layers import QuantizedLinear
23
  from .lora import variant_state_dict
24
  from .utils import remove_outlier_points
 
 
 
 
25
 
26
  ImageEncodingSettings = TypedDict(
27
  "ImageEncodingSettings",
@@ -851,9 +855,6 @@ class MoondreamModel(nn.Module):
851
  Build detect prompts for many labels, pad to same length, prefill once as a batch,
852
  then return (last_hidden per row, next_token per row, pos per row).
853
  """
854
- import torch
855
- from .text import text_encoder, lm_head
856
-
857
  tpl = self.config.tokenizer.templates["detect"]
858
  if tpl is None:
859
  raise NotImplementedError("Model does not support object detection (no detect template).")
@@ -873,7 +874,6 @@ class MoondreamModel(nn.Module):
873
 
874
  # Embed & prefill once
875
  prompt_emb = text_encoder(prompt_ids, self.text) # (B, T, C)
876
- import torch
877
  torch._dynamo.mark_dynamic(prompt_emb, 1) # allow variable T
878
 
879
  attn_mask = self.attn_mask
@@ -905,8 +905,6 @@ class MoondreamModel(nn.Module):
905
  for all rows in the batch simultaneously.
906
  Returns: list-of-lists of dicts, length B.
907
  """
908
- import torch
909
- from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
910
 
911
  B = hidden.size(0)
912
  device = self.device
@@ -930,8 +928,10 @@ class MoondreamModel(nn.Module):
930
  if x_logits.dim() == 3:
931
  x_logits = x_logits.squeeze(1) # (B, 1024)
932
  x_bin = x_logits.argmax(dim=-1).to(torch.float32) # (B,)
933
- x_center = x_bin / float(x_logits.size(-1)) # normalize to [0,1]
934
- x_emb = encode_coordinate(x_center.to(dtype=x_logits.dtype), self.region).unsqueeze(1) # (B,1,C)
 
 
935
 
936
  # step: decode to get hidden for y
937
  for i in range(B):
@@ -945,8 +945,10 @@ class MoondreamModel(nn.Module):
945
  if y_logits.dim() == 3:
946
  y_logits = y_logits.squeeze(1) # (B, 1024)
947
  y_bin = y_logits.argmax(dim=-1).to(torch.float32)
948
- y_center = y_bin / float(y_logits.size(-1))
949
- y_emb = encode_coordinate(y_center.to(dtype=y_logits.dtype), self.region).unsqueeze(1)
 
 
950
 
951
  # step: decode to get hidden for size (or eos)
952
  for i in range(B):
@@ -964,7 +966,8 @@ class MoondreamModel(nn.Module):
964
  # Convert from log-scale bin to size in [0,1]
965
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
966
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
967
- size_emb = encode_size(torch.stack([w, h], dim=0), self.region).transpose(0,1).unsqueeze(1) # (B,1,C)
 
968
 
969
  # Commit boxes for alive rows
970
  for i in range(B):
@@ -1015,8 +1018,7 @@ class MoondreamModel(nn.Module):
1015
  Returns:
1016
  {"objects": {label: [box_dict, ...]}}
1017
  """
1018
- import torch
1019
- from typing import Optional, List, Union
1020
 
1021
  if self.config.tokenizer.templates["detect"] is None:
1022
  raise NotImplementedError("Model does not support object detection.")
@@ -1030,7 +1032,6 @@ class MoondreamModel(nn.Module):
1030
  # Optional LoRA variant (same as detect())
1031
  lora = None
1032
  if "variant" in settings:
1033
- from .lora import variant_state_dict
1034
  lora = variant_state_dict(settings["variant"], device=self.device)
1035
 
1036
  # Prefill all prompts at once
 
22
  from .layers import QuantizedLinear
23
  from .lora import variant_state_dict
24
  from .utils import remove_outlier_points
25
+ from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
26
+ from .text import text_encoder, lm_head
27
+ from typing import Optional, List, Union
28
+ from .lora import variant_state_dict
29
 
30
  ImageEncodingSettings = TypedDict(
31
  "ImageEncodingSettings",
 
855
  Build detect prompts for many labels, pad to same length, prefill once as a batch,
856
  then return (last_hidden per row, next_token per row, pos per row).
857
  """
 
 
 
858
  tpl = self.config.tokenizer.templates["detect"]
859
  if tpl is None:
860
  raise NotImplementedError("Model does not support object detection (no detect template).")
 
874
 
875
  # Embed & prefill once
876
  prompt_emb = text_encoder(prompt_ids, self.text) # (B, T, C)
 
877
  torch._dynamo.mark_dynamic(prompt_emb, 1) # allow variable T
878
 
879
  attn_mask = self.attn_mask
 
905
  for all rows in the batch simultaneously.
906
  Returns: list-of-lists of dicts, length B.
907
  """
 
 
908
 
909
  B = hidden.size(0)
910
  device = self.device
 
928
  if x_logits.dim() == 3:
929
  x_logits = x_logits.squeeze(1) # (B, 1024)
930
  x_bin = x_logits.argmax(dim=-1).to(torch.float32) # (B,)
931
+ x_center = x_bin / float(x_logits.size(-1)) # (B,)
932
+ x_input = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B, 1)
933
+ x_emb = encode_coordinate(x_input, self.region).unsqueeze(1) # (B,1,C)
934
+
935
 
936
  # step: decode to get hidden for y
937
  for i in range(B):
 
945
  if y_logits.dim() == 3:
946
  y_logits = y_logits.squeeze(1) # (B, 1024)
947
  y_bin = y_logits.argmax(dim=-1).to(torch.float32)
948
+ y_center = y_bin / float(y_logits.size(-1)) # (B,)
949
+ y_input = y_center.to(dtype=y_logits.dtype).unsqueeze(-1) # (B, 1) ✅
950
+ y_emb = encode_coordinate(y_input, self.region).unsqueeze(1)
951
+
952
 
953
  # step: decode to get hidden for size (or eos)
954
  for i in range(B):
 
966
  # Convert from log-scale bin to size in [0,1]
967
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
968
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
969
+ size_input = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B, 2)
970
+ size_emb = encode_size(size_input, self.region).unsqueeze(1)
971
 
972
  # Commit boxes for alive rows
973
  for i in range(B):
 
1018
  Returns:
1019
  {"objects": {label: [box_dict, ...]}}
1020
  """
1021
+
 
1022
 
1023
  if self.config.tokenizer.templates["detect"] is None:
1024
  raise NotImplementedError("Model does not support object detection.")
 
1032
  # Optional LoRA variant (same as detect())
1033
  lora = None
1034
  if "variant" in settings:
 
1035
  lora = variant_state_dict(settings["variant"], device=self.device)
1036
 
1037
  # Prefill all prompts at once