thecollabagepatch commited on
Commit
c29a250
·
1 Parent(s): 3d79c33

sometimes a claude yolo

Browse files
Files changed (1) hide show
  1. jam_worker.py +167 -226
jam_worker.py CHANGED
@@ -1,4 +1,4 @@
1
- # jam_worker.py - Bar-locked spool rewrite
2
  from __future__ import annotations
3
 
4
  import os
@@ -20,7 +20,6 @@ from utils import (
20
  )
21
 
22
  def _dbg_rms_dbfs(x: np.ndarray) -> float:
23
-
24
  if x.ndim == 2:
25
  x = x.mean(axis=1)
26
  r = float(np.sqrt(np.mean(x * x) + 1e-12))
@@ -28,7 +27,6 @@ def _dbg_rms_dbfs(x: np.ndarray) -> float:
28
 
29
  def _dbg_rms_dbfs_model(x: np.ndarray) -> float:
30
  # x is model-rate, shape [S,C] or [S]
31
-
32
  if x.ndim == 2:
33
  x = x.mean(axis=1)
34
  r = float(np.sqrt(np.mean(x * x) + 1e-12))
@@ -37,6 +35,19 @@ def _dbg_rms_dbfs_model(x: np.ndarray) -> float:
37
  def _dbg_shape(x):
38
  return tuple(x.shape) if hasattr(x, "shape") else ("-",)
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  # -----------------------------
41
  # Data classes
42
  # -----------------------------
@@ -55,7 +66,7 @@ class JamParams:
55
  guidance_weight: float = 1.1
56
  temperature: float = 1.1
57
  topk: int = 40
58
- style_ramp_seconds: float = 8.0 # 0 => instant (current behavior), try 6.0–10.0 for gentle glides
59
 
60
 
61
  @dataclass
@@ -110,8 +121,6 @@ class JamWorker(threading.Thread):
110
  self.mrt.temperature = float(self.params.temperature)
111
  self.mrt.topk = int(self.params.topk)
112
 
113
-
114
-
115
  # codec/setup
116
  self._codec_fps = float(self.mrt.codec.frame_rate)
117
  JamWorker.FRAMES_PER_SECOND = self._codec_fps
@@ -137,8 +146,9 @@ class JamWorker(threading.Thread):
137
  self._spool = np.zeros((0, 2), dtype=np.float32) # (S,2) target SR
138
  self._spool_written = 0 # absolute frames written into spool
139
 
140
- self._pending_tail_model = None # type: Optional[np.ndarray] # last tail at model SR
141
- self._pending_tail_target_len = 0 # number of target-SR samples last tail contributed
 
142
 
143
  # bar clock: start with offset 0; if you have a downbeat estimator, set base later
144
  self._bar_clock = BarClock(self.params.target_sr, self.params.bpm, self.params.beats_per_bar, base_offset_samples=0)
@@ -163,6 +173,47 @@ class JamWorker(threading.Thread):
163
  # Prepare initial context from combined loop (best musical alignment)
164
  if self.params.combined_loop is not None:
165
  self._install_context_from_loop(self.params.combined_loop)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  # ---------- lifecycle ----------
168
 
@@ -248,13 +299,7 @@ class JamWorker(threading.Thread):
248
  return toks
249
 
250
  def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
251
- """Build *exactly* context_length_frames worth of tokens (e.g., 250 @ 25fps),
252
- while ensuring the *end* of the audio lands on a bar boundary.
253
- Strategy: take the largest integer number of bars <= ctx_seconds as the tail,
254
- then left-fill from just before that tail (wrapping if needed) to reach exactly
255
- ctx_seconds; finally, pad/trim to exact samples and, as a last resort, pad/trim
256
- tokens to the expected frame count.
257
- """
258
  wav = loop.as_stereo().resample(self._model_sr)
259
  data = wav.samples.astype(np.float32, copy=False)
260
  if data.ndim == 1:
@@ -289,8 +334,14 @@ class JamWorker(threading.Thread):
289
 
290
  # final snap to *exact* ctx samples
291
  if ctx.shape[0] < ctx_samps:
292
- pad = np.zeros((ctx_samps - ctx.shape[0], ctx.shape[1]), dtype=np.float32)
293
- ctx = np.concatenate([pad, ctx], axis=0)
 
 
 
 
 
 
294
  elif ctx.shape[0] > ctx_samps:
