HV-Khurdula commited on
Commit
11ca2e9
·
verified ·
1 Parent(s): 0841a6c

Update moondream.py

Browse files

fix: udpate KV to support batch and single prompt inference.

Files changed (1) hide show
  1. moondream.py +47 -12
moondream.py CHANGED
@@ -77,22 +77,57 @@ class KVCache(nn.Module):
77
  "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
78
  )
79
 
 
80
  def update(self, pos_ids, k, v):
81
- kout, vout = self.k_cache, self.v_cache
82
- # pos_ids: scalar (int or 0-D) OR LongTensor[B]
83
- if not torch.is_tensor(pos_ids) or pos_ids.ndim == 0:
84
- # singleton batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  kout[:, :, pos_ids, :] = k
86
  vout[:, :, pos_ids, :] = v
87
- else:
88
- # batched: write each row into its own position
89
- B = k.size(0)
90
- # Safe, explicit per-row scatter (B is usually small)
91
  for i in range(B):
92
- pi = int(pos_ids[i].item())
93
- kout[i, :, pi, :] = k[i]
94
- vout[i, :, pi, :] = v[i]
95
- return kout, vout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
98
 
 
77
  "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
78
  )
79
 
80
+ # in class KVCache
81
  def update(self, pos_ids, k, v):
82
+ """
83
+ Supports both:
84
+ - PREFILL: pos_ids shape == (q_len,), k/v shape == (B, H, q_len, D)
85
+ - STEP-DECODE (batched): pos_ids shape == (B,), k/v shape == (B, H, 1, D)
86
+ - STEP-DECODE (single): scalar pos_ids, k/v shape == (1, H, 1, D)
87
+ """
88
+ kout, vout = self.k_cache, self.v_cache # (Bcache, H, T, D)
89
+ B, H, Q, D = k.shape
90
+
91
+ # Case A: PREFILL — a vector of all time indices
92
+ if torch.is_tensor(pos_ids) and pos_ids.ndim == 1 and pos_ids.numel() == Q and Q > 1:
93
+ # broadcast batch dimension into cache if needed
94
+ if kout.size(0) != B:
95
+ # grow/shrink the first dim to match B (this happens after you cloned
96
+ # image caches to B rows for batched prefill)
97
+ new_k = kout.new_zeros((B,) + tuple(kout.shape[1:]))
98
+ new_v = vout.new_zeros((B,) + tuple(vout.shape[1:]))
99
+ # copy row 0 as base (image prefix) into all rows
100
+ new_k[:] = kout[0]
101
+ new_v[:] = vout[0]
102
+ self.k_cache = kout = new_k
103
+ self.v_cache = vout = new_v
104
+
105
+ # write the whole segment for all rows at once
106
  kout[:, :, pos_ids, :] = k
107
  vout[:, :, pos_ids, :] = v
108
+ return kout, vout
109
+
110
+ # Case B: STEP-DECODE (batched) — one position per row, q_len == 1
111
+ if torch.is_tensor(pos_ids) and pos_ids.ndim == 1 and pos_ids.numel() == B and Q == 1:
112
  for i in range(B):
113
+ pi = int(pos_ids[i])
114
+ kout[i, :, pi, :] = k[i, :, 0, :]
115
+ vout[i, :, pi, :] = v[i, :, 0, :]
116
+ return kout, vout
117
+
118
+ # Case C: STEP-DECODE (single) — scalar pos, B==1, q_len==1
119
+ if (not torch.is_tensor(pos_ids)) or pos_ids.ndim == 0:
120
+ pi = int(pos_ids)
121
+ kout[:B, :, pi, :] = k[:, :, 0, :]
122
+ vout[:B, :, pi, :] = v[:, :, 0, :]
123
+ return kout, vout
124
+
125
+ # Fallback: shape combo we didn't expect
126
+ raise RuntimeError(
127
+ f"KVCache.update: unsupported shapes pos_ids={tuple(pos_ids.shape) if torch.is_tensor(pos_ids) else '()'}, "
128
+ f"k={tuple(k.shape)}, v={tuple(v.shape)}"
129
+ )
130
+
131
 
132
 
133