Spaces:
Running
on
Zero
Running
on
Zero
| from typing import List, Tuple | |
| from enum import Enum | |
| import torch | |
| from model import EchoDiT | |
| # helper | |
| def _get_uncond_text_input_ids_and_mask(batch_size: int, max_length: int, device: str | None = None) -> tuple[torch.Tensor, torch.Tensor]: | |
| # returns zeros for text input ids, and (True, False, False, ... ) for text mask | |
| text_input_ids_uncond = torch.zeros((batch_size, max_length), dtype=torch.int32) | |
| text_mask_uncond = torch.zeros((batch_size, max_length), dtype=torch.bool) | |
| text_mask_uncond[:, 0] = True | |
| if device is not None: | |
| text_input_ids_uncond = text_input_ids_uncond.to(device) | |
| text_mask_uncond = text_mask_uncond.to(device) | |
| return text_input_ids_uncond, text_mask_uncond | |
| # SIMPLE SAMPLER FOR REFERENCE, SHOULD PROBABLY AVOID | |
| def sample_euler_cfg_simple( | |
| model: EchoDiT, | |
| speaker_latent: torch.Tensor, | |
| speaker_mask: torch.Tensor, | |
| text_input_ids: torch.Tensor, | |
| text_mask: torch.Tensor, | |
| rng_seed: int, | |
| num_steps: int, | |
| cfg_scale: float, | |
| ) -> torch.Tensor: | |
| device, dtype = model.device, model.dtype | |
| batch_size = text_input_ids.shape[0] | |
| torch.manual_seed(rng_seed) | |
| t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) | |
| text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device) | |
| speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask) | |
| full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond], dim=0) | |
| full_text_mask = torch.cat([text_mask, text_mask_uncond], dim=0) | |
| full_speaker_latent = torch.cat([speaker_latent, speaker_latent_uncond], dim=0) | |
| full_speaker_mask = torch.cat([speaker_mask, speaker_mask_uncond], dim=0) | |
| kv_cache = model.get_kv_cache( | |
| speaker_latent=full_speaker_latent.to(dtype), | |
| speaker_mask=full_speaker_mask, | |
| text_input_ids=full_text_input_ids, | |
| text_mask=full_text_mask, | |
| ) | |
| x_t = torch.randn((batch_size, 640, 80), device=device, dtype=torch.float32) | |
| for i in range(num_steps): | |
| t, t_next = t_schedule[i], t_schedule[i+1] | |
| v_cond, v_uncond = model( | |
| x=torch.cat([x_t, x_t], dim=0).to(dtype), | |
| t=(torch.ones((batch_size * 2,), device=device) * t).to(dtype), | |
| text_input_ids=None, | |
| text_mask=full_text_mask, | |
| speaker_latent=None, | |
| speaker_mask=full_speaker_mask, | |
| kv_cache=kv_cache, | |
| ).float().chunk(2, dim=0) | |
| v_pred = v_cond + cfg_scale * (v_cond - v_uncond) | |
| # note: x_0_pred is x_t - v_pred * t | |
| x_t = x_t + v_pred * (t_next - t) | |
| return x_t | |
| ###### | |
| def _temporal_score_rescale(v_pred: torch.Tensor, x_t: torch.Tensor, t: float, rescale_k: float, rescale_sigma: float) -> torch.Tensor: | |
| if t < 1: | |
| snr = (1 - t) ** 2 / (t ** 2) | |
| ratio = (snr * rescale_sigma ** 2 + 1) / (snr * rescale_sigma ** 2 / rescale_k + 1) | |
| return 1 / (1 - t) * (ratio * ((1 - t) * v_pred + x_t) - x_t) | |
| return v_pred | |
| def _get_first_n_kv_cache(kv_cache: List[List[torch.Tensor]], n: int) -> List[List[torch.Tensor]]: | |
| return [[kv_cache[i][0][:n], kv_cache[i][1][:n]] for i in range(len(kv_cache))] | |
| def _multiply_speaker_kv_cache( | |
| kv_cache: List[List[torch.Tensor]], | |
| scale: float, | |
| text_length: int, | |
| max_layers: int = 24, | |
| ) -> List[List[torch.Tensor]]: | |
| # multiplies speaker kv cache by scale | |
| # speaker keys start after text keys (at position text_length) | |
| for i in range(min(max_layers, len(kv_cache))): | |
| for j in range(len(kv_cache[i])): | |
| kv_cache[i][j][:, text_length:] *= scale | |
| def sample_euler_cfg( | |
| model: EchoDiT, | |
| speaker_latent: torch.Tensor, | |
| speaker_mask: torch.Tensor, | |
| text_input_ids: torch.Tensor, | |
| text_mask: torch.Tensor, | |
| rng_seed: int, | |
| num_steps: int, | |
| cfg_scale: float, | |
| cfg_min_t: float, | |
| cfg_max_t: float, | |
| truncation_factor: float | None, | |
| rescale_k: float | None, | |
| rescale_sigma: float | None, | |
| speaker_k_scale: float | None, | |
| speaker_k_max_layers: int | None, | |
| speaker_k_min_t: float | None, | |
| block_size: int | None = None, | |
| ) -> torch.Tensor: | |
| if block_size is None: | |
| block_size = 640 | |
| torch.manual_seed(rng_seed) | |
| INIT_SCALE = 0.999 | |
| device, dtype = model.device, model.dtype | |
| batch_size = text_input_ids.shape[0] | |
| t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE | |
| text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device) | |
| speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask) | |
| full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond], dim=0) | |
| full_text_mask = torch.cat([text_mask, text_mask_uncond], dim=0) | |
| full_speaker_latent = torch.cat([speaker_latent, speaker_latent_uncond], dim=0) | |
| full_speaker_mask = torch.cat([speaker_mask, speaker_mask_uncond], dim=0) | |
| kv_cache_full = model.get_kv_cache( | |
| speaker_latent=full_speaker_latent.to(dtype), | |
| speaker_mask=full_speaker_mask, | |
| text_input_ids=full_text_input_ids, | |
| text_mask=full_text_mask, | |
| ) # could make faster by not computing fully / recomputing for unconditional batch elements | |
| kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size) | |
| if speaker_k_scale is not None: | |
| _multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers) | |
| x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32) | |
| if truncation_factor is not None: | |
| x_t = x_t * truncation_factor | |
| for i in range(num_steps): | |
| t, t_next = t_schedule[i], t_schedule[i+1] | |
| has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item() | |
| if has_cfg: | |
| v_cond, v_uncond = model( | |
| x=torch.cat([x_t, x_t], dim=0).to(dtype), | |
| t=(torch.ones((batch_size * 2,), device=device) * t).to(dtype), | |
| text_input_ids=None, | |
| text_mask=full_text_mask, | |
| speaker_latent=None, | |
| speaker_mask=full_speaker_mask, | |
| kv_cache=kv_cache_full, | |
| ).float().chunk(2, dim=0) | |
| v_pred = v_cond + cfg_scale * (v_cond - v_uncond) | |
| else: | |
| v_pred = model( | |
| x=x_t.to(dtype), | |
| t=(torch.ones((batch_size,), device=device) * t).to(dtype), | |
| text_input_ids=None, | |
| text_mask=text_mask, | |
| speaker_latent=None, | |
| speaker_mask=speaker_mask, | |
| kv_cache=kv_cache, | |
| ).float() | |
| if rescale_k is not None and rescale_sigma is not None: | |
| v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma) | |
| if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t: | |
| _multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers) | |
| x_t = x_t + v_pred * (t_next - t) | |
| return x_t | |
| def sample_euler_cfg_independent_guidances( | |
| model: EchoDiT, | |
| speaker_latent: torch.Tensor, | |
| speaker_mask: torch.Tensor, | |
| text_input_ids: torch.Tensor, | |
| text_mask: torch.Tensor, | |
| rng_seed: int, | |
| num_steps: int, | |
| cfg_scale_text: float, | |
| cfg_scale_speaker: float, | |
| cfg_min_t: float, | |
| cfg_max_t: float, | |
| truncation_factor: float | None, | |
| rescale_k: float | None, | |
| rescale_sigma: float | None, | |
| speaker_k_scale: float | None, | |
| speaker_k_max_layers: int | None, | |
| speaker_k_min_t: float | None, | |
| block_size: int | None = None, | |
| ) -> torch.Tensor: | |
| if block_size is None: | |
| block_size = 640 | |
| torch.manual_seed(rng_seed) | |
| INIT_SCALE = 0.999 | |
| device, dtype = model.device, model.dtype | |
| batch_size = text_input_ids.shape[0] | |
| t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE | |
| text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device) | |
| speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask) | |
| full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond, text_input_ids], dim=0) | |
| full_text_mask = torch.cat([text_mask, text_mask_uncond, text_mask], dim=0) | |
| full_speaker_latent = torch.cat([speaker_latent, speaker_latent, speaker_latent_uncond], dim=0) | |
| full_speaker_mask = torch.cat([speaker_mask, speaker_mask, speaker_mask_uncond], dim=0) | |
| kv_cache_full = model.get_kv_cache( | |
| speaker_latent=full_speaker_latent.to(dtype), | |
| speaker_mask=full_speaker_mask, | |
| text_input_ids=full_text_input_ids, | |
| text_mask=full_text_mask, | |
| ) | |
| kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size) | |
| if speaker_k_scale is not None: | |
| _multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers) | |
| x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32) | |
| if truncation_factor is not None: | |
| x_t = x_t * truncation_factor | |
| for i in range(num_steps): | |
| t, t_next = t_schedule[i], t_schedule[i+1] | |
| has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item() | |
| if has_cfg: | |
| v_cond, v_uncond_text, v_uncond_speaker = model( | |
| x=torch.cat([x_t, x_t, x_t], dim=0).to(dtype), | |
| t=(torch.ones((batch_size * 3,), device=device) * t).to(dtype), | |
| text_input_ids=None, | |
| text_mask=full_text_mask, | |
| speaker_latent=None, | |
| speaker_mask=full_speaker_mask, | |
| kv_cache=kv_cache_full, | |
| ).float().chunk(3, dim=0) | |
| v_pred = v_cond + cfg_scale_text * (v_cond - v_uncond_text) + cfg_scale_speaker * (v_cond - v_uncond_speaker) | |
| else: | |
| v_pred = model( | |
| x=x_t.to(dtype), | |
| t=(torch.ones((batch_size,), device=device) * t).to(dtype), | |
| text_input_ids=None, | |
| text_mask=text_mask, | |
| speaker_latent=None, | |
| speaker_mask=speaker_mask, | |
| kv_cache=kv_cache, | |
| ).float() | |
| if rescale_k is not None and rescale_sigma is not None: | |
| v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma) | |
| if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t: | |
| _multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers) | |
| x_t = x_t + v_pred * (t_next - t) | |
| return x_t | |
| def sample_euler_cfg_alternating_guidances( | |
| model: EchoDiT, | |
| speaker_latent: torch.Tensor, | |
| speaker_mask: torch.Tensor, | |
| text_input_ids: torch.Tensor, | |
| text_mask: torch.Tensor, | |
| rng_seed: int, | |
| num_steps: int, | |
| cfg_scale_text: float, | |
| cfg_scale_speaker: float, | |
| cfg_min_t: float, | |
| cfg_max_t: float, | |
| truncation_factor: float | None, | |
| rescale_k: float | None, | |
| rescale_sigma: float | None, | |
| speaker_k_scale: float | None, | |
| speaker_k_max_layers: int | None, | |
| speaker_k_min_t: float | None, | |
| block_size: int | None = None, | |
| ) -> torch.Tensor: | |
| if block_size is None: | |
| block_size = 640 | |
| torch.manual_seed(rng_seed) | |
| INIT_SCALE = 0.999 | |
| device, dtype = model.device, model.dtype | |
| batch_size = text_input_ids.shape[0] | |
| t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE | |
| text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device) | |
| # TODO THIS / THE BELOW IS TECHNICALLY INCORRECT, AS IT ASSUMES A CAUSAL TEXT ENCODER (which is not the case) | |
| # IF THE TEXT ENCODER WERE CAUSAL, THEN USING AN UNCOND TEXT MASK ON COND TEXT INPUTS GIVES YOU AN UNCOND STATE DUE TO BOS=0 | |
| # HOWEVER, MIGHT NOT MAKE MUCH OF A DIFFERENCE | |
| # CHANGED ALL OTHER SAMPLERS TO USE CORRECT UNCONDITIONAL CACHES | |
| speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask) | |
| full_text_input_ids = torch.cat([text_input_ids, text_input_ids], dim=0) | |
| full_text_mask = torch.cat([text_mask, text_mask_uncond], dim=0) | |
| full_speaker_latent = torch.cat([speaker_latent, speaker_latent_uncond], dim=0) | |
| full_speaker_mask = torch.cat([speaker_mask, speaker_mask_uncond], dim=0) | |
| kv_cache_full = model.get_kv_cache( | |
| speaker_latent=full_speaker_latent.to(dtype), | |
| speaker_mask=full_speaker_mask, | |
| text_input_ids=full_text_input_ids, | |
| text_mask=full_text_mask, | |
| ) | |
| kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size) | |
| if speaker_k_scale is not None: | |
| _multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers) | |
| x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32) | |
| if truncation_factor is not None: | |
| x_t = x_t * truncation_factor | |
| for i in range(num_steps): | |
| t, t_next = t_schedule[i], t_schedule[i+1] | |
| has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item() | |
| if has_cfg: | |
| v_cond, v_uncond = model( | |
| x=torch.cat([x_t, x_t], dim=0).to(dtype), | |
| t=(torch.ones((batch_size * 2,), device=device) * t).to(dtype), | |
| text_input_ids=None, | |
| text_mask=torch.cat([text_mask, text_mask_uncond if i % 2 == 0 else text_mask], dim=0), | |
| speaker_latent=None, | |
| speaker_mask=torch.cat([speaker_mask, speaker_mask if i % 2 == 0 else speaker_mask_uncond], dim=0), | |
| kv_cache=kv_cache_full, | |
| ).float().chunk(2, dim=0) | |
| v_pred = v_cond + (cfg_scale_text if i % 2 == 0 else cfg_scale_speaker) * (v_cond - v_uncond) | |
| else: | |
| v_pred = model( | |
| x=x_t.to(dtype), | |
| t=(torch.ones((batch_size,), device=device) * t).to(dtype), | |
| text_input_ids=None, | |
| text_mask=text_mask, | |
| speaker_latent=None, | |
| speaker_mask=speaker_mask, | |
| kv_cache=kv_cache, | |
| ).float() | |
| if rescale_k is not None and rescale_sigma is not None: | |
| v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma) | |
| if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t: | |
| _multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers) | |
| x_t = x_t + v_pred * (t_next - t) | |
| return x_t | |
| def sample_euler_apg_independent_guidances( | |
| model: EchoDiT, | |
| speaker_latent: torch.Tensor, | |
| speaker_mask: torch.Tensor, | |
| text_input_ids: torch.Tensor, | |
| text_mask: torch.Tensor, | |
| rng_seed: int, | |
| num_steps: int, | |
| cfg_scale_text: float, | |
| cfg_scale_speaker: float, | |
| cfg_min_t: float, | |
| cfg_max_t: float, | |
| truncation_factor: float | None, | |
| rescale_k: float | None, | |
| rescale_sigma: float | None, | |
| apg_eta_text: float, | |
| apg_eta_speaker: float, | |
| apg_momentum_text: float | None, | |
| apg_momentum_speaker: float | None, | |
| apg_norm_text: float | None, | |
| apg_norm_speaker: float | None, | |
| speaker_k_scale: float | None, | |
| speaker_k_max_layers: int | None, | |
| speaker_k_min_t: float | None, | |
| block_size: int | None = None, | |
| ) -> torch.Tensor: | |
| if block_size is None: | |
| block_size = 640 | |
| if apg_momentum_text is None: | |
| apg_momentum_text = 0.0 | |
| if apg_momentum_speaker is None: | |
| apg_momentum_speaker = 0.0 | |
| torch.manual_seed(rng_seed) | |
| INIT_SCALE = 0.999 | |
| device, dtype = model.device, model.dtype | |
| batch_size = text_input_ids.shape[0] | |
| t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE | |
| text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device) | |
| speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask) | |
| full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond, text_input_ids], dim=0) | |
| full_text_mask = torch.cat([text_mask, text_mask_uncond, text_mask], dim=0) | |
| full_speaker_latent = torch.cat([speaker_latent, speaker_latent, speaker_latent_uncond], dim=0) | |
| full_speaker_mask = torch.cat([speaker_mask, speaker_mask, speaker_mask_uncond], dim=0) | |
| kv_cache_full = model.get_kv_cache( | |
| speaker_latent=full_speaker_latent.to(dtype), | |
| speaker_mask=full_speaker_mask, | |
| text_input_ids=full_text_input_ids, | |
| text_mask=full_text_mask, | |
| ) | |
| kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size) | |
| if speaker_k_scale is not None: | |
| _multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers) | |
| x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32) | |
| if truncation_factor is not None: | |
| x_t = x_t * truncation_factor | |
| buf_text = torch.zeros_like(x_t) | |
| buf_speaker = torch.zeros_like(x_t) | |
| for i in range(num_steps): | |
| t, t_next = t_schedule[i], t_schedule[i+1] | |
| has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item() | |
| if has_cfg: | |
| v_cond, v_uncond_text, v_uncond_speaker = model( | |
| x=torch.cat([x_t, x_t, x_t], dim=0).to(dtype), | |
| t=(torch.ones((batch_size * 3,), device=device) * t).to(dtype), | |
| text_input_ids=None, | |
| text_mask=full_text_mask, | |
| speaker_latent=None, | |
| speaker_mask=full_speaker_mask, | |
| kv_cache=kv_cache_full, | |
| ).float().chunk(3, dim=0) | |
| x0_cond = x_t - t * v_cond | |
| x0_uncond_text = x_t - t * v_uncond_text | |
| x0_uncond_speaker = x_t - t * v_uncond_speaker | |
| diff_text = x0_cond - x0_uncond_text | |
| diff_speaker = x0_cond - x0_uncond_speaker | |
| buf_text = diff_text + apg_momentum_text * buf_text | |
| diff_text = buf_text | |
| buf_speaker = diff_speaker + apg_momentum_speaker * buf_speaker | |
| diff_speaker = buf_speaker | |
| if apg_norm_text is not None: | |
| nt = torch.sqrt((diff_text * diff_text).sum(dim=tuple(range(1, diff_text.dim())), keepdim=True) + 1e-12) | |
| s = torch.minimum(torch.ones_like(nt), (torch.as_tensor(apg_norm_text, device=device, dtype=diff_text.dtype) / nt)) | |
| diff_text = diff_text * s | |
| if apg_norm_speaker is not None: | |
| ns = torch.sqrt((diff_speaker * diff_speaker).sum(dim=tuple(range(1, diff_speaker.dim())), keepdim=True) + 1e-12) | |
| s = torch.minimum(torch.ones_like(ns), (torch.as_tensor(apg_norm_speaker, device=device, dtype=diff_speaker.dtype) / ns)) | |
| diff_speaker = diff_speaker * s | |
| c_norm = torch.sqrt((x0_cond * x0_cond).sum(dim=tuple(range(1, x0_cond.dim())), keepdim=True) + 1e-12) | |
| c_hat = x0_cond / c_norm | |
| par_text = (diff_text * c_hat).sum(dim=tuple(range(1, diff_text.dim())), keepdim=True) * c_hat | |
| ort_text = diff_text - par_text | |
| upd_text = ort_text + apg_eta_text * par_text | |
| par_speaker = (diff_speaker * c_hat).sum(dim=tuple(range(1, diff_speaker.dim())), keepdim=True) * c_hat | |
| ort_speaker = diff_speaker - par_speaker | |
| upd_speaker = ort_speaker + apg_eta_speaker * par_speaker | |
| x0_pred = x0_cond + cfg_scale_text * upd_text + cfg_scale_speaker * upd_speaker | |
| v_pred = (x_t - x0_pred) / t | |
| else: | |
| v_pred = model( | |
| x=x_t.to(dtype), | |
| t=(torch.ones((batch_size,), device=device) * t).to(dtype), | |
| text_input_ids=None, | |
| text_mask=text_mask, | |
| speaker_latent=None, | |
| speaker_mask=speaker_mask, | |
| kv_cache=kv_cache, | |
| ).float() | |
| if rescale_k is not None and rescale_sigma is not None: | |
| v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma) | |
| if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t: | |
| _multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers) | |
| x_t = x_t + v_pred * (t_next - t) | |
| return x_t | |
| # router | |
| class GuidanceMode(Enum): | |
| INDEPENDENT = "independent" | |
| APG = "apg" | |
| JOINT = "joint" | |
| ALTERNATING = "alternating" | |
| def sample_euler_cfg_any( | |
| model: EchoDiT, | |
| speaker_latent: torch.Tensor, | |
| speaker_mask: torch.Tensor, | |
| text_input_ids: torch.Tensor, | |
| text_mask: torch.Tensor, | |
| rng_seed: int, | |
| guidance_mode: GuidanceMode, | |
| num_steps: int, | |
| cfg_scale_text: float, | |
| cfg_scale_speaker: float | None, | |
| cfg_min_t: float, | |
| cfg_max_t: float, | |
| truncation_factor: float | None, | |
| rescale_k: float | None, | |
| rescale_sigma: float | None, | |
| speaker_k_scale: float | None, | |
| speaker_k_min_t: float | None, | |
| speaker_k_max_layers: int | None, | |
| apg_eta_text: float | None, | |
| apg_eta_speaker: float | None, | |
| apg_momentum_text: float | None, | |
| apg_momentum_speaker: float | None, | |
| apg_norm_text: float | None, | |
| apg_norm_speaker: float | None, | |
| block_size: int | None = None, | |
| ) -> torch.Tensor: | |
| if guidance_mode == GuidanceMode.INDEPENDENT: | |
| assert cfg_scale_speaker is not None, "cfg_scale_speaker must be provided for independent guidances" | |
| return sample_euler_cfg_independent_guidances( | |
| model=model, | |
| speaker_latent=speaker_latent, | |
| speaker_mask=speaker_mask, | |
| text_input_ids=text_input_ids, | |
| text_mask=text_mask, | |
| rng_seed=rng_seed, | |
| num_steps=num_steps, | |
| cfg_scale_text=cfg_scale_text, | |
| cfg_scale_speaker=cfg_scale_speaker, | |
| cfg_min_t=cfg_min_t, | |
| cfg_max_t=cfg_max_t, | |
| truncation_factor=truncation_factor, | |
| rescale_k=rescale_k, | |
| rescale_sigma=rescale_sigma, | |
| speaker_k_scale=speaker_k_scale, | |
| speaker_k_max_layers=speaker_k_max_layers, | |
| speaker_k_min_t=speaker_k_min_t, | |
| block_size=block_size, | |
| ) | |
| elif guidance_mode == GuidanceMode.APG: | |
| assert cfg_scale_speaker is not None, "cfg_scale_speaker must be provided for APG" | |
| assert apg_eta_text is not None, "apg_eta_text must be provided for APG" | |
| assert apg_eta_speaker is not None, "apg_eta_speaker must be provided for APG" | |
| return sample_euler_apg_independent_guidances( | |
| model=model, | |
| speaker_latent=speaker_latent, | |
| speaker_mask=speaker_mask, | |
| text_input_ids=text_input_ids, | |
| text_mask=text_mask, | |
| rng_seed=rng_seed, | |
| num_steps=num_steps, | |
| cfg_scale_text=cfg_scale_text, | |
| cfg_scale_speaker=cfg_scale_speaker, | |
| cfg_min_t=cfg_min_t, | |
| cfg_max_t=cfg_max_t, | |
| truncation_factor=truncation_factor, | |
| rescale_k=rescale_k, | |
| rescale_sigma=rescale_sigma, | |
| apg_eta_text=apg_eta_text, | |
| apg_eta_speaker=apg_eta_speaker, | |
| apg_momentum_text=apg_momentum_text, | |
| apg_momentum_speaker=apg_momentum_speaker, | |
| apg_norm_text=apg_norm_text, | |
| apg_norm_speaker=apg_norm_speaker, | |
| speaker_k_scale=speaker_k_scale, | |
| speaker_k_max_layers=speaker_k_max_layers, | |
| speaker_k_min_t=speaker_k_min_t, | |
| block_size=block_size, | |
| ) | |
| elif guidance_mode == GuidanceMode.JOINT: | |
| assert cfg_scale_text == cfg_scale_speaker or cfg_scale_speaker is None, "cfg_scale_text and cfg_scale_speaker must be the same or cfg_scale_speaker must be None" | |
| return sample_euler_cfg( | |
| model=model, | |
| speaker_latent=speaker_latent, | |
| speaker_mask=speaker_mask, | |
| text_input_ids=text_input_ids, | |
| text_mask=text_mask, | |
| rng_seed=rng_seed, | |
| num_steps=num_steps, | |
| cfg_scale=cfg_scale_text, | |
| cfg_min_t=cfg_min_t, | |
| cfg_max_t=cfg_max_t, | |
| truncation_factor=truncation_factor, | |
| rescale_k=rescale_k, | |
| rescale_sigma=rescale_sigma, | |
| speaker_k_scale=speaker_k_scale, | |
| speaker_k_max_layers=speaker_k_max_layers, | |
| speaker_k_min_t=speaker_k_min_t, | |
| block_size=block_size, | |
| ) | |
| elif guidance_mode == GuidanceMode.ALTERNATING: | |
| assert cfg_scale_speaker is not None, "cfg_scale_speaker must be provided for alternating guidances" | |
| return sample_euler_cfg_alternating_guidances( | |
| model=model, | |
| speaker_latent=speaker_latent, | |
| speaker_mask=speaker_mask, | |
| text_input_ids=text_input_ids, | |
| text_mask=text_mask, | |
| rng_seed=rng_seed, | |
| num_steps=num_steps, | |
| cfg_scale_text=cfg_scale_text, | |
| cfg_scale_speaker=cfg_scale_speaker, | |
| cfg_min_t=cfg_min_t, | |
| cfg_max_t=cfg_max_t, | |
| truncation_factor=truncation_factor, | |
| rescale_k=rescale_k, | |
| rescale_sigma=rescale_sigma, | |
| speaker_k_scale=speaker_k_scale, | |
| speaker_k_max_layers=speaker_k_max_layers, | |
| speaker_k_min_t=speaker_k_min_t, | |
| block_size=block_size, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown guidance mode: {guidance_mode}") | |