295
  ctx = ctx[-ctx_samps:]
296
 
@@ -301,79 +352,20 @@ class JamWorker(threading.Thread):
301
 
302
  # Force expected (F,D) at *return time*
303
  tokens = self._coerce_tokens(tokens)
 
 
 
 
 
304
  return tokens
305
 
306
- def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
307
- """Build *exactly* context_length_frames worth of tokens (e.g., 250 @ 25fps),
308
- while ensuring the *end* of the audio lands on a bar boundary.
309
- Strategy: take the largest integer number of bars <= ctx_seconds as the tail,
310
- then left-fill from just before that tail (wrapping if needed) to reach exactly
311
- ctx_seconds; finally, pad/trim to exact samples and, as a last resort, pad/trim
312
- tokens to the expected frame count.
313
- """
314
- wav = loop.as_stereo().resample(self._model_sr)
315
- data = wav.samples.astype(np.float32, copy=False)
316
- if data.ndim == 1:
317
- data = data[:, None]
318
-
319
- spb = self._bar_clock.seconds_per_bar()
320
- ctx_sec = float(self._ctx_seconds)
321
- sr = int(self._model_sr)
322
-
323
- # bars that fit fully inside ctx_sec (at least 1)
324
- bars_fit = max(1, int(ctx_sec // spb))
325
- tail_len_samps = int(round(bars_fit * spb * sr))
326
-
327
- # ensure we have enough source by tiling
328
- need = int(round(ctx_sec * sr)) + tail_len_samps
329
- if data.shape[0] == 0:
330
- data = np.zeros((1, 2), dtype=np.float32)
331
- reps = int(np.ceil(need / float(data.shape[0])))
332
- tiled = np.tile(data, (reps, 1))
333
-
334
- end = tiled.shape[0]
335
- tail = tiled[end - tail_len_samps:end]
336
-
337
- # left-fill to reach exact ctx samples (keeps end-of-bar alignment)
338
- ctx_samps = int(round(ctx_sec * sr))
339
- pad_len = ctx_samps - tail.shape[0]
340
- if pad_len > 0:
341
- pre = tiled[end - tail_len_samps - pad_len:end - tail_len_samps]
342
- ctx = np.concatenate([pre, tail], axis=0)
343
- else:
344
- ctx = tail[-ctx_samps:]
345
-
346
- # final snap to *exact* ctx samples
347
- if ctx.shape[0] < ctx_samps:
348
- pad = np.zeros((ctx_samps - ctx.shape[0], ctx.shape[1]), dtype=np.float32)
349
- ctx = np.concatenate([pad, ctx], axis=0)
350
- elif ctx.shape[0] > ctx_samps:
351
- ctx = ctx[-ctx_samps:]
352
-
353
- exact = au.Waveform(ctx, sr)
354
- tokens_full = self.mrt.codec.encode(exact).astype(np.int32)
355
- depth = int(self.mrt.config.decoder_codec_rvq_depth)
356
- tokens = tokens_full[:, :depth]
357
-
358
- # Last defense: force expected frame count
359
- frames = tokens.shape[0]
360
- exp = int(self._ctx_frames)
361
- if frames < exp:
362
- # repeat last frame
363
- pad = np.repeat(tokens[-1:, :], exp - frames, axis=0)
364
- tokens = np.concatenate([pad, tokens], axis=0)
365
- elif frames > exp:
366
- tokens = tokens[-exp:, :]
367
- return tokens
368
-
369
-
370
  def _install_context_from_loop(self, loop: au.Waveform):
371
  # Build exact-length, bar-locked context tokens
372
  context_tokens = self._encode_exact_context_tokens(loop)
373
  s = self.mrt.init_state()
374
  s.context_tokens = context_tokens
375
  self.state = s
376
- self._original_context_tokens = np.copy(context_tokens)
377
 
378
  def reseed_from_waveform(self, wav: au.Waveform):
379
  """Immediate reseed: replace context from provided wave (bar-locked, exact length)."""
@@ -383,14 +375,11 @@ class JamWorker(threading.Thread):
383
  s.context_tokens = context_tokens
384
  self.state = s
385
  self._model_stream = None # drop model-domain continuity so next chunk starts cleanly
386
- self._original_context_tokens = np.copy(context_tokens)
 
387
 
388
  def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
389
- """Queue a *seamless* reseed by token splicing instead of full restart.
390
- We compute a fresh, bar-locked context token tensor of exact length
391
- (e.g., 250 frames), then splice only the *tail* corresponding to
392
- `anchor_bars` so generation continues smoothly without resetting state.
393
- """
394
  new_ctx = self._encode_exact_context_tokens(recent_wav) # coerce to (F,D)
395
  F, D = self._expected_token_shape()
396
 
@@ -419,44 +408,20 @@ class JamWorker(threading.Thread):
419
  "tokens": spliced,
420
  "debug": {"F": F, "D": D, "splice_frames": splice_frames, "frames_per_bar": frames_per_bar}
421
  }
422
-
423
 
424
-
425
- def reseed_from_waveform(self, wav: au.Waveform):
426
- """Immediate reseed: replace context from provided wave (bar-aligned tail)."""
427
- wav = wav.as_stereo().resample(self._model_sr)
428
- tail = take_bar_aligned_tail(wav, self.params.bpm, self.params.beats_per_bar, self._ctx_seconds)
429
- tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
430
- depth = int(self.mrt.config.decoder_codec_rvq_depth)
431
- context_tokens = tokens_full[:, :depth]
432
-
433
- s = self.mrt.init_state()
434
- s.context_tokens = context_tokens
435
- self.state = s
436
- # reset model stream so next generate starts cleanly
437
- self._model_stream = None
438
-
439
- # optional loudness match will be applied per-chunk on emission
440
-
441
- # also remember this as new "original"
442
- self._original_context_tokens = np.copy(context_tokens)
443
-
444
- # ---------- core streaming helpers ----------
445
 
446
  def _append_model_chunk_and_spool(self, wav: au.Waveform) -> None:
447
  """
448
- Conservative boundary fix:
449
- - Emit body+tail immediately (target SR), unchanged from your original behavior.
450
- - On *next* call, compute the mixed overlap (prev tail ⨉ cos + new head ⨉ sin),
451
- resample it, and overwrite the last `_pending_tail_target_len` samples in the
452
- target-SR spool with that mixed overlap. Then emit THIS chunk's body+tail and
453
- remember THIS chunk's tail length at target SR for the next correction.
454
-
455
- This keeps external timing and bar alignment identical, but removes the audible
456
- fade-to-zero at chunk ends.
457
  """
458
-
459
- # ---- unpack model-rate samples ----
460
  s = wav.samples.astype(np.float32, copy=False)
461
  if s.ndim == 1:
462
  s = s[:, None]
@@ -464,119 +429,90 @@ class JamWorker(threading.Thread):
464
  if n_samps == 0:
465
  return
466
 
467
- # crossfade length in model samples
 
 
 
468
  try:
469
  xfade_s = float(self.mrt.config.crossfade_length)
470
  except Exception:
471
  xfade_s = 0.0
472
  xfade_n = int(round(max(0.0, xfade_s) * float(self._model_sr)))
473
 
474
- # helper: resample to target SR via your streaming resampler
 
 
475
  def to_target(y: np.ndarray) -> np.ndarray:
476
  return y if self._rs is None else self._rs.process(y, final=False)
477
 
478
- # ------------------------------------------
479
- # (A) If we have a pending model tail, fix the last emitted tail at target SR
480
- # ------------------------------------------
481
- if self._pending_tail_model is not None and self._pending_tail_model.shape[0] == xfade_n and xfade_n > 0 and n_samps >= xfade_n:
482
- head = s[:xfade_n, :]
483
-
484
- print(f"[model] head len={head.shape[0]} rms={_dbg_rms_dbfs_model(head):+.1f} dBFS")
485
-
486
- t = np.linspace(0.0, np.pi/2.0, xfade_n, endpoint=False, dtype=np.float32)[:, None]
487
- cosw = np.cos(t, dtype=np.float32)
488
- sinw = np.sin(t, dtype=np.float32)
489
- mixed_model = (self._pending_tail_model * cosw) + (head * sinw) # [xfade_n, C] at model SR
490
-
491
- y_mixed = to_target(mixed_model.astype(np.float32))
492
- Lcorr = int(y_mixed.shape[0]) # exact target-SR samples to write
493
-
494
- # DEBUG: corrected overlap RMS (what we intend to hear at the boundary)
495
- if y_mixed.size:
496
- print(f"[append] mixedOverlap len={y_mixed.shape[0]} rms={_dbg_rms_dbfs(y_mixed):+.1f} dBFS")
497
-
498
- # Overwrite the last `_pending_tail_target_len` samples of the spool with `y_mixed`.
499
- # Use the *smaller* of the two lengths to be safe.
500
- Lpop = min(self._pending_tail_target_len, self._spool.shape[0], Lcorr)
501
- if Lpop > 0 and self._spool.size:
502
- # Trim last Lpop samples
503
- self._spool = self._spool[:-Lpop, :]
504
- self._spool_written -= Lpop
505
- # Append corrected overlap (trim/pad to Lpop to avoid drift)
506
- if Lcorr != Lpop:
507
- if Lcorr > Lpop:
508
- y_m = y_mixed[-Lpop:, :]
509
- else:
510
- pad = np.zeros((Lpop - Lcorr, y_mixed.shape[1]), dtype=np.float32)
511
- y_m = np.concatenate([y_mixed, pad], axis=0)
512
- else:
513
- y_m = y_mixed
514
- self._spool = np.concatenate([self._spool, y_m], axis=0) if self._spool.size else y_m
515
- self._spool_written += y_m.shape[0]
516
-
517
- # For internal continuity, update _model_stream like before
518
- if self._model_stream is None or self._model_stream.shape[0] < xfade_n:
519
- self._model_stream = s[xfade_n:].copy()
520
- else:
521
- self._model_stream = np.concatenate([self._model_stream[:-xfade_n], mixed_model, s[xfade_n:]], axis=0)
522
- else:
523
- # First-ever call or too-short to mix: maintain _model_stream minimally
524
- if xfade_n > 0 and n_samps > xfade_n:
525
- self._model_stream = s[xfade_n:].copy() if self._model_stream is None else np.concatenate([self._model_stream, s[xfade_n:]], axis=0)
526
- else:
527
- self._model_stream = s.copy() if self._model_stream is None else np.concatenate([self._model_stream, s], axis=0)
528
-
529
- # ------------------------------------------
530
- # (B) Emit THIS chunk's body and tail (same external behavior)
531
- # ------------------------------------------
532
- if xfade_n > 0 and n_samps >= (2 * xfade_n):
533
- body = s[xfade_n:-xfade_n, :]
534
- print(f"[model] body len={body.shape[0]} rms={_dbg_rms_dbfs_model(body):+.1f} dBFS")
535
- if body.size:
536
- y_body = to_target(body.astype(np.float32))
537
- if y_body.size:
538
- # DEBUG: body RMS we are actually appending
539
- print(f"[append] body len={y_body.shape[0]} rms={_dbg_rms_dbfs(y_body):+.1f} dBFS")
540
- self._spool = np.concatenate([self._spool, y_body], axis=0) if self._spool.size else y_body
541
- self._spool_written += y_body.shape[0]
542
  else:
543
- # If chunk too short for head+tail split, treat all (minus preroll) as body
544
- if xfade_n > 0 and n_samps > xfade_n:
545
- body = s[xfade_n:, :]
546
- print(f"[model] body(S) len={body.shape[0]} rms={_dbg_rms_dbfs_model(body):+.1f} dBFS")
547
- y_body = to_target(body.astype(np.float32))
548
- if y_body.size:
549
- # DEBUG: body RMS in short-chunk path
550
- print(f"[append] body(len=short) len={y_body.shape[0]} rms={_dbg_rms_dbfs(y_body):+.1f} dBFS")
551
- self._spool = np.concatenate([self._spool, y_body], axis=0) if self._spool.size else y_body
552
- self._spool_written += y_body.shape[0]
553
- # No tail to remember this round
554
- self._pending_tail_model = None
555
- self._pending_tail_target_len = 0
556
- return
557
-
558
- # Tail (always remember how many TARGET samples we append)
 
 
 
 
 
 
 
 
559
  if xfade_n > 0 and n_samps >= xfade_n:
560
- tail = s[-xfade_n:, :]
561
- print(f"[model] tail len={tail.shape[0]} rms={_dbg_rms_dbfs_model(tail):+.1f} dBFS")
562
- y_tail = to_target(tail.astype(np.float32))
563
- Ltail = int(y_tail.shape[0])
564
- if Ltail:
565
- # DEBUG: tail RMS we are appending now (to be corrected next call)
566
- print(f"[append] tail len={y_tail.shape[0]} rms={_dbg_rms_dbfs(y_tail):+.1f} dBFS")
567
- self._spool = np.concatenate([self._spool, y_tail], axis=0) if self._spool.size else y_tail
568
- self._spool_written += Ltail
569
- self._pending_tail_model = tail.copy()
570
- self._pending_tail_target_len = Ltail
571
- else:
572
- # Nothing appended (resampler returned nothing yet) — keep model tail but mark zero target len
573
- self._pending_tail_model = tail.copy()
574
- self._pending_tail_target_len = 0
575
  else:
576
- self._pending_tail_model = None
577
- self._pending_tail_target_len = 0
578
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579
 
 
 
 
 
580
 
581
  def _should_generate_next_chunk(self) -> bool:
582
  # Allow running ahead relative to whichever is larger: last *consumed*
@@ -613,6 +549,7 @@ class JamWorker(threading.Thread):
613
  "guidance_weight": float(self.params.guidance_weight),
614
  "temperature": float(self.params.temperature),
615
  "topk": int(self.params.topk),
 
616
  }
