from typing import Tuple, List import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)] / dim)) t = torch.arange(end) freqs = torch.outer(t, freqs) freqs_cis = torch.complex(torch.cos(freqs), torch.sin(freqs)) return freqs_cis def apply_rotary_emb( x: torch.Tensor, freqs_cis: torch.Tensor, ) -> torch.Tensor: x_ = torch.view_as_complex(x.float().reshape(*x.shape[:3], -1, 2)) x_ = x_ * freqs_cis[..., None, :] x_ = torch.view_as_real(x_).reshape(x.shape) return x_.type_as(x) def get_timestep_embedding( timestep: torch.Tensor, embed_size: int, ) -> torch.Tensor: assert embed_size % 2 == 0 half = embed_size // 2 freqs = 1000 * torch.exp( -torch.log(torch.tensor(10000.0)) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(timestep.device) args = timestep[..., None] * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) return embedding.to(timestep.dtype) class LowRankAdaLN(nn.Module): def __init__( self, model_size: int, rank: int, eps: float ): super().__init__() self.eps = eps self.shift_down = nn.Linear(model_size, rank, bias=False) self.scale_down = nn.Linear(model_size, rank, bias=False) self.gate_down = nn.Linear(model_size, rank, bias=False) self.shift_up = nn.Linear(rank, model_size, bias=True) self.scale_up = nn.Linear(rank, model_size, bias=True) self.gate_up = nn.Linear(rank, model_size, bias=True) def forward( self, x: torch.Tensor, cond_embed: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: shift, scale, gate = cond_embed.chunk(3, dim=-1) shift = self.shift_up(self.shift_down(F.silu(shift))) + shift scale = self.scale_up(self.scale_down(F.silu(scale))) + scale gate = self.gate_up(self.gate_down(F.silu(gate))) + gate x_dtype = x.dtype x = x.float() x = x * torch.rsqrt(torch.pow(x.float(), 2).mean(dim=-1, keepdim=True) + self.eps) x = x * (scale + 1) + shift gate = torch.tanh(gate) return x.to(x_dtype), gate class RMSNorm(nn.Module): # could also just use torch rmsnorm def __init__( self, model_size: int | Tuple[int, int], eps: float ): super().__init__() self.eps = eps if isinstance(model_size, int): model_size = (model_size, ) self.weight = nn.Parameter(torch.ones(model_size)) def forward(self, x: torch.Tensor) -> torch.Tensor: x_dtype = x.dtype x = x.float() x = x * torch.rsqrt(torch.pow(x.float(), 2).mean(dim=-1, keepdim=True) + self.eps) x = x * self.weight return x.to(x_dtype) class SelfAttention(nn.Module): def __init__( self, model_size: int, num_heads: int, is_causal: bool, norm_eps: float ): super().__init__() self.num_heads = num_heads self.is_causal = is_causal self.wq = nn.Linear(model_size, model_size, bias=False) self.wk = nn.Linear(model_size, model_size, bias=False) self.wv = nn.Linear(model_size, model_size, bias=False) self.wo = nn.Linear(model_size, model_size, bias=False) self.gate = nn.Linear(model_size, model_size, bias=False) assert model_size % num_heads == 0 self.q_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps) self.k_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps) def forward(self, x: torch.Tensor, mask: torch.Tensor | None, freqs_cis: torch.Tensor) -> torch.Tensor: batch_size, seq_len = x.shape[:2] xq = self.wq(x).reshape(batch_size, seq_len, self.num_heads, -1) xk = self.wk(x).reshape(batch_size, seq_len, self.num_heads, -1) xv = self.wv(x).reshape(batch_size, seq_len, self.num_heads, -1) gate = self.gate(x) xq = self.q_norm(xq) xk = self.k_norm(xk) xq = apply_rotary_emb(xq, freqs_cis[:seq_len]) xk = apply_rotary_emb(xk, freqs_cis[:seq_len]) if mask is not None: assert mask.ndim == 2 # (b, s) mask = mask[:, None, None] output = F.scaled_dot_product_attention( query=xq.transpose(1, 2), key=xk.transpose(1, 2), value=xv.transpose(1, 2), attn_mask=mask, is_causal=self.is_causal ).transpose(1, 2) output = output.reshape(batch_size, seq_len, -1) output = output * torch.sigmoid(gate) output = self.wo(output) return output class JointAttention(nn.Module): def __init__( self, model_size: int, num_heads: int, text_model_size: int, speaker_model_size: int, speaker_patch_size: int, norm_eps: float ): super().__init__() self.speaker_patch_size = speaker_patch_size self.num_heads = num_heads self.wq = nn.Linear(model_size, model_size, bias=False) self.wk = nn.Linear(model_size, model_size, bias=False) self.wv = nn.Linear(model_size, model_size, bias=False) self.wk_text = nn.Linear(text_model_size, model_size, bias=False) self.wv_text = nn.Linear(text_model_size, model_size, bias=False) self.wk_speaker = nn.Linear(speaker_model_size, model_size, bias=False) self.wv_speaker = nn.Linear(speaker_model_size, model_size, bias=False) assert model_size % num_heads == 0 self.q_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps) self.k_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps) self.gate = nn.Linear(model_size, model_size, bias=False) self.wo = nn.Linear(model_size, model_size, bias=False) def forward( self, x: torch.Tensor, text_state: torch.Tensor | None, text_mask: torch.Tensor, speaker_state: torch.Tensor | None, speaker_mask: torch.Tensor, freqs_cis: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: batch_size, seq_len = x.shape[:2] xq = self.wq(x).reshape(batch_size, seq_len, self.num_heads, -1) xk_self = self.wk(x).reshape(batch_size, seq_len, self.num_heads, -1) xv_self = self.wv(x).reshape(batch_size, seq_len, self.num_heads, -1) xq = self.q_norm(xq) xk_self = self.k_norm(xk_self) gate = self.gate(x) def _apply_rotary_half(y: torch.Tensor, fc: torch.Tensor) -> torch.Tensor: y1, y2 = y.chunk(2, dim=-2) y1 = apply_rotary_emb(y1, fc) return torch.cat([y1, y2], dim=-2) xq = _apply_rotary_half(xq, freqs_cis) xk_self = _apply_rotary_half(xk_self, freqs_cis) if kv_cache is None: xk_text = self.wk_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1) xv_text = self.wv_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1) xk_speaker = self.wk_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1) xv_speaker = self.wv_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1) xk_text = self.k_norm(xk_text) xk_speaker = self.k_norm(xk_speaker) xk = torch.cat([xk_self, xk_text, xk_speaker], dim=1) xv = torch.cat([xv_self, xv_text, xv_speaker], dim=1) else: xk_cross, xv_cross = kv_cache xk = torch.cat([xk_self, xk_cross], dim=1) xv = torch.cat([xv_self, xv_cross], dim=1) self_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=x.device) mask = torch.cat([self_mask, text_mask, speaker_mask], dim=1) mask = mask[:, None, None] output = F.scaled_dot_product_attention( query=xq.transpose(1, 2), key=xk.transpose(1, 2), value=xv.transpose(1, 2), attn_mask=mask, is_causal=False ).transpose(1, 2) output = output.reshape(batch_size, seq_len, -1) output = output * torch.sigmoid(gate) output = self.wo(output) return output def get_kv_cache( self, text_state: torch.Tensor, speaker_state: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: batch_size = text_state.shape[0] xk_text = self.wk_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1) xv_text = self.wv_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1) xk_speaker = self.wk_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1) xv_speaker = self.wv_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1) xk = torch.cat([xk_text, xk_speaker], dim=1) xv = torch.cat([xv_text, xv_speaker], dim=1) xk = self.k_norm(xk) return xk, xv class MLP(nn.Module): def __init__( self, model_size: int, intermediate_size: int ): super().__init__() self.w1 = nn.Linear(model_size, intermediate_size, bias=False) self.w3 = nn.Linear(model_size, intermediate_size, bias=False) self.w2 = nn.Linear(intermediate_size, model_size, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) class EncoderTransformerBlock(nn.Module): def __init__( self, model_size: int, num_heads: int, intermediate_size: int, is_causal: bool, norm_eps: float ): super().__init__() self.attention = SelfAttention( model_size=model_size, num_heads=num_heads, is_causal=is_causal, norm_eps=norm_eps ) self.mlp = MLP( model_size=model_size, intermediate_size=intermediate_size ) self.attention_norm = RMSNorm(model_size, norm_eps) self.mlp_norm = RMSNorm(model_size, norm_eps) def forward(self, x: torch.Tensor, mask: torch.Tensor | None, freqs_cis: torch.Tensor) -> torch.Tensor: x = x + self.attention(self.attention_norm(x), mask, freqs_cis) x = x + self.mlp(self.mlp_norm(x)) return x class TransformerBlock(nn.Module): def __init__( self, model_size: int, num_heads: int, intermediate_size: int, norm_eps: float, text_model_size: int, speaker_model_size: int, speaker_patch_size: int, adaln_rank: int, ): super().__init__() self.attention = JointAttention( model_size=model_size, num_heads=num_heads, text_model_size=text_model_size, speaker_model_size=speaker_model_size, speaker_patch_size=speaker_patch_size, norm_eps=norm_eps ) self.mlp = MLP( model_size=model_size, intermediate_size=intermediate_size ) self.attention_adaln = LowRankAdaLN(model_size=model_size, rank=adaln_rank, eps=norm_eps) self.mlp_adaln = LowRankAdaLN(model_size=model_size, rank=adaln_rank, eps=norm_eps) def forward( self, x: torch.Tensor, cond_embed: torch.Tensor, text_state: torch.Tensor | None, text_mask: torch.Tensor, speaker_state: torch.Tensor | None, speaker_mask: torch.Tensor, freqs_cis: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: x_norm, attention_gate = self.attention_adaln(x, cond_embed) x = x + attention_gate * self.attention(x_norm, text_state, text_mask, speaker_state, speaker_mask, freqs_cis, kv_cache) x_norm, mlp_gate = self.mlp_adaln(x, cond_embed) x = x + mlp_gate * self.mlp(x_norm) return x def get_kv_cache( self, text_state: torch.Tensor, speaker_state: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: return self.attention.get_kv_cache(text_state, speaker_state) class TextEncoder(nn.Module): def __init__( self, vocab_size: int, model_size: int, num_layers: int, num_heads: int, intermediate_size: int, norm_eps: float, max_seq_len: int, ): super().__init__() self.text_embedding = nn.Embedding(vocab_size, model_size) self.blocks = nn.ModuleList() for i in range(num_layers): block = EncoderTransformerBlock( model_size=model_size, num_heads=num_heads, intermediate_size=intermediate_size, is_causal=False, norm_eps=norm_eps ) self.blocks.append(block) self.head_dim = model_size // num_heads def forward(self, input_ids: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: x = self.text_embedding(input_ids) freqs_cis = precompute_freqs_cis(self.head_dim, input_ids.shape[1]).to(x.device) # see below about avoiding recomputation for block in self.blocks: x = block(x, mask, freqs_cis) return x class SpeakerEncoder(nn.Module): def __init__( self, latent_size: int, patch_size: int, model_size: int, num_layers: int, num_heads: int, intermediate_size: int, norm_eps: float, max_patched_seq_len: int, ): super().__init__() self.patch_size = patch_size self.in_proj = nn.Linear(latent_size * patch_size, model_size, bias=True) self.blocks = nn.ModuleList() for i in range(num_layers): block = EncoderTransformerBlock( model_size=model_size, num_heads=num_heads, intermediate_size=intermediate_size, is_causal=True, norm_eps=norm_eps ) self.blocks.append(block) self.head_dim = model_size // num_heads def forward(self, latent: torch.Tensor) -> torch.Tensor: x = latent.reshape(*latent.shape[:-2], latent.shape[-2] // self.patch_size, latent.shape[-1] * self.patch_size) x = self.in_proj(x) x = x / 6. # this helped with initial activation dynamics in early ablations, could also bake into in_proj freqs_cis = precompute_freqs_cis(self.head_dim, x.shape[1]).to(x.device) # see below about avoiding recomputation for block in self.blocks: x = block(x, None, freqs_cis) return x class EchoDiT(nn.Module): def __init__( self, latent_size: int, # model_size: int, num_layers: int, num_heads: int, intermediate_size: int, norm_eps: float, max_seq_len: int, # text_vocab_size: int, text_model_size: int, text_num_layers: int, text_num_heads: int, text_intermediate_size: int, text_max_seq_len: int, # speaker_patch_size: int, speaker_model_size: int, speaker_num_layers: int, speaker_num_heads: int, speaker_intermediate_size: int, speaker_max_patched_seq_len: int, # timestep_embed_size: int, adaln_rank: int, ): super().__init__() self.speaker_patch_size = speaker_patch_size self.timestep_embed_size = timestep_embed_size self.text_encoder = TextEncoder( vocab_size=text_vocab_size, model_size=text_model_size, num_layers=text_num_layers, num_heads=text_num_heads, intermediate_size=text_intermediate_size, norm_eps=norm_eps, max_seq_len=text_max_seq_len, ) self.speaker_encoder = SpeakerEncoder( latent_size=latent_size, patch_size=speaker_patch_size, model_size=speaker_model_size, num_layers=speaker_num_layers, num_heads=speaker_num_heads, intermediate_size=speaker_intermediate_size, norm_eps=norm_eps, max_patched_seq_len=speaker_max_patched_seq_len, ) self.text_norm = RMSNorm(text_model_size, norm_eps) self.speaker_norm = RMSNorm(speaker_model_size, norm_eps) self.cond_module = nn.Sequential( nn.Linear(timestep_embed_size, model_size, bias=False), nn.SiLU(), nn.Linear(model_size, model_size, bias=False), nn.SiLU(), nn.Linear(model_size, model_size * 3, bias=False), ) self.in_proj = nn.Linear(latent_size, model_size, bias=True) self.blocks = nn.ModuleList() for i in range(num_layers): block = TransformerBlock( model_size=model_size, num_heads=num_heads, intermediate_size=intermediate_size, norm_eps=norm_eps, text_model_size=text_model_size, speaker_model_size=speaker_model_size, speaker_patch_size=speaker_patch_size, adaln_rank=adaln_rank, ) self.blocks.append(block) self.out_norm = RMSNorm(model_size, norm_eps) self.out_proj = nn.Linear(model_size, latent_size, bias=True) self.head_dim = model_size // num_heads def forward( self, x: torch.Tensor, t: torch.Tensor, text_input_ids: torch.Tensor, text_mask: torch.Tensor | None, speaker_latent: torch.Tensor, speaker_mask: torch.Tensor | None, kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] | None = None, ) -> torch.Tensor: """ x: (b, s, d) t: (b,) text_input_ids: (b, s_t) # not used when kv_cache is provided text_mask: (b, s_t) speaker_latent: (b, s_r, d) # not used when kv_cache is provided speaker_mask: (b, s_r) kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] returns: (b, s, d) """ freqs_cis = precompute_freqs_cis(self.head_dim, x.shape[1]).to(x.device) # can't register as buffer because we'd like it to stay in fp32; however, could optionally pass in to avoid recomputing if kv_cache is None and speaker_state is None: text_state = self.text_encoder(text_input_ids, text_mask) text_state = self.text_norm(text_state) speaker_state = self.speaker_encoder(speaker_latent) speaker_state = self.speaker_norm(speaker_state) else: text_state, speaker_state = None, None speaker_mask = speaker_mask[..., ::self.speaker_patch_size] cond_embed = self.cond_module(get_timestep_embedding(t, self.timestep_embed_size)) assert cond_embed.ndim == 2 cond_embed = cond_embed[:, None] x = self.in_proj(x) for i, block in enumerate(self.blocks): x = block( x=x, cond_embed=cond_embed, text_state=text_state, text_mask=text_mask, speaker_state=speaker_state, speaker_mask=speaker_mask, freqs_cis=freqs_cis, kv_cache=kv_cache[i] if kv_cache is not None else None, ) x = self.out_norm(x) x = self.out_proj(x) return x.float() def get_kv_cache( self, speaker_latent: torch.Tensor, speaker_mask: torch.Tensor, text_input_ids: torch.Tensor, text_mask: torch.Tensor, ) -> List[Tuple[torch.Tensor, torch.Tensor]]: speaker_state = self.speaker_encoder(speaker_latent) speaker_state = self.speaker_norm(speaker_state) text_state = self.text_encoder(text_input_ids, text_mask) text_state = self.text_norm(text_state) return [self.blocks[i].get_kv_cache(text_state, speaker_state) for i in range(len(self.blocks))] def get_kv_cache_from_precomputed_speaker_state( self, speaker_state: torch.Tensor, speaker_mask: torch.Tensor, text_input_ids: torch.Tensor, text_mask: torch.Tensor, ) -> List[Tuple[torch.Tensor, torch.Tensor]]: # here, speaker state is already computed from the speaker latent encoder transformer text_state = self.text_encoder(text_input_ids, text_mask) text_state = self.text_norm(text_state) return [self.blocks[i].get_kv_cache(text_state, speaker_state) for i in range(len(self.blocks))] @property def device(self) -> torch.device: return next(self.parameters()).device @property def dtype(self) -> torch.dtype: return next(self.parameters()).dtype