HV-Khurdula commited on
Commit
9a7633c
·
verified ·
1 Parent(s): aeef384

Update moondream.py

Browse files

fix: _generate_points_batched to advance the mask per_row & never the batch to row_0

Files changed (1) hide show
  1. moondream.py +31 -35
moondream.py CHANGED
@@ -993,7 +993,7 @@ class MoondreamModel(nn.Module):
993
  include_size: bool = True,
994
  max_objects: int = 50,
995
  lora=None,
996
- use_soft_argmax: bool = True, # reduces jitter/hallucinations
997
  ):
998
  B = hidden.size(0)
999
  device = self.device
@@ -1006,6 +1006,7 @@ class MoondreamModel(nn.Module):
1006
  p0 = int(pos)
1007
  if p0 > 0:
1008
  mask[:, :, :, :p0] = True
 
1009
  # per-row position ids (B,1)
1010
  pos_ids = torch.full((B, 1), p0, device=device, dtype=torch.long)
1011
 
@@ -1018,41 +1019,49 @@ class MoondreamModel(nn.Module):
1018
  idx = logits.argmax(dim=-1).to(torch.float32)
1019
  return idx / float(logits.size(-1) - 1)
1020
 
 
 
 
 
 
 
 
 
 
 
1021
  alive = torch.ones(B, dtype=torch.bool, device=device)
1022
  counts = torch.zeros(B, dtype=torch.int32, device=device)
1023
 
1024
  with torch.inference_mode():
1025
  while alive.any() and (counts < max_objects).any():
 
1026
  # ---------------- x ----------------
1027
  x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
1028
  if x_logits.dim() == 3:
1029
  x_logits = x_logits.squeeze(1) # -> (B,1024)
1030
  x_center = _argmax01(x_logits) # (B,)
1031
  x_emb = encode_coordinate(
1032
- x_center.to(dtype=x_logits.dtype).unsqueeze(-1), # (B,1)
1033
  self.region
1034
  ).unsqueeze(1) # (B,1,C)
1035
 
1036
- # advance one token for ALIVE rows only
1037
- step_col = int(pos_ids[0, 0].item())
1038
- mask[alive, :, :, step_col] = True
1039
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
1040
- pos_ids[alive, 0] += 1
1041
 
1042
  # ---------------- y ----------------
1043
  y_logits = decode_coordinate(hidden, self.region)
1044
  if y_logits.dim() == 3:
1045
- y_logits = y_logits.squeeze(1) # (B,1024)
1046
- y_center = _argmax01(y_logits) # (B,)
1047
  y_emb = encode_coordinate(
1048
  y_center.to(dtype=y_logits.dtype).unsqueeze(-1),
1049
  self.region
1050
- ).unsqueeze(1) # (B,1,C)
1051
 
1052
- step_col = int(pos_ids[0, 0].item())
1053
- mask[alive, :, :, step_col] = True
1054
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1055
- pos_ids[alive, 0] += 1
1056
 
1057
  if include_size:
1058
  # ------------- size (w,h) -------------
@@ -1077,8 +1086,7 @@ class MoondreamModel(nn.Module):
1077
  ).unsqueeze(1) # (B,1,C)
1078
 
1079
  # record boxes only for ALIVE rows
1080
- alive_idx = alive.nonzero(as_tuple=False).view(-1)
1081
- for i in alive_idx.tolist():
1082
  xl = (x_center[i] - w[i] / 2).item()
1083
  xr = (x_center[i] + w[i] / 2).item()
1084
  yt = (y_center[i] - h[i] / 2).item()
@@ -1090,41 +1098,29 @@ class MoondreamModel(nn.Module):
1090
  "y_max": max(0.0, min(1.0, yb)),
1091
  })
1092
 
1093
- step_col = int(pos_ids[0, 0].item())
1094
- mask[alive, :, :, step_col] = True
1095
  logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
1096
- pos_ids[alive, 0] += 1
1097
  next_tok = logits.argmax(dim=-1)
1098
- if next_tok.dim() == 3: # (B,1,1) possible
1099
- next_tok = next_tok.squeeze(-1).squeeze(-1)
1100
- elif next_tok.dim() == 2: # (B,1)
1101
- next_tok = next_tok.squeeze(1)
1102
  else:
1103
  # point mode
1104
- alive_idx = alive.nonzero(as_tuple=False).view(-1)
1105
- for i in alive_idx.tolist():
1106
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1107
- step_col = int(pos_ids[0, 0].item())
1108
- mask[alive, :, :, step_col] = True
1109
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1110
- pos_ids[alive, 0] += 1
1111
  next_tok = logits.argmax(dim=-1)
1112
- if next_tok.dim() == 3:
1113
- next_tok = next_tok.squeeze(-1).squeeze(-1)
1114
- elif next_tok.dim() == 2:
1115
- next_tok = next_tok.squeeze(1)
1116
 