617
  chunk = JamChunk(index=self.idx, audio_base64=audio_b64, metadata=meta)
618
 
@@ -637,6 +574,7 @@ class JamWorker(threading.Thread):
637
  # inplace update (no reset)
638
  self.state.context_tokens = spliced
639
  self._pending_token_splice = None
 
640
  except Exception:
641
  # fallback: full reseed using spliced tokens
642
  new_state = self.mrt.init_state()
@@ -644,6 +582,7 @@ class JamWorker(threading.Thread):
644
  self.state = new_state
645
  self._model_stream = None
646
  self._pending_token_splice = None
 
647
  elif self._pending_reseed is not None:
648
  ctx = self._coerce_tokens(self._pending_reseed["ctx"])
649
  new_state = self.mrt.init_state()
@@ -651,6 +590,7 @@ class JamWorker(threading.Thread):
651
  self.state = new_state
652
  self._model_stream = None
653
  self._pending_reseed = None
 
654
 
655
  # ---------- main loop ----------
656
 
@@ -687,9 +627,10 @@ class JamWorker(threading.Thread):
687
  self._emit_ready()
688
 
689
  # finalize resampler (flush) — not strictly necessary here
690
- tail = self._rs.process(np.zeros((0,2), np.float32), final=True)
691
- if tail.size:
692
- self._spool = np.concatenate([self._spool, tail], axis=0)
693
- self._spool_written += tail.shape[0]
 
