HV-Khurdula commited on
Commit
fc77069
·
verified ·
1 Parent(s): 11ca2e9

Update moondream.py

Browse files
Files changed (1) hide show
  1. moondream.py +36 -32
moondream.py CHANGED
@@ -77,60 +77,64 @@ class KVCache(nn.Module):
77
  "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
78
  )
79
 
80
- # in class KVCache
81
  def update(self, pos_ids, k, v):
82
  """
83
  Supports both:
84
- - PREFILL: pos_ids shape == (q_len,), k/v shape == (B, H, q_len, D)
85
- - STEP-DECODE (batched): pos_ids shape == (B,), k/v shape == (B, H, 1, D)
86
- - STEP-DECODE (single): scalar pos_ids, k/v shape == (1, H, 1, D)
87
  """
88
- kout, vout = self.k_cache, self.v_cache # (Bcache, H, T, D)
89
- B, H, Q, D = k.shape
90
 
91
- # Case A: PREFILL — a vector of all time indices
92
- if torch.is_tensor(pos_ids) and pos_ids.ndim == 1 and pos_ids.numel() == Q and Q > 1:
93
- # broadcast batch dimension into cache if needed
94
- if kout.size(0) != B:
95
- # grow/shrink the first dim to match B (this happens after you cloned
96
- # image caches to B rows for batched prefill)
97
- new_k = kout.new_zeros((B,) + tuple(kout.shape[1:]))
98
- new_v = vout.new_zeros((B,) + tuple(vout.shape[1:]))
99
- # copy row 0 as base (image prefix) into all rows
100
- new_k[:] = kout[0]
101
- new_v[:] = vout[0]
102
- self.k_cache = kout = new_k
103
- self.v_cache = vout = new_v
104
-
105
- # write the whole segment for all rows at once
106
  kout[:, :, pos_ids, :] = k
107
  vout[:, :, pos_ids, :] = v
108
  return kout, vout
109
 
110
- # Case B: STEP-DECODE (batched) — one position per row, q_len == 1
111
- if torch.is_tensor(pos_ids) and pos_ids.ndim == 1 and pos_ids.numel() == B and Q == 1:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  for i in range(B):
113
- pi = int(pos_ids[i])
114
  kout[i, :, pi, :] = k[i, :, 0, :]
115
  vout[i, :, pi, :] = v[i, :, 0, :]
116
  return kout, vout
117
 
118
- # Case C: STEP-DECODE (single) — scalar pos, B==1, q_len==1
119
- if (not torch.is_tensor(pos_ids)) or pos_ids.ndim == 0:
120
- pi = int(pos_ids)
121
- kout[:B, :, pi, :] = k[:, :, 0, :]
122
- vout[:B, :, pi, :] = v[:, :, 0, :]
123
  return kout, vout
124
 
125
- # Fallback: shape combo we didn't expect
126
  raise RuntimeError(
127
- f"KVCache.update: unsupported shapes pos_ids={tuple(pos_ids.shape) if torch.is_tensor(pos_ids) else '()'}, "
128
- f"k={tuple(k.shape)}, v={tuple(v.shape)}"
129
  )
130
 
131
 
132
 
133
 
 
134
  class MoondreamModel(nn.Module):
135
 
136
  def __init__(
 
77
  "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
78
  )
79
 
80
+ # --- replace the whole method in KVCache ---
81
  def update(self, pos_ids, k, v):
82
  """
83
  Supports both:
84
+ Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
85
+ 1-step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,)
86
+ Writes into caches shaped (B, n_kv_heads, T_max, d).
87
  """
88
+ kout, vout = self.k_cache, self.v_cache
 
89
 
90
+ if not torch.is_tensor(pos_ids):
91
+ # Scalar position for singleton batch (legacy)
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  kout[:, :, pos_ids, :] = k
93
  vout[:, :, pos_ids, :] = v
94
  return kout, vout
95
 
96
+ # Normalize dtype
97
+ pos_ids = pos_ids.to(dtype=torch.long, device=k.device)
98
+
99
+ # Shapes
100
+ if k.dim() != 4 or v.dim() != 4:
101
+ raise RuntimeError(f"KV update expects k,v 4D. Got k={tuple(k.shape)} v={tuple(v.shape)}")
102
+ B, Hkv, q_len, D = k.shape
103
+
104
+ # Ensure cache batch matches B
105
+ if kout.size(0) != B:
106
+ raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
107
+
108
+ # Case A: PREFILL — per-row write of a whole span of positions
109
+ if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
110
+ for i in range(B):
111
+ kout[i, :, pos_ids, :] = k[i] # (Hkv, q_len, D)
112
+ vout[i, :, pos_ids, :] = v[i]
113
+ return kout, vout
114
+
115
+ # Case B: STEP DECODE — one new position per row (q_len must be 1)
116
+ if pos_ids.dim() == 1 and pos_ids.numel() == B and q_len == 1:
117
  for i in range(B):
118
+ pi = int(pos_ids[i].item())
119
  kout[i, :, pi, :] = k[i, :, 0, :]
120
  vout[i, :, pi, :] = v[i, :, 0, :]
121
  return kout, vout
122
 
123
+ # Optional legacy: scalar pos for everyone
124
+ if pos_ids.dim() == 0 and q_len == 1:
125
+ pi = int(pos_ids.item())
126
+ kout[:, :, pi, :] = k[:, :, 0, :]
127
+ vout[:, :, pi, :] = v[:, :, 0, :]
128
  return kout, vout
129
 
 
130
  raise RuntimeError(
131
+ f"Unsupported KV update combo: k={tuple(k.shape)}, pos_ids={tuple(pos_ids.shape)}"
 
132
  )
133
 
134
 
135
 
136
 
137
+
138
  class MoondreamModel(nn.Module):
139
 
140
  def __init__(