echo-tts-preview / samplers.py
jordand's picture
Upload 21 files
60cc71a verified
raw
history blame
26.2 kB
from typing import List, Tuple
from enum import Enum
import torch
from model import EchoDiT
# helper
def _get_uncond_text_input_ids_and_mask(batch_size: int, max_length: int, device: str | None = None) -> tuple[torch.Tensor, torch.Tensor]:
# returns zeros for text input ids, and (True, False, False, ... ) for text mask
text_input_ids_uncond = torch.zeros((batch_size, max_length), dtype=torch.int32)
text_mask_uncond = torch.zeros((batch_size, max_length), dtype=torch.bool)
text_mask_uncond[:, 0] = True
if device is not None:
text_input_ids_uncond = text_input_ids_uncond.to(device)
text_mask_uncond = text_mask_uncond.to(device)
return text_input_ids_uncond, text_mask_uncond
# SIMPLE SAMPLER FOR REFERENCE, SHOULD PROBABLY AVOID
@torch.inference_mode()
def sample_euler_cfg_simple(
model: EchoDiT,
speaker_latent: torch.Tensor,
speaker_mask: torch.Tensor,
text_input_ids: torch.Tensor,
text_mask: torch.Tensor,
rng_seed: int,
num_steps: int,
cfg_scale: float,
) -> torch.Tensor:
device, dtype = model.device, model.dtype
batch_size = text_input_ids.shape[0]
torch.manual_seed(rng_seed)
t_schedule = torch.linspace(1., 0., num_steps + 1, device=device)
text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device)
speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask)
full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond], dim=0)
full_text_mask = torch.cat([text_mask, text_mask_uncond], dim=0)
full_speaker_latent = torch.cat([speaker_latent, speaker_latent_uncond], dim=0)
full_speaker_mask = torch.cat([speaker_mask, speaker_mask_uncond], dim=0)
kv_cache = model.get_kv_cache(
speaker_latent=full_speaker_latent.to(dtype),
speaker_mask=full_speaker_mask,
text_input_ids=full_text_input_ids,
text_mask=full_text_mask,
)
x_t = torch.randn((batch_size, 640, 80), device=device, dtype=torch.float32)
for i in range(num_steps):
t, t_next = t_schedule[i], t_schedule[i+1]
v_cond, v_uncond = model(
x=torch.cat([x_t, x_t], dim=0).to(dtype),
t=(torch.ones((batch_size * 2,), device=device) * t).to(dtype),
text_input_ids=None,
text_mask=full_text_mask,
speaker_latent=None,
speaker_mask=full_speaker_mask,
kv_cache=kv_cache,
).float().chunk(2, dim=0)
v_pred = v_cond + cfg_scale * (v_cond - v_uncond)
# note: x_0_pred is x_t - v_pred * t
x_t = x_t + v_pred * (t_next - t)
return x_t
######
def _temporal_score_rescale(v_pred: torch.Tensor, x_t: torch.Tensor, t: float, rescale_k: float, rescale_sigma: float) -> torch.Tensor:
if t < 1:
snr = (1 - t) ** 2 / (t ** 2)
ratio = (snr * rescale_sigma ** 2 + 1) / (snr * rescale_sigma ** 2 / rescale_k + 1)
return 1 / (1 - t) * (ratio * ((1 - t) * v_pred + x_t) - x_t)
return v_pred
def _get_first_n_kv_cache(kv_cache: List[List[torch.Tensor]], n: int) -> List[List[torch.Tensor]]:
return [[kv_cache[i][0][:n], kv_cache[i][1][:n]] for i in range(len(kv_cache))]
def _multiply_speaker_kv_cache(
kv_cache: List[List[torch.Tensor]],
scale: float,
text_length: int,
max_layers: int = 24,
) -> List[List[torch.Tensor]]:
# multiplies speaker kv cache by scale
# speaker keys start after text keys (at position text_length)
for i in range(min(max_layers, len(kv_cache))):
for j in range(len(kv_cache[i])):
kv_cache[i][j][:, text_length:] *= scale
@torch.inference_mode()
def sample_euler_cfg(
model: EchoDiT,
speaker_latent: torch.Tensor,
speaker_mask: torch.Tensor,
text_input_ids: torch.Tensor,
text_mask: torch.Tensor,
rng_seed: int,
num_steps: int,
cfg_scale: float,
cfg_min_t: float,
cfg_max_t: float,
truncation_factor: float | None,
rescale_k: float | None,
rescale_sigma: float | None,
speaker_k_scale: float | None,
speaker_k_max_layers: int | None,
speaker_k_min_t: float | None,
block_size: int | None = None,
) -> torch.Tensor:
if block_size is None:
block_size = 640
torch.manual_seed(rng_seed)
INIT_SCALE = 0.999
device, dtype = model.device, model.dtype
batch_size = text_input_ids.shape[0]
t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE
text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device)
speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask)
full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond], dim=0)
full_text_mask = torch.cat([text_mask, text_mask_uncond], dim=0)
full_speaker_latent = torch.cat([speaker_latent, speaker_latent_uncond], dim=0)
full_speaker_mask = torch.cat([speaker_mask, speaker_mask_uncond], dim=0)
kv_cache_full = model.get_kv_cache(
speaker_latent=full_speaker_latent.to(dtype),
speaker_mask=full_speaker_mask,
text_input_ids=full_text_input_ids,
text_mask=full_text_mask,
) # could make faster by not computing fully / recomputing for unconditional batch elements
kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size)
if speaker_k_scale is not None:
_multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32)
if truncation_factor is not None:
x_t = x_t * truncation_factor
for i in range(num_steps):
t, t_next = t_schedule[i], t_schedule[i+1]
has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item()
if has_cfg:
v_cond, v_uncond = model(
x=torch.cat([x_t, x_t], dim=0).to(dtype),
t=(torch.ones((batch_size * 2,), device=device) * t).to(dtype),
text_input_ids=None,
text_mask=full_text_mask,
speaker_latent=None,
speaker_mask=full_speaker_mask,
kv_cache=kv_cache_full,
).float().chunk(2, dim=0)
v_pred = v_cond + cfg_scale * (v_cond - v_uncond)
else:
v_pred = model(
x=x_t.to(dtype),
t=(torch.ones((batch_size,), device=device) * t).to(dtype),
text_input_ids=None,
text_mask=text_mask,
speaker_latent=None,
speaker_mask=speaker_mask,
kv_cache=kv_cache,
).float()
if rescale_k is not None and rescale_sigma is not None:
v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma)
if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t:
_multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
x_t = x_t + v_pred * (t_next - t)
return x_t
@torch.inference_mode()
def sample_euler_cfg_independent_guidances(
model: EchoDiT,
speaker_latent: torch.Tensor,
speaker_mask: torch.Tensor,
text_input_ids: torch.Tensor,
text_mask: torch.Tensor,
rng_seed: int,
num_steps: int,
cfg_scale_text: float,
cfg_scale_speaker: float,
cfg_min_t: float,
cfg_max_t: float,
truncation_factor: float | None,
rescale_k: float | None,
rescale_sigma: float | None,
speaker_k_scale: float | None,
speaker_k_max_layers: int | None,
speaker_k_min_t: float | None,
block_size: int | None = None,
) -> torch.Tensor:
if block_size is None:
block_size = 640
torch.manual_seed(rng_seed)
INIT_SCALE = 0.999
device, dtype = model.device, model.dtype
batch_size = text_input_ids.shape[0]
t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE
text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device)
speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask)
full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond, text_input_ids], dim=0)
full_text_mask = torch.cat([text_mask, text_mask_uncond, text_mask], dim=0)
full_speaker_latent = torch.cat([speaker_latent, speaker_latent, speaker_latent_uncond], dim=0)
full_speaker_mask = torch.cat([speaker_mask, speaker_mask, speaker_mask_uncond], dim=0)
kv_cache_full = model.get_kv_cache(
speaker_latent=full_speaker_latent.to(dtype),
speaker_mask=full_speaker_mask,
text_input_ids=full_text_input_ids,
text_mask=full_text_mask,
)
kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size)
if speaker_k_scale is not None:
_multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32)
if truncation_factor is not None:
x_t = x_t * truncation_factor
for i in range(num_steps):
t, t_next = t_schedule[i], t_schedule[i+1]
has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item()
if has_cfg:
v_cond, v_uncond_text, v_uncond_speaker = model(
x=torch.cat([x_t, x_t, x_t], dim=0).to(dtype),
t=(torch.ones((batch_size * 3,), device=device) * t).to(dtype),
text_input_ids=None,
text_mask=full_text_mask,
speaker_latent=None,
speaker_mask=full_speaker_mask,
kv_cache=kv_cache_full,
).float().chunk(3, dim=0)
v_pred = v_cond + cfg_scale_text * (v_cond - v_uncond_text) + cfg_scale_speaker * (v_cond - v_uncond_speaker)
else:
v_pred = model(
x=x_t.to(dtype),
t=(torch.ones((batch_size,), device=device) * t).to(dtype),
text_input_ids=None,
text_mask=text_mask,
speaker_latent=None,
speaker_mask=speaker_mask,
kv_cache=kv_cache,
).float()
if rescale_k is not None and rescale_sigma is not None:
v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma)
if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t:
_multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
x_t = x_t + v_pred * (t_next - t)
return x_t
@torch.inference_mode()
def sample_euler_cfg_alternating_guidances(
model: EchoDiT,
speaker_latent: torch.Tensor,
speaker_mask: torch.Tensor,
text_input_ids: torch.Tensor,
text_mask: torch.Tensor,
rng_seed: int,
num_steps: int,
cfg_scale_text: float,
cfg_scale_speaker: float,
cfg_min_t: float,
cfg_max_t: float,
truncation_factor: float | None,
rescale_k: float | None,
rescale_sigma: float | None,
speaker_k_scale: float | None,
speaker_k_max_layers: int | None,
speaker_k_min_t: float | None,
block_size: int | None = None,
) -> torch.Tensor:
if block_size is None:
block_size = 640
torch.manual_seed(rng_seed)
INIT_SCALE = 0.999
device, dtype = model.device, model.dtype
batch_size = text_input_ids.shape[0]
t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE
text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device)
# TODO THIS / THE BELOW IS TECHNICALLY INCORRECT, AS IT ASSUMES A CAUSAL TEXT ENCODER (which is not the case)
# IF THE TEXT ENCODER WERE CAUSAL, THEN USING AN UNCOND TEXT MASK ON COND TEXT INPUTS GIVES YOU AN UNCOND STATE DUE TO BOS=0
# HOWEVER, MIGHT NOT MAKE MUCH OF A DIFFERENCE
# CHANGED ALL OTHER SAMPLERS TO USE CORRECT UNCONDITIONAL CACHES
speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask)
full_text_input_ids = torch.cat([text_input_ids, text_input_ids], dim=0)
full_text_mask = torch.cat([text_mask, text_mask_uncond], dim=0)
full_speaker_latent = torch.cat([speaker_latent, speaker_latent_uncond], dim=0)
full_speaker_mask = torch.cat([speaker_mask, speaker_mask_uncond], dim=0)
kv_cache_full = model.get_kv_cache(
speaker_latent=full_speaker_latent.to(dtype),
speaker_mask=full_speaker_mask,
text_input_ids=full_text_input_ids,
text_mask=full_text_mask,
)
kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size)
if speaker_k_scale is not None:
_multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32)
if truncation_factor is not None:
x_t = x_t * truncation_factor
for i in range(num_steps):
t, t_next = t_schedule[i], t_schedule[i+1]
has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item()
if has_cfg:
v_cond, v_uncond = model(
x=torch.cat([x_t, x_t], dim=0).to(dtype),
t=(torch.ones((batch_size * 2,), device=device) * t).to(dtype),
text_input_ids=None,
text_mask=torch.cat([text_mask, text_mask_uncond if i % 2 == 0 else text_mask], dim=0),
speaker_latent=None,
speaker_mask=torch.cat([speaker_mask, speaker_mask if i % 2 == 0 else speaker_mask_uncond], dim=0),
kv_cache=kv_cache_full,
).float().chunk(2, dim=0)
v_pred = v_cond + (cfg_scale_text if i % 2 == 0 else cfg_scale_speaker) * (v_cond - v_uncond)
else:
v_pred = model(
x=x_t.to(dtype),
t=(torch.ones((batch_size,), device=device) * t).to(dtype),
text_input_ids=None,
text_mask=text_mask,
speaker_latent=None,
speaker_mask=speaker_mask,
kv_cache=kv_cache,
).float()
if rescale_k is not None and rescale_sigma is not None:
v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma)
if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t:
_multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
x_t = x_t + v_pred * (t_next - t)
return x_t
@torch.inference_mode()
def sample_euler_apg_independent_guidances(
model: EchoDiT,
speaker_latent: torch.Tensor,
speaker_mask: torch.Tensor,
text_input_ids: torch.Tensor,
text_mask: torch.Tensor,
rng_seed: int,
num_steps: int,
cfg_scale_text: float,
cfg_scale_speaker: float,
cfg_min_t: float,
cfg_max_t: float,
truncation_factor: float | None,
rescale_k: float | None,
rescale_sigma: float | None,
apg_eta_text: float,
apg_eta_speaker: float,
apg_momentum_text: float | None,
apg_momentum_speaker: float | None,
apg_norm_text: float | None,
apg_norm_speaker: float | None,
speaker_k_scale: float | None,
speaker_k_max_layers: int | None,
speaker_k_min_t: float | None,
block_size: int | None = None,
) -> torch.Tensor:
if block_size is None:
block_size = 640
if apg_momentum_text is None:
apg_momentum_text = 0.0
if apg_momentum_speaker is None:
apg_momentum_speaker = 0.0
torch.manual_seed(rng_seed)
INIT_SCALE = 0.999
device, dtype = model.device, model.dtype
batch_size = text_input_ids.shape[0]
t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE
text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device)
speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask)
full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond, text_input_ids], dim=0)
full_text_mask = torch.cat([text_mask, text_mask_uncond, text_mask], dim=0)
full_speaker_latent = torch.cat([speaker_latent, speaker_latent, speaker_latent_uncond], dim=0)
full_speaker_mask = torch.cat([speaker_mask, speaker_mask, speaker_mask_uncond], dim=0)
kv_cache_full = model.get_kv_cache(
speaker_latent=full_speaker_latent.to(dtype),
speaker_mask=full_speaker_mask,
text_input_ids=full_text_input_ids,
text_mask=full_text_mask,
)
kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size)
if speaker_k_scale is not None:
_multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32)
if truncation_factor is not None:
x_t = x_t * truncation_factor
buf_text = torch.zeros_like(x_t)
buf_speaker = torch.zeros_like(x_t)
for i in range(num_steps):
t, t_next = t_schedule[i], t_schedule[i+1]
has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item()
if has_cfg:
v_cond, v_uncond_text, v_uncond_speaker = model(
x=torch.cat([x_t, x_t, x_t], dim=0).to(dtype),
t=(torch.ones((batch_size * 3,), device=device) * t).to(dtype),
text_input_ids=None,
text_mask=full_text_mask,
speaker_latent=None,
speaker_mask=full_speaker_mask,
kv_cache=kv_cache_full,
).float().chunk(3, dim=0)
x0_cond = x_t - t * v_cond
x0_uncond_text = x_t - t * v_uncond_text
x0_uncond_speaker = x_t - t * v_uncond_speaker
diff_text = x0_cond - x0_uncond_text
diff_speaker = x0_cond - x0_uncond_speaker
buf_text = diff_text + apg_momentum_text * buf_text
diff_text = buf_text
buf_speaker = diff_speaker + apg_momentum_speaker * buf_speaker
diff_speaker = buf_speaker
if apg_norm_text is not None:
nt = torch.sqrt((diff_text * diff_text).sum(dim=tuple(range(1, diff_text.dim())), keepdim=True) + 1e-12)
s = torch.minimum(torch.ones_like(nt), (torch.as_tensor(apg_norm_text, device=device, dtype=diff_text.dtype) / nt))
diff_text = diff_text * s
if apg_norm_speaker is not None:
ns = torch.sqrt((diff_speaker * diff_speaker).sum(dim=tuple(range(1, diff_speaker.dim())), keepdim=True) + 1e-12)
s = torch.minimum(torch.ones_like(ns), (torch.as_tensor(apg_norm_speaker, device=device, dtype=diff_speaker.dtype) / ns))
diff_speaker = diff_speaker * s
c_norm = torch.sqrt((x0_cond * x0_cond).sum(dim=tuple(range(1, x0_cond.dim())), keepdim=True) + 1e-12)
c_hat = x0_cond / c_norm
par_text = (diff_text * c_hat).sum(dim=tuple(range(1, diff_text.dim())), keepdim=True) * c_hat
ort_text = diff_text - par_text
upd_text = ort_text + apg_eta_text * par_text
par_speaker = (diff_speaker * c_hat).sum(dim=tuple(range(1, diff_speaker.dim())), keepdim=True) * c_hat
ort_speaker = diff_speaker - par_speaker
upd_speaker = ort_speaker + apg_eta_speaker * par_speaker
x0_pred = x0_cond + cfg_scale_text * upd_text + cfg_scale_speaker * upd_speaker
v_pred = (x_t - x0_pred) / t
else:
v_pred = model(
x=x_t.to(dtype),
t=(torch.ones((batch_size,), device=device) * t).to(dtype),
text_input_ids=None,
text_mask=text_mask,
speaker_latent=None,
speaker_mask=speaker_mask,
kv_cache=kv_cache,
).float()
if rescale_k is not None and rescale_sigma is not None:
v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma)
if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t:
_multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
x_t = x_t + v_pred * (t_next - t)
return x_t
# router
class GuidanceMode(Enum):
INDEPENDENT = "independent"
APG = "apg"
JOINT = "joint"
ALTERNATING = "alternating"
def sample_euler_cfg_any(
model: EchoDiT,
speaker_latent: torch.Tensor,
speaker_mask: torch.Tensor,
text_input_ids: torch.Tensor,
text_mask: torch.Tensor,
rng_seed: int,
guidance_mode: GuidanceMode,
num_steps: int,
cfg_scale_text: float,
cfg_scale_speaker: float | None,
cfg_min_t: float,
cfg_max_t: float,
truncation_factor: float | None,
rescale_k: float | None,
rescale_sigma: float | None,
speaker_k_scale: float | None,
speaker_k_min_t: float | None,
speaker_k_max_layers: int | None,
apg_eta_text: float | None,
apg_eta_speaker: float | None,
apg_momentum_text: float | None,
apg_momentum_speaker: float | None,
apg_norm_text: float | None,
apg_norm_speaker: float | None,
block_size: int | None = None,
) -> torch.Tensor:
if guidance_mode == GuidanceMode.INDEPENDENT:
assert cfg_scale_speaker is not None, "cfg_scale_speaker must be provided for independent guidances"
return sample_euler_cfg_independent_guidances(
model=model,
speaker_latent=speaker_latent,
speaker_mask=speaker_mask,
text_input_ids=text_input_ids,
text_mask=text_mask,
rng_seed=rng_seed,
num_steps=num_steps,
cfg_scale_text=cfg_scale_text,
cfg_scale_speaker=cfg_scale_speaker,
cfg_min_t=cfg_min_t,
cfg_max_t=cfg_max_t,
truncation_factor=truncation_factor,
rescale_k=rescale_k,
rescale_sigma=rescale_sigma,
speaker_k_scale=speaker_k_scale,
speaker_k_max_layers=speaker_k_max_layers,
speaker_k_min_t=speaker_k_min_t,
block_size=block_size,
)
elif guidance_mode == GuidanceMode.APG:
assert cfg_scale_speaker is not None, "cfg_scale_speaker must be provided for APG"
assert apg_eta_text is not None, "apg_eta_text must be provided for APG"
assert apg_eta_speaker is not None, "apg_eta_speaker must be provided for APG"
return sample_euler_apg_independent_guidances(
model=model,
speaker_latent=speaker_latent,
speaker_mask=speaker_mask,
text_input_ids=text_input_ids,
text_mask=text_mask,
rng_seed=rng_seed,
num_steps=num_steps,
cfg_scale_text=cfg_scale_text,
cfg_scale_speaker=cfg_scale_speaker,
cfg_min_t=cfg_min_t,
cfg_max_t=cfg_max_t,
truncation_factor=truncation_factor,
rescale_k=rescale_k,
rescale_sigma=rescale_sigma,
apg_eta_text=apg_eta_text,
apg_eta_speaker=apg_eta_speaker,
apg_momentum_text=apg_momentum_text,
apg_momentum_speaker=apg_momentum_speaker,
apg_norm_text=apg_norm_text,
apg_norm_speaker=apg_norm_speaker,
speaker_k_scale=speaker_k_scale,
speaker_k_max_layers=speaker_k_max_layers,
speaker_k_min_t=speaker_k_min_t,
block_size=block_size,
)
elif guidance_mode == GuidanceMode.JOINT:
assert cfg_scale_text == cfg_scale_speaker or cfg_scale_speaker is None, "cfg_scale_text and cfg_scale_speaker must be the same or cfg_scale_speaker must be None"
return sample_euler_cfg(
model=model,
speaker_latent=speaker_latent,
speaker_mask=speaker_mask,
text_input_ids=text_input_ids,
text_mask=text_mask,
rng_seed=rng_seed,
num_steps=num_steps,
cfg_scale=cfg_scale_text,
cfg_min_t=cfg_min_t,
cfg_max_t=cfg_max_t,
truncation_factor=truncation_factor,
rescale_k=rescale_k,
rescale_sigma=rescale_sigma,
speaker_k_scale=speaker_k_scale,
speaker_k_max_layers=speaker_k_max_layers,
speaker_k_min_t=speaker_k_min_t,
block_size=block_size,
)
elif guidance_mode == GuidanceMode.ALTERNATING:
assert cfg_scale_speaker is not None, "cfg_scale_speaker must be provided for alternating guidances"
return sample_euler_cfg_alternating_guidances(
model=model,
speaker_latent=speaker_latent,
speaker_mask=speaker_mask,
text_input_ids=text_input_ids,
text_mask=text_mask,
rng_seed=rng_seed,
num_steps=num_steps,
cfg_scale_text=cfg_scale_text,
cfg_scale_speaker=cfg_scale_speaker,
cfg_min_t=cfg_min_t,
cfg_max_t=cfg_max_t,
truncation_factor=truncation_factor,
rescale_k=rescale_k,
rescale_sigma=rescale_sigma,
speaker_k_scale=speaker_k_scale,
speaker_k_max_layers=speaker_k_max_layers,
speaker_k_min_t=speaker_k_min_t,
block_size=block_size,
)
else:
raise ValueError(f"Unknown guidance mode: {guidance_mode}")