694
  # one last emit attempt
695
- self._emit_ready()
 
1
+ # jam_worker.py - Updated with robust silence handling
2
  from __future__ import annotations
3
 
4
  import os
 
20
  )
21
 
22
  def _dbg_rms_dbfs(x: np.ndarray) -> float:
 
23
  if x.ndim == 2:
24
  x = x.mean(axis=1)
25
  r = float(np.sqrt(np.mean(x * x) + 1e-12))
 
27
 
28
  def _dbg_rms_dbfs_model(x: np.ndarray) -> float:
29
  # x is model-rate, shape [S,C] or [S]
 
30
  if x.ndim == 2:
31
  x = x.mean(axis=1)
32
  r = float(np.sqrt(np.mean(x * x) + 1e-12))
 
35
  def _dbg_shape(x):
36
  return tuple(x.shape) if hasattr(x, "shape") else ("-",)
37
 
38
+ def _is_silent(audio: np.ndarray, threshold_db: float = -60.0) -> bool:
39
+ """Check if audio is effectively silent."""
40
+ if audio.size == 0:
41
+ return True
42
+ if audio.ndim == 2:
43
+ audio = audio.mean(axis=1)
44
+ rms = float(np.sqrt(np.mean(audio**2)))
45
+ return 20.0 * np.log10(max(rms, 1e-12)) < threshold_db
46
+
47
+ def _has_energy(audio: np.ndarray, threshold_db: float = -40.0) -> bool:
48
+ """Check if audio has significant energy (stricter than just non-silent)."""
49
+ return not _is_silent(audio, threshold_db)
50
+
51
  # -----------------------------
