from typing import List, Tuple from enum import Enum import torch from model import EchoDiT # helper def _get_uncond_text_input_ids_and_mask(batch_size: int, max_length: int, device: str | None = None) -> tuple[torch.Tensor, torch.Tensor]: # returns zeros for text input ids, and (True, False, False, ... ) for text mask text_input_ids_uncond = torch.zeros((batch_size, max_length), dtype=torch.int32) text_mask_uncond = torch.zeros((batch_size, max_length), dtype=torch.bool) text_mask_uncond[:, 0] = True if device is not None: text_input_ids_uncond = text_input_ids_uncond.to(device) text_mask_uncond = text_mask_uncond.to(device) return text_input_ids_uncond, text_mask_uncond # SIMPLE SAMPLER FOR REFERENCE, SHOULD PROBABLY AVOID @torch.inference_mode() def sample_euler_cfg_simple( model: EchoDiT, speaker_latent: torch.Tensor, speaker_mask: torch.Tensor, text_input_ids: torch.Tensor, text_mask: torch.Tensor, rng_seed: int, num_steps: int, cfg_scale: float, ) -> torch.Tensor: device, dtype = model.device, model.dtype batch_size = text_input_ids.shape[0] torch.manual_seed(rng_seed) t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device) speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask) full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond], dim=0) full_text_mask = torch.cat([text_mask, text_mask_uncond], dim=0) full_speaker_latent = torch.cat([speaker_latent, speaker_latent_uncond], dim=0) full_speaker_mask = torch.cat([speaker_mask, speaker_mask_uncond], dim=0) kv_cache = model.get_kv_cache( speaker_latent=full_speaker_latent.to(dtype), speaker_mask=full_speaker_mask, text_input_ids=full_text_input_ids, text_mask=full_text_mask, ) x_t = torch.randn((batch_size, 640, 80), device=device, dtype=torch.float32) for i in range(num_steps): t, t_next = t_schedule[i], t_schedule[i+1] v_cond, v_uncond = model( x=torch.cat([x_t, x_t], dim=0).to(dtype), t=(torch.ones((batch_size * 2,), device=device) * t).to(dtype), text_input_ids=None, text_mask=full_text_mask, speaker_latent=None, speaker_mask=full_speaker_mask, kv_cache=kv_cache, ).float().chunk(2, dim=0) v_pred = v_cond + cfg_scale * (v_cond - v_uncond) # note: x_0_pred is x_t - v_pred * t x_t = x_t + v_pred * (t_next - t) return x_t ###### def _temporal_score_rescale(v_pred: torch.Tensor, x_t: torch.Tensor, t: float, rescale_k: float, rescale_sigma: float) -> torch.Tensor: if t < 1: snr = (1 - t) ** 2 / (t ** 2) ratio = (snr * rescale_sigma ** 2 + 1) / (snr * rescale_sigma ** 2 / rescale_k + 1) return 1 / (1 - t) * (ratio * ((1 - t) * v_pred + x_t) - x_t) return v_pred def _get_first_n_kv_cache(kv_cache: List[List[torch.Tensor]], n: int) -> List[List[torch.Tensor]]: return [[kv_cache[i][0][:n], kv_cache[i][1][:n]] for i in range(len(kv_cache))] def _multiply_speaker_kv_cache( kv_cache: List[List[torch.Tensor]], scale: float, text_length: int, max_layers: int = 24, ) -> List[List[torch.Tensor]]: # multiplies speaker kv cache by scale # speaker keys start after text keys (at position text_length) for i in range(min(max_layers, len(kv_cache))): for j in range(len(kv_cache[i])): kv_cache[i][j][:, text_length:] *= scale @torch.inference_mode() def sample_euler_cfg( model: EchoDiT, speaker_latent: torch.Tensor, speaker_mask: torch.Tensor, text_input_ids: torch.Tensor, text_mask: torch.Tensor, rng_seed: int, num_steps: int, cfg_scale: float, cfg_min_t: float, cfg_max_t: float, truncation_factor: float | None, rescale_k: float | None, rescale_sigma: float | None, speaker_k_scale: float | None, speaker_k_max_layers: int | None, speaker_k_min_t: float | None, block_size: int | None = None, ) -> torch.Tensor: if block_size is None: block_size = 640 torch.manual_seed(rng_seed) INIT_SCALE = 0.999 device, dtype = model.device, model.dtype batch_size = text_input_ids.shape[0] t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device) speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask) full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond], dim=0) full_text_mask = torch.cat([text_mask, text_mask_uncond], dim=0) full_speaker_latent = torch.cat([speaker_latent, speaker_latent_uncond], dim=0) full_speaker_mask = torch.cat([speaker_mask, speaker_mask_uncond], dim=0) kv_cache_full = model.get_kv_cache( speaker_latent=full_speaker_latent.to(dtype), speaker_mask=full_speaker_mask, text_input_ids=full_text_input_ids, text_mask=full_text_mask, ) # could make faster by not computing fully / recomputing for unconditional batch elements kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size) if speaker_k_scale is not None: _multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers) x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32) if truncation_factor is not None: x_t = x_t * truncation_factor for i in range(num_steps): t, t_next = t_schedule[i], t_schedule[i+1] has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item() if has_cfg: v_cond, v_uncond = model( x=torch.cat([x_t, x_t], dim=0).to(dtype), t=(torch.ones((batch_size * 2,), device=device) * t).to(dtype), text_input_ids=None, text_mask=full_text_mask, speaker_latent=None, speaker_mask=full_speaker_mask, kv_cache=kv_cache_full, ).float().chunk(2, dim=0) v_pred = v_cond + cfg_scale * (v_cond - v_uncond) else: v_pred = model( x=x_t.to(dtype), t=(torch.ones((batch_size,), device=device) * t).to(dtype), text_input_ids=None, text_mask=text_mask, speaker_latent=None, speaker_mask=speaker_mask, kv_cache=kv_cache, ).float() if rescale_k is not None and rescale_sigma is not None: v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma) if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t: _multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers) x_t = x_t + v_pred * (t_next - t) return x_t @torch.inference_mode() def sample_euler_cfg_independent_guidances( model: EchoDiT, speaker_latent: torch.Tensor, speaker_mask: torch.Tensor, text_input_ids: torch.Tensor, text_mask: torch.Tensor, rng_seed: int, num_steps: int, cfg_scale_text: float, cfg_scale_speaker: float, cfg_min_t: float, cfg_max_t: float, truncation_factor: float | None, rescale_k: float | None, rescale_sigma: float | None, speaker_k_scale: float | None, speaker_k_max_layers: int | None, speaker_k_min_t: float | None, block_size: int | None = None, ) -> torch.Tensor: if block_size is None: block_size = 640 torch.manual_seed(rng_seed) INIT_SCALE = 0.999 device, dtype = model.device, model.dtype batch_size = text_input_ids.shape[0] t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device) speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask) full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond, text_input_ids], dim=0) full_text_mask = torch.cat([text_mask, text_mask_uncond, text_mask], dim=0) full_speaker_latent = torch.cat([speaker_latent, speaker_latent, speaker_latent_uncond], dim=0) full_speaker_mask = torch.cat([speaker_mask, speaker_mask, speaker_mask_uncond], dim=0) kv_cache_full = model.get_kv_cache( speaker_latent=full_speaker_latent.to(dtype), speaker_mask=full_speaker_mask, text_input_ids=full_text_input_ids, text_mask=full_text_mask, ) kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size) if speaker_k_scale is not None: _multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers) x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32) if truncation_factor is not None: x_t = x_t * truncation_factor for i in range(num_steps): t, t_next = t_schedule[i], t_schedule[i+1] has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item() if has_cfg: v_cond, v_uncond_text, v_uncond_speaker = model( x=torch.cat([x_t, x_t, x_t], dim=0).to(dtype), t=(torch.ones((batch_size * 3,), device=device) * t).to(dtype), text_input_ids=None, text_mask=full_text_mask, speaker_latent=None, speaker_mask=full_speaker_mask, kv_cache=kv_cache_full, ).float().chunk(3, dim=0) v_pred = v_cond + cfg_scale_text * (v_cond - v_uncond_text) + cfg_scale_speaker * (v_cond - v_uncond_speaker) else: v_pred = model( x=x_t.to(dtype), t=(torch.ones((batch_size,), device=device) * t).to(dtype), text_input_ids=None, text_mask=text_mask, speaker_latent=None, speaker_mask=speaker_mask, kv_cache=kv_cache, ).float() if rescale_k is not None and rescale_sigma is not None: v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma) if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t: _multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers) x_t = x_t + v_pred * (t_next - t) return x_t @torch.inference_mode() def sample_euler_cfg_alternating_guidances( model: EchoDiT, speaker_latent: torch.Tensor, speaker_mask: torch.Tensor, text_input_ids: torch.Tensor, text_mask: torch.Tensor, rng_seed: int, num_steps: int, cfg_scale_text: float, cfg_scale_speaker: float, cfg_min_t: float, cfg_max_t: float, truncation_factor: float | None, rescale_k: float | None, rescale_sigma: float | None, speaker_k_scale: float | None, speaker_k_max_layers: int | None, speaker_k_min_t: float | None, block_size: int | None = None, ) -> torch.Tensor: if block_size is None: block_size = 640 torch.manual_seed(rng_seed) INIT_SCALE = 0.999 device, dtype = model.device, model.dtype batch_size = text_input_ids.shape[0] t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device) # TODO THIS / THE BELOW IS TECHNICALLY INCORRECT, AS IT ASSUMES A CAUSAL TEXT ENCODER (which is not the case) # IF THE TEXT ENCODER WERE CAUSAL, THEN USING AN UNCOND TEXT MASK ON COND TEXT INPUTS GIVES YOU AN UNCOND STATE DUE TO BOS=0 # HOWEVER, MIGHT NOT MAKE MUCH OF A DIFFERENCE # CHANGED ALL OTHER SAMPLERS TO USE CORRECT UNCONDITIONAL CACHES speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask) full_text_input_ids = torch.cat([text_input_ids, text_input_ids], dim=0) full_text_mask = torch.cat([text_mask, text_mask_uncond], dim=0) full_speaker_latent = torch.cat([speaker_latent, speaker_latent_uncond], dim=0) full_speaker_mask = torch.cat([speaker_mask, speaker_mask_uncond], dim=0) kv_cache_full = model.get_kv_cache( speaker_latent=full_speaker_latent.to(dtype), speaker_mask=full_speaker_mask, text_input_ids=full_text_input_ids, text_mask=full_text_mask, ) kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size) if speaker_k_scale is not None: _multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers) x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32) if truncation_factor is not None: x_t = x_t * truncation_factor for i in range(num_steps): t, t_next = t_schedule[i], t_schedule[i+1] has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item() if has_cfg: v_cond, v_uncond = model( x=torch.cat([x_t, x_t], dim=0).to(dtype), t=(torch.ones((batch_size * 2,), device=device) * t).to(dtype), text_input_ids=None, text_mask=torch.cat([text_mask, text_mask_uncond if i % 2 == 0 else text_mask], dim=0), speaker_latent=None, speaker_mask=torch.cat([speaker_mask, speaker_mask if i % 2 == 0 else speaker_mask_uncond], dim=0), kv_cache=kv_cache_full, ).float().chunk(2, dim=0) v_pred = v_cond + (cfg_scale_text if i % 2 == 0 else cfg_scale_speaker) * (v_cond - v_uncond) else: v_pred = model( x=x_t.to(dtype), t=(torch.ones((batch_size,), device=device) * t).to(dtype), text_input_ids=None, text_mask=text_mask, speaker_latent=None, speaker_mask=speaker_mask, kv_cache=kv_cache, ).float() if rescale_k is not None and rescale_sigma is not None: v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma) if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t: _multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers) x_t = x_t + v_pred * (t_next - t) return x_t @torch.inference_mode() def sample_euler_apg_independent_guidances( model: EchoDiT, speaker_latent: torch.Tensor, speaker_mask: torch.Tensor, text_input_ids: torch.Tensor, text_mask: torch.Tensor, rng_seed: int, num_steps: int, cfg_scale_text: float, cfg_scale_speaker: float, cfg_min_t: float, cfg_max_t: float, truncation_factor: float | None, rescale_k: float | None, rescale_sigma: float | None, apg_eta_text: float, apg_eta_speaker: float, apg_momentum_text: float | None, apg_momentum_speaker: float | None, apg_norm_text: float | None, apg_norm_speaker: float | None, speaker_k_scale: float | None, speaker_k_max_layers: int | None, speaker_k_min_t: float | None, block_size: int | None = None, ) -> torch.Tensor: if block_size is None: block_size = 640 if apg_momentum_text is None: apg_momentum_text = 0.0 if apg_momentum_speaker is None: apg_momentum_speaker = 0.0 torch.manual_seed(rng_seed) INIT_SCALE = 0.999 device, dtype = model.device, model.dtype batch_size = text_input_ids.shape[0] t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device) speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask) full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond, text_input_ids], dim=0) full_text_mask = torch.cat([text_mask, text_mask_uncond, text_mask], dim=0) full_speaker_latent = torch.cat([speaker_latent, speaker_latent, speaker_latent_uncond], dim=0) full_speaker_mask = torch.cat([speaker_mask, speaker_mask, speaker_mask_uncond], dim=0) kv_cache_full = model.get_kv_cache( speaker_latent=full_speaker_latent.to(dtype), speaker_mask=full_speaker_mask, text_input_ids=full_text_input_ids, text_mask=full_text_mask, ) kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size) if speaker_k_scale is not None: _multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers) x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32) if truncation_factor is not None: x_t = x_t * truncation_factor buf_text = torch.zeros_like(x_t) buf_speaker = torch.zeros_like(x_t) for i in range(num_steps): t, t_next = t_schedule[i], t_schedule[i+1] has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item() if has_cfg: v_cond, v_uncond_text, v_uncond_speaker = model( x=torch.cat([x_t, x_t, x_t], dim=0).to(dtype), t=(torch.ones((batch_size * 3,), device=device) * t).to(dtype), text_input_ids=None, text_mask=full_text_mask, speaker_latent=None, speaker_mask=full_speaker_mask, kv_cache=kv_cache_full, ).float().chunk(3, dim=0) x0_cond = x_t - t * v_cond x0_uncond_text = x_t - t * v_uncond_text x0_uncond_speaker = x_t - t * v_uncond_speaker diff_text = x0_cond - x0_uncond_text diff_speaker = x0_cond - x0_uncond_speaker buf_text = diff_text + apg_momentum_text * buf_text diff_text = buf_text buf_speaker = diff_speaker + apg_momentum_speaker * buf_speaker diff_speaker = buf_speaker if apg_norm_text is not None: nt = torch.sqrt((diff_text * diff_text).sum(dim=tuple(range(1, diff_text.dim())), keepdim=True) + 1e-12) s = torch.minimum(torch.ones_like(nt), (torch.as_tensor(apg_norm_text, device=device, dtype=diff_text.dtype) / nt)) diff_text = diff_text * s if apg_norm_speaker is not None: ns = torch.sqrt((diff_speaker * diff_speaker).sum(dim=tuple(range(1, diff_speaker.dim())), keepdim=True) + 1e-12) s = torch.minimum(torch.ones_like(ns), (torch.as_tensor(apg_norm_speaker, device=device, dtype=diff_speaker.dtype) / ns)) diff_speaker = diff_speaker * s c_norm = torch.sqrt((x0_cond * x0_cond).sum(dim=tuple(range(1, x0_cond.dim())), keepdim=True) + 1e-12) c_hat = x0_cond / c_norm par_text = (diff_text * c_hat).sum(dim=tuple(range(1, diff_text.dim())), keepdim=True) * c_hat ort_text = diff_text - par_text upd_text = ort_text + apg_eta_text * par_text par_speaker = (diff_speaker * c_hat).sum(dim=tuple(range(1, diff_speaker.dim())), keepdim=True) * c_hat ort_speaker = diff_speaker - par_speaker upd_speaker = ort_speaker + apg_eta_speaker * par_speaker x0_pred = x0_cond + cfg_scale_text * upd_text + cfg_scale_speaker * upd_speaker v_pred = (x_t - x0_pred) / t else: v_pred = model( x=x_t.to(dtype), t=(torch.ones((batch_size,), device=device) * t).to(dtype), text_input_ids=None, text_mask=text_mask, speaker_latent=None, speaker_mask=speaker_mask, kv_cache=kv_cache, ).float() if rescale_k is not None and rescale_sigma is not None: v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma) if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t: _multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers) x_t = x_t + v_pred * (t_next - t) return x_t # router class GuidanceMode(Enum): INDEPENDENT = "independent" APG = "apg" JOINT = "joint" ALTERNATING = "alternating" def sample_euler_cfg_any( model: EchoDiT, speaker_latent: torch.Tensor, speaker_mask: torch.Tensor, text_input_ids: torch.Tensor, text_mask: torch.Tensor, rng_seed: int, guidance_mode: GuidanceMode, num_steps: int, cfg_scale_text: float, cfg_scale_speaker: float | None, cfg_min_t: float, cfg_max_t: float, truncation_factor: float | None, rescale_k: float | None, rescale_sigma: float | None, speaker_k_scale: float | None, speaker_k_min_t: float | None, speaker_k_max_layers: int | None, apg_eta_text: float | None, apg_eta_speaker: float | None, apg_momentum_text: float | None, apg_momentum_speaker: float | None, apg_norm_text: float | None, apg_norm_speaker: float | None, block_size: int | None = None, ) -> torch.Tensor: if guidance_mode == GuidanceMode.INDEPENDENT: assert cfg_scale_speaker is not None, "cfg_scale_speaker must be provided for independent guidances" return sample_euler_cfg_independent_guidances( model=model, speaker_latent=speaker_latent, speaker_mask=speaker_mask, text_input_ids=text_input_ids, text_mask=text_mask, rng_seed=rng_seed, num_steps=num_steps, cfg_scale_text=cfg_scale_text, cfg_scale_speaker=cfg_scale_speaker, cfg_min_t=cfg_min_t, cfg_max_t=cfg_max_t, truncation_factor=truncation_factor, rescale_k=rescale_k, rescale_sigma=rescale_sigma, speaker_k_scale=speaker_k_scale, speaker_k_max_layers=speaker_k_max_layers, speaker_k_min_t=speaker_k_min_t, block_size=block_size, ) elif guidance_mode == GuidanceMode.APG: assert cfg_scale_speaker is not None, "cfg_scale_speaker must be provided for APG" assert apg_eta_text is not None, "apg_eta_text must be provided for APG" assert apg_eta_speaker is not None, "apg_eta_speaker must be provided for APG" return sample_euler_apg_independent_guidances( model=model, speaker_latent=speaker_latent, speaker_mask=speaker_mask, text_input_ids=text_input_ids, text_mask=text_mask, rng_seed=rng_seed, num_steps=num_steps, cfg_scale_text=cfg_scale_text, cfg_scale_speaker=cfg_scale_speaker, cfg_min_t=cfg_min_t, cfg_max_t=cfg_max_t, truncation_factor=truncation_factor, rescale_k=rescale_k, rescale_sigma=rescale_sigma, apg_eta_text=apg_eta_text, apg_eta_speaker=apg_eta_speaker, apg_momentum_text=apg_momentum_text, apg_momentum_speaker=apg_momentum_speaker, apg_norm_text=apg_norm_text, apg_norm_speaker=apg_norm_speaker, speaker_k_scale=speaker_k_scale, speaker_k_max_layers=speaker_k_max_layers, speaker_k_min_t=speaker_k_min_t, block_size=block_size, ) elif guidance_mode == GuidanceMode.JOINT: assert cfg_scale_text == cfg_scale_speaker or cfg_scale_speaker is None, "cfg_scale_text and cfg_scale_speaker must be the same or cfg_scale_speaker must be None" return sample_euler_cfg( model=model, speaker_latent=speaker_latent, speaker_mask=speaker_mask, text_input_ids=text_input_ids, text_mask=text_mask, rng_seed=rng_seed, num_steps=num_steps, cfg_scale=cfg_scale_text, cfg_min_t=cfg_min_t, cfg_max_t=cfg_max_t, truncation_factor=truncation_factor, rescale_k=rescale_k, rescale_sigma=rescale_sigma, speaker_k_scale=speaker_k_scale, speaker_k_max_layers=speaker_k_max_layers, speaker_k_min_t=speaker_k_min_t, block_size=block_size, ) elif guidance_mode == GuidanceMode.ALTERNATING: assert cfg_scale_speaker is not None, "cfg_scale_speaker must be provided for alternating guidances" return sample_euler_cfg_alternating_guidances( model=model, speaker_latent=speaker_latent, speaker_mask=speaker_mask, text_input_ids=text_input_ids, text_mask=text_mask, rng_seed=rng_seed, num_steps=num_steps, cfg_scale_text=cfg_scale_text, cfg_scale_speaker=cfg_scale_speaker, cfg_min_t=cfg_min_t, cfg_max_t=cfg_max_t, truncation_factor=truncation_factor, rescale_k=rescale_k, rescale_sigma=rescale_sigma, speaker_k_scale=speaker_k_scale, speaker_k_max_layers=speaker_k_max_layers, speaker_k_min_t=speaker_k_min_t, block_size=block_size, ) else: raise ValueError(f"Unknown guidance mode: {guidance_mode}")