Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Tuple, List | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torch.nn.functional as F | |
| def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: | |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)] / dim)) | |
| t = torch.arange(end) | |
| freqs = torch.outer(t, freqs) | |
| freqs_cis = torch.complex(torch.cos(freqs), torch.sin(freqs)) | |
| return freqs_cis | |
| def apply_rotary_emb( | |
| x: torch.Tensor, | |
| freqs_cis: torch.Tensor, | |
| ) -> torch.Tensor: | |
| x_ = torch.view_as_complex(x.float().reshape(*x.shape[:3], -1, 2)) | |
| x_ = x_ * freqs_cis[..., None, :] | |
| x_ = torch.view_as_real(x_).reshape(x.shape) | |
| return x_.type_as(x) | |
| def get_timestep_embedding( | |
| timestep: torch.Tensor, | |
| embed_size: int, | |
| ) -> torch.Tensor: | |
| assert embed_size % 2 == 0 | |
| half = embed_size // 2 | |
| freqs = 1000 * torch.exp( | |
| -torch.log(torch.tensor(10000.0)) * | |
| torch.arange(start=0, end=half, dtype=torch.float32) / half | |
| ).to(timestep.device) | |
| args = timestep[..., None] * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| return embedding.to(timestep.dtype) | |
| class LowRankAdaLN(nn.Module): | |
| def __init__( | |
| self, | |
| model_size: int, | |
| rank: int, | |
| eps: float | |
| ): | |
| super().__init__() | |
| self.eps = eps | |
| self.shift_down = nn.Linear(model_size, rank, bias=False) | |
| self.scale_down = nn.Linear(model_size, rank, bias=False) | |
| self.gate_down = nn.Linear(model_size, rank, bias=False) | |
| self.shift_up = nn.Linear(rank, model_size, bias=True) | |
| self.scale_up = nn.Linear(rank, model_size, bias=True) | |
| self.gate_up = nn.Linear(rank, model_size, bias=True) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| cond_embed: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| shift, scale, gate = cond_embed.chunk(3, dim=-1) | |
| shift = self.shift_up(self.shift_down(F.silu(shift))) + shift | |
| scale = self.scale_up(self.scale_down(F.silu(scale))) + scale | |
| gate = self.gate_up(self.gate_down(F.silu(gate))) + gate | |
| x_dtype = x.dtype | |
| x = x.float() | |
| x = x * torch.rsqrt(torch.pow(x.float(), 2).mean(dim=-1, keepdim=True) + self.eps) | |
| x = x * (scale + 1) + shift | |
| gate = torch.tanh(gate) | |
| return x.to(x_dtype), gate | |
| class RMSNorm(nn.Module): # could also just use torch rmsnorm | |
| def __init__( | |
| self, | |
| model_size: int | Tuple[int, int], | |
| eps: float | |
| ): | |
| super().__init__() | |
| self.eps = eps | |
| if isinstance(model_size, int): | |
| model_size = (model_size, ) | |
| self.weight = nn.Parameter(torch.ones(model_size)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x_dtype = x.dtype | |
| x = x.float() | |
| x = x * torch.rsqrt(torch.pow(x.float(), 2).mean(dim=-1, keepdim=True) + self.eps) | |
| x = x * self.weight | |
| return x.to(x_dtype) | |
| class SelfAttention(nn.Module): | |
| def __init__( | |
| self, | |
| model_size: int, | |
| num_heads: int, | |
| is_causal: bool, | |
| norm_eps: float | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.is_causal = is_causal | |
| self.wq = nn.Linear(model_size, model_size, bias=False) | |
| self.wk = nn.Linear(model_size, model_size, bias=False) | |
| self.wv = nn.Linear(model_size, model_size, bias=False) | |
| self.wo = nn.Linear(model_size, model_size, bias=False) | |
| self.gate = nn.Linear(model_size, model_size, bias=False) | |
| assert model_size % num_heads == 0 | |
| self.q_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps) | |
| self.k_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps) | |
| def forward(self, x: torch.Tensor, mask: torch.Tensor | None, freqs_cis: torch.Tensor) -> torch.Tensor: | |
| batch_size, seq_len = x.shape[:2] | |
| xq = self.wq(x).reshape(batch_size, seq_len, self.num_heads, -1) | |
| xk = self.wk(x).reshape(batch_size, seq_len, self.num_heads, -1) | |
| xv = self.wv(x).reshape(batch_size, seq_len, self.num_heads, -1) | |
| gate = self.gate(x) | |
| xq = self.q_norm(xq) | |
| xk = self.k_norm(xk) | |
| xq = apply_rotary_emb(xq, freqs_cis[:seq_len]) | |
| xk = apply_rotary_emb(xk, freqs_cis[:seq_len]) | |
| if mask is not None: | |
| assert mask.ndim == 2 # (b, s) | |
| mask = mask[:, None, None] | |
| output = F.scaled_dot_product_attention( | |
| query=xq.transpose(1, 2), | |
| key=xk.transpose(1, 2), | |
| value=xv.transpose(1, 2), | |
| attn_mask=mask, | |
| is_causal=self.is_causal | |
| ).transpose(1, 2) | |
| output = output.reshape(batch_size, seq_len, -1) | |
| output = output * torch.sigmoid(gate) | |
| output = self.wo(output) | |
| return output | |
| class JointAttention(nn.Module): | |
| def __init__( | |
| self, | |
| model_size: int, | |
| num_heads: int, | |
| text_model_size: int, | |
| speaker_model_size: int, | |
| speaker_patch_size: int, | |
| norm_eps: float | |
| ): | |
| super().__init__() | |
| self.speaker_patch_size = speaker_patch_size | |
| self.num_heads = num_heads | |
| self.wq = nn.Linear(model_size, model_size, bias=False) | |
| self.wk = nn.Linear(model_size, model_size, bias=False) | |
| self.wv = nn.Linear(model_size, model_size, bias=False) | |
| self.wk_text = nn.Linear(text_model_size, model_size, bias=False) | |
| self.wv_text = nn.Linear(text_model_size, model_size, bias=False) | |
| self.wk_speaker = nn.Linear(speaker_model_size, model_size, bias=False) | |
| self.wv_speaker = nn.Linear(speaker_model_size, model_size, bias=False) | |
| assert model_size % num_heads == 0 | |
| self.q_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps) | |
| self.k_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps) | |
| self.gate = nn.Linear(model_size, model_size, bias=False) | |
| self.wo = nn.Linear(model_size, model_size, bias=False) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| text_state: torch.Tensor | None, | |
| text_mask: torch.Tensor, | |
| speaker_state: torch.Tensor | None, | |
| speaker_mask: torch.Tensor, | |
| freqs_cis: torch.Tensor, | |
| kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, | |
| ) -> torch.Tensor: | |
| batch_size, seq_len = x.shape[:2] | |
| xq = self.wq(x).reshape(batch_size, seq_len, self.num_heads, -1) | |
| xk_self = self.wk(x).reshape(batch_size, seq_len, self.num_heads, -1) | |
| xv_self = self.wv(x).reshape(batch_size, seq_len, self.num_heads, -1) | |
| xq = self.q_norm(xq) | |
| xk_self = self.k_norm(xk_self) | |
| gate = self.gate(x) | |
| def _apply_rotary_half(y: torch.Tensor, fc: torch.Tensor) -> torch.Tensor: | |
| y1, y2 = y.chunk(2, dim=-2) | |
| y1 = apply_rotary_emb(y1, fc) | |
| return torch.cat([y1, y2], dim=-2) | |
| xq = _apply_rotary_half(xq, freqs_cis) | |
| xk_self = _apply_rotary_half(xk_self, freqs_cis) | |
| if kv_cache is None: | |
| xk_text = self.wk_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1) | |
| xv_text = self.wv_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1) | |
| xk_speaker = self.wk_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1) | |
| xv_speaker = self.wv_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1) | |
| xk_text = self.k_norm(xk_text) | |
| xk_speaker = self.k_norm(xk_speaker) | |
| xk = torch.cat([xk_self, xk_text, xk_speaker], dim=1) | |
| xv = torch.cat([xv_self, xv_text, xv_speaker], dim=1) | |
| else: | |
| xk_cross, xv_cross = kv_cache | |
| xk = torch.cat([xk_self, xk_cross], dim=1) | |
| xv = torch.cat([xv_self, xv_cross], dim=1) | |
| self_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=x.device) | |
| mask = torch.cat([self_mask, text_mask, speaker_mask], dim=1) | |
| mask = mask[:, None, None] | |
| output = F.scaled_dot_product_attention( | |
| query=xq.transpose(1, 2), | |
| key=xk.transpose(1, 2), | |
| value=xv.transpose(1, 2), | |
| attn_mask=mask, | |
| is_causal=False | |
| ).transpose(1, 2) | |
| output = output.reshape(batch_size, seq_len, -1) | |
| output = output * torch.sigmoid(gate) | |
| output = self.wo(output) | |
| return output | |
| def get_kv_cache( | |
| self, | |
| text_state: torch.Tensor, | |
| speaker_state: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| batch_size = text_state.shape[0] | |
| xk_text = self.wk_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1) | |
| xv_text = self.wv_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1) | |
| xk_speaker = self.wk_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1) | |
| xv_speaker = self.wv_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1) | |
| xk = torch.cat([xk_text, xk_speaker], dim=1) | |
| xv = torch.cat([xv_text, xv_speaker], dim=1) | |
| xk = self.k_norm(xk) | |
| return xk, xv | |
| class MLP(nn.Module): | |
| def __init__( | |
| self, | |
| model_size: int, | |
| intermediate_size: int | |
| ): | |
| super().__init__() | |
| self.w1 = nn.Linear(model_size, intermediate_size, bias=False) | |
| self.w3 = nn.Linear(model_size, intermediate_size, bias=False) | |
| self.w2 = nn.Linear(intermediate_size, model_size, bias=False) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| class EncoderTransformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| model_size: int, | |
| num_heads: int, | |
| intermediate_size: int, | |
| is_causal: bool, | |
| norm_eps: float | |
| ): | |
| super().__init__() | |
| self.attention = SelfAttention( | |
| model_size=model_size, | |
| num_heads=num_heads, | |
| is_causal=is_causal, | |
| norm_eps=norm_eps | |
| ) | |
| self.mlp = MLP( | |
| model_size=model_size, | |
| intermediate_size=intermediate_size | |
| ) | |
| self.attention_norm = RMSNorm(model_size, norm_eps) | |
| self.mlp_norm = RMSNorm(model_size, norm_eps) | |
| def forward(self, x: torch.Tensor, mask: torch.Tensor | None, freqs_cis: torch.Tensor) -> torch.Tensor: | |
| x = x + self.attention(self.attention_norm(x), mask, freqs_cis) | |
| x = x + self.mlp(self.mlp_norm(x)) | |
| return x | |
| class TransformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| model_size: int, | |
| num_heads: int, | |
| intermediate_size: int, | |
| norm_eps: float, | |
| text_model_size: int, | |
| speaker_model_size: int, | |
| speaker_patch_size: int, | |
| adaln_rank: int, | |
| ): | |
| super().__init__() | |
| self.attention = JointAttention( | |
| model_size=model_size, | |
| num_heads=num_heads, | |
| text_model_size=text_model_size, | |
| speaker_model_size=speaker_model_size, | |
| speaker_patch_size=speaker_patch_size, | |
| norm_eps=norm_eps | |
| ) | |
| self.mlp = MLP( | |
| model_size=model_size, | |
| intermediate_size=intermediate_size | |
| ) | |
| self.attention_adaln = LowRankAdaLN(model_size=model_size, rank=adaln_rank, eps=norm_eps) | |
| self.mlp_adaln = LowRankAdaLN(model_size=model_size, rank=adaln_rank, eps=norm_eps) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| cond_embed: torch.Tensor, | |
| text_state: torch.Tensor | None, | |
| text_mask: torch.Tensor, | |
| speaker_state: torch.Tensor | None, | |
| speaker_mask: torch.Tensor, | |
| freqs_cis: torch.Tensor, | |
| kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, | |
| ) -> torch.Tensor: | |
| x_norm, attention_gate = self.attention_adaln(x, cond_embed) | |
| x = x + attention_gate * self.attention(x_norm, text_state, text_mask, speaker_state, speaker_mask, freqs_cis, kv_cache) | |
| x_norm, mlp_gate = self.mlp_adaln(x, cond_embed) | |
| x = x + mlp_gate * self.mlp(x_norm) | |
| return x | |
| def get_kv_cache( | |
| self, | |
| text_state: torch.Tensor, | |
| speaker_state: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| return self.attention.get_kv_cache(text_state, speaker_state) | |
| class TextEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| model_size: int, | |
| num_layers: int, | |
| num_heads: int, | |
| intermediate_size: int, | |
| norm_eps: float, | |
| max_seq_len: int, | |
| ): | |
| super().__init__() | |
| self.text_embedding = nn.Embedding(vocab_size, model_size) | |
| self.blocks = nn.ModuleList() | |
| for i in range(num_layers): | |
| block = EncoderTransformerBlock( | |
| model_size=model_size, | |
| num_heads=num_heads, | |
| intermediate_size=intermediate_size, | |
| is_causal=False, | |
| norm_eps=norm_eps | |
| ) | |
| self.blocks.append(block) | |
| self.head_dim = model_size // num_heads | |
| def forward(self, input_ids: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: | |
| x = self.text_embedding(input_ids) | |
| freqs_cis = precompute_freqs_cis(self.head_dim, input_ids.shape[1]).to(x.device) # see below about avoiding recomputation | |
| for block in self.blocks: | |
| x = block(x, mask, freqs_cis) | |
| return x | |
| class SpeakerEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| latent_size: int, | |
| patch_size: int, | |
| model_size: int, | |
| num_layers: int, | |
| num_heads: int, | |
| intermediate_size: int, | |
| norm_eps: float, | |
| max_patched_seq_len: int, | |
| ): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.in_proj = nn.Linear(latent_size * patch_size, model_size, bias=True) | |
| self.blocks = nn.ModuleList() | |
| for i in range(num_layers): | |
| block = EncoderTransformerBlock( | |
| model_size=model_size, | |
| num_heads=num_heads, | |
| intermediate_size=intermediate_size, | |
| is_causal=True, | |
| norm_eps=norm_eps | |
| ) | |
| self.blocks.append(block) | |
| self.head_dim = model_size // num_heads | |
| def forward(self, latent: torch.Tensor) -> torch.Tensor: | |
| x = latent.reshape(*latent.shape[:-2], latent.shape[-2] // self.patch_size, latent.shape[-1] * self.patch_size) | |
| x = self.in_proj(x) | |
| x = x / 6. # this helped with initial activation dynamics in early ablations, could also bake into in_proj | |
| freqs_cis = precompute_freqs_cis(self.head_dim, x.shape[1]).to(x.device) # see below about avoiding recomputation | |
| for block in self.blocks: | |
| x = block(x, None, freqs_cis) | |
| return x | |
| class EchoDiT(nn.Module): | |
| def __init__( | |
| self, | |
| latent_size: int, | |
| # | |
| model_size: int, | |
| num_layers: int, | |
| num_heads: int, | |
| intermediate_size: int, | |
| norm_eps: float, | |
| max_seq_len: int, | |
| # | |
| text_vocab_size: int, | |
| text_model_size: int, | |
| text_num_layers: int, | |
| text_num_heads: int, | |
| text_intermediate_size: int, | |
| text_max_seq_len: int, | |
| # | |
| speaker_patch_size: int, | |
| speaker_model_size: int, | |
| speaker_num_layers: int, | |
| speaker_num_heads: int, | |
| speaker_intermediate_size: int, | |
| speaker_max_patched_seq_len: int, | |
| # | |
| timestep_embed_size: int, | |
| adaln_rank: int, | |
| ): | |
| super().__init__() | |
| self.speaker_patch_size = speaker_patch_size | |
| self.timestep_embed_size = timestep_embed_size | |
| self.text_encoder = TextEncoder( | |
| vocab_size=text_vocab_size, | |
| model_size=text_model_size, | |
| num_layers=text_num_layers, | |
| num_heads=text_num_heads, | |
| intermediate_size=text_intermediate_size, | |
| norm_eps=norm_eps, | |
| max_seq_len=text_max_seq_len, | |
| ) | |
| self.speaker_encoder = SpeakerEncoder( | |
| latent_size=latent_size, | |
| patch_size=speaker_patch_size, | |
| model_size=speaker_model_size, | |
| num_layers=speaker_num_layers, | |
| num_heads=speaker_num_heads, | |
| intermediate_size=speaker_intermediate_size, | |
| norm_eps=norm_eps, | |
| max_patched_seq_len=speaker_max_patched_seq_len, | |
| ) | |
| self.text_norm = RMSNorm(text_model_size, norm_eps) | |
| self.speaker_norm = RMSNorm(speaker_model_size, norm_eps) | |
| self.cond_module = nn.Sequential( | |
| nn.Linear(timestep_embed_size, model_size, bias=False), | |
| nn.SiLU(), | |
| nn.Linear(model_size, model_size, bias=False), | |
| nn.SiLU(), | |
| nn.Linear(model_size, model_size * 3, bias=False), | |
| ) | |
| self.in_proj = nn.Linear(latent_size, model_size, bias=True) | |
| self.blocks = nn.ModuleList() | |
| for i in range(num_layers): | |
| block = TransformerBlock( | |
| model_size=model_size, | |
| num_heads=num_heads, | |
| intermediate_size=intermediate_size, | |
| norm_eps=norm_eps, | |
| text_model_size=text_model_size, | |
| speaker_model_size=speaker_model_size, | |
| speaker_patch_size=speaker_patch_size, | |
| adaln_rank=adaln_rank, | |
| ) | |
| self.blocks.append(block) | |
| self.out_norm = RMSNorm(model_size, norm_eps) | |
| self.out_proj = nn.Linear(model_size, latent_size, bias=True) | |
| self.head_dim = model_size // num_heads | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| t: torch.Tensor, | |
| text_input_ids: torch.Tensor, | |
| text_mask: torch.Tensor | None, | |
| speaker_latent: torch.Tensor, | |
| speaker_mask: torch.Tensor | None, | |
| kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] | None = None, | |
| ) -> torch.Tensor: | |
| """ | |
| x: (b, s, d) | |
| t: (b,) | |
| text_input_ids: (b, s_t) # not used when kv_cache is provided | |
| text_mask: (b, s_t) | |
| speaker_latent: (b, s_r, d) # not used when kv_cache is provided | |
| speaker_mask: (b, s_r) | |
| kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] | |
| returns: (b, s, d) | |
| """ | |
| freqs_cis = precompute_freqs_cis(self.head_dim, x.shape[1]).to(x.device) | |
| # can't register as buffer because we'd like it to stay in fp32; however, could optionally pass in to avoid recomputing | |
| if kv_cache is None and speaker_state is None: | |
| text_state = self.text_encoder(text_input_ids, text_mask) | |
| text_state = self.text_norm(text_state) | |
| speaker_state = self.speaker_encoder(speaker_latent) | |
| speaker_state = self.speaker_norm(speaker_state) | |
| else: | |
| text_state, speaker_state = None, None | |
| speaker_mask = speaker_mask[..., ::self.speaker_patch_size] | |
| cond_embed = self.cond_module(get_timestep_embedding(t, self.timestep_embed_size)) | |
| assert cond_embed.ndim == 2 | |
| cond_embed = cond_embed[:, None] | |
| x = self.in_proj(x) | |
| for i, block in enumerate(self.blocks): | |
| x = block( | |
| x=x, | |
| cond_embed=cond_embed, | |
| text_state=text_state, | |
| text_mask=text_mask, | |
| speaker_state=speaker_state, | |
| speaker_mask=speaker_mask, | |
| freqs_cis=freqs_cis, | |
| kv_cache=kv_cache[i] if kv_cache is not None else None, | |
| ) | |
| x = self.out_norm(x) | |
| x = self.out_proj(x) | |
| return x.float() | |
| def get_kv_cache( | |
| self, | |
| speaker_latent: torch.Tensor, | |
| speaker_mask: torch.Tensor, | |
| text_input_ids: torch.Tensor, | |
| text_mask: torch.Tensor, | |
| ) -> List[Tuple[torch.Tensor, torch.Tensor]]: | |
| speaker_state = self.speaker_encoder(speaker_latent) | |
| speaker_state = self.speaker_norm(speaker_state) | |
| text_state = self.text_encoder(text_input_ids, text_mask) | |
| text_state = self.text_norm(text_state) | |
| return [self.blocks[i].get_kv_cache(text_state, speaker_state) for i in range(len(self.blocks))] | |
| def get_kv_cache_from_precomputed_speaker_state( | |
| self, | |
| speaker_state: torch.Tensor, | |
| speaker_mask: torch.Tensor, | |
| text_input_ids: torch.Tensor, | |
| text_mask: torch.Tensor, | |
| ) -> List[Tuple[torch.Tensor, torch.Tensor]]: | |
| # here, speaker state is already computed from the speaker latent encoder transformer | |
| text_state = self.text_encoder(text_input_ids, text_mask) | |
| text_state = self.text_norm(text_state) | |
| return [self.blocks[i].get_kv_cache(text_state, speaker_state) for i in range(len(self.blocks))] | |
| def device(self) -> torch.device: return next(self.parameters()).device | |
| def dtype(self) -> torch.dtype: return next(self.parameters()).dtype |