52
  # Data classes
53
  # -----------------------------
 
66
  guidance_weight: float = 1.1
67
  temperature: float = 1.1
68
  topk: int = 40
69
+ style_ramp_seconds: float = 8.0
70
 
71
 
72
  @dataclass
 
121
  self.mrt.temperature = float(self.params.temperature)
122
  self.mrt.topk = int(self.params.topk)
123
 
 
 
124
  # codec/setup
125
  self._codec_fps = float(self.mrt.codec.frame_rate)
126
  JamWorker.FRAMES_PER_SECOND = self._codec_fps
 
146
  self._spool = np.zeros((0, 2), dtype=np.float32) # (S,2) target SR
147
  self._spool_written = 0 # absolute frames written into spool
148
 
149
+ # Health monitoring
150
+ self._silence_streak = 0 # consecutive silent chunks
151
+ self._last_good_context_tokens = None # backup of last known good context
152
 
153
  # bar clock: start with offset 0; if you have a downbeat estimator, set base later
154
  self._bar_clock = BarClock(self.params.target_sr, self.params.bpm, self.params.beats_per_bar, base_offset_samples=0)
 
173
  # Prepare initial context from combined loop (best musical alignment)
174
  if self.params.combined_loop is not None:
175
  self._install_context_from_loop(self.params.combined_loop)