1117
- # we added one object/point for all ALIVE rows this iteration
1118
- counts[alive] += 1
 
1119
 
1120
- # stop rows that hit eos OR reached max_objects
1121
  finished_now = (next_tok == eos_id) | (counts >= max_objects)
1122
  alive &= ~finished_now
1123
 
1124
  return out
1125
 
1126
-
1127
-
1128
  def detect_multi(self, image, objects, settings=None):
1129
  if self.config.tokenizer.templates["detect"] is None:
1130
  raise NotImplementedError("Model does not support object detection.")
 
993
  include_size: bool = True,
994
  max_objects: int = 50,
995
  lora=None,
996
+ use_soft_argmax: bool = True,
997
  ):
998
  B = hidden.size(0)
999
  device = self.device
 
1006
  p0 = int(pos)
1007
  if p0 > 0:
1008
  mask[:, :, :, :p0] = True
1009
+
1010
  # per-row position ids (B,1)
1011
  pos_ids = torch.full((B, 1), p0, device=device, dtype=torch.long)
1012
 
 
1019
  idx = logits.argmax(dim=-1).to(torch.float32)
1020
  return idx / float(logits.size(-1) - 1)
1021
 
1022
+ # advance-one-step for a subset of rows (alive only)
1023
+ def _advance_rows(row_mask: torch.Tensor):
1024
+ idx = row_mask.nonzero(as_tuple=False).flatten()
1025
+ # set each row's next KV column true
1026
+ for i in idx.tolist():
1027
+ col = int(pos_ids[i, 0].item())
1028
+ mask[i, 0, 0, col] = True
1029
+ # decoder step (all rows run, but only alive rows’ pos_ids move)
1030
+ return idx
1031
+
1032
  alive = torch.ones(B, dtype=torch.bool, device=device)
1033
  counts = torch.zeros(B, dtype=torch.int32, device=device)
1034
 
1035
  with torch.inference_mode():
1036
  while alive.any() and (counts < max_objects).any():
1037
+
1038
  # ---------------- x ----------------
1039
  x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
1040
  if x_logits.dim() == 3:
1041
  x_logits = x_logits.squeeze(1) # -> (B,1024)
1042
  x_center = _argmax01(x_logits) # (B,)
1043
  x_emb = encode_coordinate(
1044
+ x_center.to(dtype=x_logits.dtype).unsqueeze(-1), # (B,1)
1045
  self.region
1046
  ).unsqueeze(1) # (B,1,C)
1047
 
1048
+ idx = _advance_rows(alive)
 
 
1049
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
1050
+ pos_ids[idx, 0] += 1
1051
 
1052
  # ---------------- y ----------------
1053
  y_logits = decode_coordinate(hidden, self.region)
1054
  if y_logits.dim() == 3:
1055
+ y_logits = y_logits.squeeze(1)
1056
+ y_center = _argmax01(y_logits)
1057
  y_emb = encode_coordinate(
1058
  y_center.to(dtype=y_logits.dtype).unsqueeze(-1),
1059
  self.region
1060
+ ).unsqueeze(1)
1061
 
1062
+ idx = _advance_rows(alive)
 
1063
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1064
+ pos_ids[idx, 0] += 1
1065
 
1066
  if include_size:
1067
  # ------------- size (w,h) -------------
 
1086
  ).unsqueeze(1) # (B,1,C)
1087
 
1088
  # record boxes only for ALIVE rows
1089
+ for i in alive.nonzero(as_tuple=False).flatten().tolist():
 
1090
  xl = (x_center[i] - w[i] / 2).item()
1091
  xr = (x_center[i] + w[i] / 2).item()
1092
  yt = (y_center[i] - h[i] / 2).item()
 
1098
  "y_max": max(0.0, min(1.0, yb)),
1099
  })
1100
 
1101
+ idx = _advance_rows(alive)
 
1102
  logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
1103
+ pos_ids[idx, 0] += 1
1104
  next_tok = logits.argmax(dim=-1)
 
 
 
 
1105
  else:
1106
  # point mode
1107
+ for i in alive.nonzero(as_tuple=False).flatten().tolist():
 
1108
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1109
+ idx = _advance_rows(alive)
 
1110
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1111
+ pos_ids[idx, 0] += 1
1112
  next_tok = logits.argmax(dim=-1)
 
 
 
 
1113
 
1114
+ # normalize next_tok to shape (B,)
1115
+ while next_tok.dim() > 1:
1116
+ next_tok = next_tok.squeeze(-1)
1117
 
1118
+ counts[alive] += 1
1119
  finished_now = (next_tok == eos_id) | (counts >= max_objects)
1120
  alive &= ~finished_now
1121
 
1122
  return out
1123
 
 
 
1124
  def detect_multi(self, image, objects, settings=None):
1125
  if self.config.tokenizer.templates["detect"] is None:
1126
  raise NotImplementedError("Model does not support object detection.")