176
+ # Save this as our "good" context backup
177
+ if hasattr(self.state, 'context_tokens') and self.state.context_tokens is not None:
178
+ self._last_good_context_tokens = np.copy(self.state.context_tokens)
179
+
180
+ # ---------- NEW: Health monitoring methods ----------
181
+
182
+ def _check_model_health(self, new_chunk: np.ndarray) -> bool:
183
+ """Check if the model output looks healthy."""
184
+ if _is_silent(new_chunk, threshold_db=-80.0):
185
+ self._silence_streak += 1
186
+ print(f"⚠️ Silent chunk detected (streak: {self._silence_streak})")
187
+ return False
188
+ else:
189
+ if self._silence_streak > 0:
190
+ print(f"✅ Audio resumed after {self._silence_streak} silent chunks")
191
+ self._silence_streak = 0
192
+ return True
193
+
194
+ def _recover_from_silence(self):
195
+ """Attempt to recover from silence by restoring last good context."""
196
+ print("🔧 Attempting recovery from silence...")
197
+
198
+ if self._last_good_context_tokens is not None:
199
+ # Restore last known good context
200
+ try:
201
+ new_state = self.mrt.init_state()
202
+ new_state.context_tokens = np.copy(self._last_good_context_tokens)
203
+ self.state = new_state
204
+ self._model_stream = None # Reset stream to start fresh
205
+ print(" Restored last good context")
206
+ except Exception as e:
207
+ print(f" Context restoration failed: {e}")
208
+
209
+ # If we have the original loop, rebuild context from it
210
+ elif self.params.combined_loop is not None:
211
+ try:
212
+ self._install_context_from_loop(self.params.combined_loop)
213
+ self._model_stream = None
214
+ print(" Rebuilt context from original loop")
215
+ except Exception as e:
216
+ print(f" Context rebuild failed: {e}")
217
 
218
  # ---------- lifecycle ----------
219
 
 
299
  return toks
300
 
301
  def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
302
+ """Build *exactly* context_length_frames worth of tokens, ensuring bar alignment."""
 
 
 
 
 
 
303
  wav = loop.as_stereo().resample(self._model_sr)
304
  data = wav.samples.astype(np.float32, copy=False)
305
  if data.ndim == 1:
 
334
 
335
  # final snap to *exact* ctx samples
336
  if ctx.shape[0] < ctx_samps:
337
+ # Instead of zero padding, repeat the audio to fill
338
+ shortfall = ctx_samps - ctx.shape[0]
339
+ if ctx.shape[0] > 0:
340
+ fill = np.tile(ctx, (int(np.ceil(shortfall / ctx.shape[0])) + 1, 1))[:shortfall]
341
+ ctx = np.concatenate([fill, ctx], axis=0)
342
+ else:
343
+ print("⚠️ Zero-length context, using fallback")
344
+ ctx = np.zeros((ctx_samps, 2), dtype=np.float32)
345
  elif ctx.shape[0] > ctx_samps:
346
  ctx = ctx[-ctx_samps:]
347
 
 
352
 
353
  # Force expected (F,D) at *return time*
354
  tokens = self._coerce_tokens(tokens)
355
+
356
+ # Validate that we don't have a silent context
357
+ if _is_silent(ctx, threshold_db=-80.0):
358
+ print("⚠️ Generated silent context - this may cause issues")
359
+
360
  return tokens
361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  def _install_context_from_loop(self, loop: au.Waveform):
363
  # Build exact-length, bar-locked context tokens
364
  context_tokens = self._encode_exact_context_tokens(loop)
365
  s = self.mrt.init_state()
366
  s.context_tokens = context_tokens
367
  self.state = s
368
+ self._last_good_context_tokens = np.copy(context_tokens)
369
 
370
  def reseed_from_waveform(self, wav: au.Waveform):
371
  """Immediate reseed: replace context from provided wave (bar-locked, exact length)."""
 
375
  s.context_tokens = context_tokens
376
  self.state = s
377
  self._model_stream = None # drop model-domain continuity so next chunk starts cleanly
378
+ self._last_good_context_tokens = np.copy(context_tokens)
379
+ self._silence_streak = 0 # Reset health monitoring
380
 
381
  def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
382
+ """Queue a *seamless* reseed by token splicing instead of full restart."""
 
 
 
 
383
  new_ctx = self._encode_exact_context_tokens(recent_wav) # coerce to (F,D)
384
  F, D = self._expected_token_shape()
385
 
 
408
  "tokens": spliced,
409
  "debug": {"F": F, "D": D, "splice_frames": splice_frames, "frames_per_bar": frames_per_bar}
410
  }
 
411
 
412
+ # ---------- REWRITTEN: core streaming helpers ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
 
414
  def _append_model_chunk_and_spool(self, wav: au.Waveform) -> None:
415
  """
416
+ REWRITTEN: Robust audio processing with silence detection and health monitoring.
417
+
418
+ Strategy:
419
+ 1. Validate input chunk for silence/issues
420
+ 2. Use simpler crossfading that handles silence gracefully
421
+ 3. Update model stream with health checks
422
+ 4. Convert to target SR and append to spool
 
 
423
  """
424
+ # Unpack model-rate samples
 
425
  s = wav.samples.astype(np.float32, copy=False)
426
  if s.ndim == 1:
427
  s = s[:, None]
 
429
  if n_samps == 0:
430
  return
431
 
432
+ # Health check on new chunk
433
+ is_healthy = self._check_model_health(s)
434
+
435
+ # Get crossfade params
436
  try:
437
  xfade_s = float(self.mrt.config.crossfade_length)
438
  except Exception:
439
  xfade_s = 0.0
440
  xfade_n = int(round(max(0.0, xfade_s) * float(self._model_sr)))
441
 
442
+ print(f"[model] chunk len={n_samps} rms={_dbg_rms_dbfs_model(s):+.1f} dBFS healthy={is_healthy}")
443
+
444
+ # Helper: resample to target SR
445
  def to_target(y: np.ndarray) -> np.ndarray:
446
  return y if self._rs is None else self._rs.process(y, final=False)
447
 
448
+ # --- SIMPLIFIED CROSSFADE LOGIC ---
449
+
450
+ if self._model_stream is None:
451
+ # First chunk - no crossfading needed
452
+ self._model_stream = s.copy()
453
+
454
+ elif xfade_n <= 0 or n_samps < xfade_n:
455
+ # No crossfade configured or chunk too short - simple append
456
+ self._model_stream = np.concatenate([self._model_stream, s], axis=0)
457
+
458
+ elif _is_silent(self._model_stream[-xfade_n:]) or _is_silent(s[:xfade_n]):
459
+ # One side is silent - don't crossfade, just append
460
+ print(f"[crossfade] Skipping crossfade due to silence")
461
+ self._model_stream = np.concatenate([self._model_stream, s], axis=0)
462
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  else:
464
+ # Normal crossfade between non-silent audio
465
+ tail = self._model_stream[-xfade_n:]
466
+ head = s[:xfade_n]
467
+ body = s[xfade_n:] if n_samps > xfade_n else np.zeros((0, s.shape[1]), dtype=np.float32)
468
+
469
+ # Equal power crossfade
470
+ t = np.linspace(0.0, 1.0, xfade_n, dtype=np.float32)[:, None]
471
+ fade_out = np.cos(t * np.pi / 2.0)
472
+ fade_in = np.sin(t * np.pi / 2.0)
473
+
474
+ mixed = tail * fade_out + head * fade_in
475
+
476
+ print(f"[crossfade] tail rms={_dbg_rms_dbfs_model(tail):+.1f} head rms={_dbg_rms_dbfs_model(head):+.1f} mixed rms={_dbg_rms_dbfs_model(mixed):+.1f}")
477
+
478
+ # Update model stream: remove old tail, add mixed section, add body
479
+ self._model_stream = np.concatenate([
480
+ self._model_stream[:-xfade_n],
481
+ mixed,
482
+ body
483
+ ], axis=0)
484
+
485
+ # --- CONVERT AND APPEND TO SPOOL ---
486
+
487
+ # Take the new audio from this iteration (avoid reprocessing old audio)
488
  if xfade_n > 0 and n_samps >= xfade_n:
489
+ # Normal case: body after crossfade region
490
+ new_audio = s[xfade_n:] if n_samps > xfade_n else s
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  else:
492
+ # Short chunk or no crossfade: use entire chunk
493
+ new_audio = s
494
+
495
+ if new_audio.shape[0] > 0:
496
+ target_audio = to_target(new_audio)
497
+ if target_audio.shape[0] > 0:
498
+ print(f"[append] body len={target_audio.shape[0]} rms={_dbg_rms_dbfs(target_audio):+.1f} dBFS")
499
+ self._spool = np.concatenate([self._spool, target_audio], axis=0) if self._spool.size else target_audio
500
+ self._spool_written += target_audio.shape[0]
501
+
502
+ # --- HEALTH MONITORING ---
503
+
504
+ if not is_healthy:
505
+ if self._silence_streak >= 3: # After 3 silent chunks, try to recover
506
+ self._recover_from_silence()
507
+ else:
508
+ # Save current context as "good" backup
509
+ if hasattr(self.state, 'context_tokens') and self.state.context_tokens is not None:
510
+ self._last_good_context_tokens = np.copy(self.state.context_tokens)
511
 
512
+ # Trim model stream to reasonable length (keep ~30 seconds)
513
+ max_model_samples = int(30.0 * self._model_sr)
514
+ if self._model_stream.shape[0] > max_model_samples:
515
+ self._model_stream = self._model_stream[-max_model_samples:]
516
 
517
  def _should_generate_next_chunk(self) -> bool:
518
  # Allow running ahead relative to whichever is larger: last *consumed*
 
549
  "guidance_weight": float(self.params.guidance_weight),
550
  "temperature": float(self.params.temperature),
551
  "topk": int(self.params.topk),
552
+ "silence_streak": self._silence_streak, # Add health info
553
  }
554
  chunk = JamChunk(index=self.idx, audio_base64=audio_b64, metadata=meta)
555
 
 
574
  # inplace update (no reset)
575
  self.state.context_tokens = spliced
576
  self._pending_token_splice = None
577
+ print("[reseed] Token splice applied")
578
  except Exception:
579
  # fallback: full reseed using spliced tokens
580
  new_state = self.mrt.init_state()
 
582
  self.state = new_state
583
  self._model_stream = None
584
  self._pending_token_splice = None
585
+ print("[reseed] Token splice fallback to full reset")
586
  elif self._pending_reseed is not None:
587
  ctx = self._coerce_tokens(self._pending_reseed["ctx"])
588
  new_state = self.mrt.init_state()
 
590
  self.state = new_state
591
  self._model_stream = None
592
  self._pending_reseed = None
593
+ print("[reseed] Full reseed applied")
594
 
595
  # ---------- main loop ----------
596
 
 
627
  self._emit_ready()
628
 
629
  # finalize resampler (flush) — not strictly necessary here
630
+ if self._rs is not None:
631
+ tail = self._rs.process(np.zeros((0,2), np.float32), final=True)
632
+ if tail.size:
633
+ self._spool = np.concatenate([self._spool, tail], axis=0)
634
+ self._spool_written += tail.shape[0]
635
  # one last emit attempt
636
+ self._emit_ready()