# SPDX-FileCopyrightText: 2025 Jordan Darefsky # SPDX-License-Identifier: Apache-2.0 # # This file contains portions adapted from: # • Descript Audio Codec (DAC) — MIT License (full text appended below) # • Fish-Speech S1 DAC Autoencoder — reference implementation (Apache-2.0 / CC-BY-NC), # rewritten here in a single-file Torch module for interoperability and transparency. # # OVERALL LICENSE (this file): Apache-2.0, except where explicitly marked: # # SPDX-License-Identifier: MIT # Keep these notices and the embedded MIT text if you redistribute this file. # NOTE (style/provenance): # Code in this module has been largely copy-and-pasted from the Fish-S1-DAC and DAC repositories, # and refactored with help from ChatGPT/Claude (these models also helped with licensing). # Thus, it stylistically differs from the rest of the codebase (I'm not even sure about internal consistency) # and is likely much messier than it would have been had it been written from scratch. from __future__ import annotations import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from torch import Tensor, nn from torch.nn import functional as F from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrize import remove_parametrizations from einops import rearrange # -------------------------------------------------------------------- # Shared helpers # -------------------------------------------------------------------- def find_multiple(n: int, k: int) -> int: return n if n % k == 0 else n + k - (n % k) def unpad1d(x: Tensor, paddings: Tuple[int, int]) -> Tensor: """Remove padding from x, handling properly zero padding. Only for 1d!""" padding_left, padding_right = paddings assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) assert (padding_left + padding_right) <= x.shape[-1] end = x.shape[-1] - padding_right return x[..., padding_left:end] def get_extra_padding_for_conv1d( x: Tensor, kernel_size: int, stride: int, padding_total: int = 0 ) -> int: """See pad_for_conv1d; enough right pad so striding evenly covers length.""" length = x.shape[-1] n_frames = (length - kernel_size + padding_total) / stride + 1 ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) return ideal_length - length def pad1d( x: Tensor, paddings: Tuple[int, int], mode: str = "zeros", value: float = 0.0, ) -> Tensor: """ Reflect‑safe 1D pad: if reflect would underflow on small inputs, insert temporary right zero-pad before reflecting. """ length = x.shape[-1] padding_left, padding_right = paddings assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) if mode == "reflect": max_pad = max(padding_left, padding_right) extra_pad = 0 if length <= max_pad: extra_pad = max_pad - length + 1 x = F.pad(x, (0, extra_pad)) padded = F.pad(x, (padding_left, padding_right), mode, value) end = padded.shape[-1] - extra_pad return padded[..., :end] else: return F.pad(x, (padding_left, padding_right), mode, value) # -------------------------------------------------------------------- # DAC Layers (adapted) — MIT # Original: https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/layers.py # SPDX-License-Identifier: MIT # -------------------------------------------------------------------- def WNConv1d(*args, **kwargs): return weight_norm(nn.Conv1d(*args, **kwargs)) def WNConvTranspose1d(*args, **kwargs): return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) @torch.jit.script def snake(x: Tensor, alpha: Tensor) -> Tensor: shape = x.shape x = x.reshape(shape[0], shape[1], -1) x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) x = x.reshape(shape) return x class Snake1d(nn.Module): def __init__(self, channels: int): super().__init__() self.alpha = nn.Parameter(torch.ones(1, channels, 1)) def forward(self, x: Tensor) -> Tensor: return snake(x, self.alpha) # -------------------------------------------------------------------- # DAC Vector Quantize (adapted) — MIT # Original: https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/quantize.py # SPDX-License-Identifier: MIT # -------------------------------------------------------------------- class VectorQuantize(nn.Module): """ VQ with factorized, l2-normalized codes (ViT‑VQGAN style). I/O in (B, D, T). """ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): super().__init__() self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) self.codebook = nn.Embedding(codebook_size, codebook_dim) def forward(self, z: Tensor): z_e = self.in_proj(z) # (B, D, T) z_q, indices = self.decode_latents(z_e) commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) z_q = z_e + (z_q - z_e).detach() # straight‑through z_q = self.out_proj(z_q) return z_q, commitment_loss, codebook_loss, indices, z_e def embed_code(self, embed_id: Tensor) -> Tensor: return F.embedding(embed_id, self.codebook.weight) def decode_code(self, embed_id: Tensor) -> Tensor: return self.embed_code(embed_id).transpose(1, 2) def decode_latents(self, latents: Tensor) -> Tuple[Tensor, Tensor]: encodings = rearrange(latents, "b d t -> (b t) d") codebook = self.codebook.weight encodings = F.normalize(encodings) codebook = F.normalize(codebook) dist = ( encodings.pow(2).sum(1, keepdim=True) - 2 * encodings @ codebook.t() + codebook.pow(2).sum(1, keepdim=True).t() ) indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) z_q = self.decode_code(indices) return z_q, indices class ResidualVectorQuantize(nn.Module): """SoundStream-style residual VQ stack.""" def __init__( self, input_dim: int = 512, n_codebooks: int = 9, codebook_size: int = 1024, codebook_dim: Union[int, List[int]] = 8, quantizer_dropout: float = 0.0, ): super().__init__() if isinstance(codebook_dim, int): codebook_dim = [codebook_dim for _ in range(n_codebooks)] self.n_codebooks = n_codebooks self.codebook_dim = codebook_dim self.codebook_size = codebook_size self.quantizers = nn.ModuleList([ VectorQuantize(input_dim, codebook_size, codebook_dim[i]) for i in range(n_codebooks) ]) self.quantizer_dropout = quantizer_dropout def forward(self, z: Tensor, n_quantizers: Optional[int] = None): z_q = 0 residual = z commitment_loss = 0 codebook_loss = 0 codebook_indices = [] latents = [] if n_quantizers is None: n_quantizers = self.n_codebooks if self.training: n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) n_dropout = int(z.shape[0] * self.quantizer_dropout) n_quantizers[:n_dropout] = dropout[:n_dropout] n_quantizers = n_quantizers.to(z.device) for i, quantizer in enumerate(self.quantizers): if self.training is False and i >= n_quantizers: break z_q_i, commit_i, codebk_i, indices_i, z_e_i = quantizer(residual) mask = (torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers) z_q = z_q + z_q_i * mask[:, None, None] residual = residual - z_q_i commitment_loss += (commit_i * mask).mean() codebook_loss += (codebk_i * mask).mean() codebook_indices.append(indices_i) latents.append(z_e_i) codes = torch.stack(codebook_indices, dim=1) latents = torch.cat(latents, dim=1) return z_q, codes, latents, commitment_loss, codebook_loss def from_codes(self, codes: Tensor) -> Tuple[Tensor, Tensor, Tensor]: z_q = 0.0 z_p = [] n_codebooks = codes.shape[1] for i in range(n_codebooks): z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) z_p.append(z_p_i) z_q_i = self.quantizers[i].out_proj(z_p_i) z_q = z_q + z_q_i return z_q, torch.cat(z_p, dim=1), codes def from_latents(self, latents: Tensor) -> Tuple[Tensor, Tensor, Tensor]: z_q = 0 z_p = [] codes = [] dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0] for i in range(n_codebooks): j, k = dims[i], dims[i + 1] z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) z_p.append(z_p_i) codes.append(codes_i) z_q_i = self.quantizers[i].out_proj(z_p_i) z_q = z_q + z_q_i return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) # -------------------------------------------------------------------- # S1 DAC rvq # -------------------------------------------------------------------- @dataclass class VQResult: z: Tensor codes: Tensor latents: Tensor codebook_loss: Tensor commitment_loss: Tensor semantic_distill_z: Optional[Tensor] = None class CausalConvNet(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1, padding=None, ): super().__init__() self.conv = nn.Conv1d( in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, groups=groups, ) self.stride = stride self.kernel_size = (kernel_size - 1) * dilation + 1 self.dilation = dilation self.padding = self.kernel_size - self.stride def forward(self, x: Tensor) -> Tensor: pad = self.padding extra = get_extra_padding_for_conv1d(x, self.kernel_size, self.stride, pad) x = pad1d(x, (pad, extra), mode="constant", value=0) return self.conv(x).contiguous() def weight_norm(self, name="weight", dim=0): self.conv = weight_norm(self.conv, name=name, dim=dim) return self def remove_weight_norm(self): self.conv = remove_parametrizations(self.conv) return self class CausalTransConvNet(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None): super().__init__() self.conv = nn.ConvTranspose1d( in_channels, out_channels, kernel_size, stride=stride, dilation=dilation ) self.stride = stride self.kernel_size = kernel_size def forward(self, x: Tensor) -> Tensor: x = self.conv(x) pad = self.kernel_size - self.stride padding_right = math.ceil(pad) padding_left = pad - padding_right x = unpad1d(x, (padding_left, padding_right)) return x.contiguous() def weight_norm(self, name="weight", dim=0): self.conv = weight_norm(self.conv, name=name, dim=dim) return self def remove_weight_norm(self): self.conv = remove_parametrizations(self.conv) return self def CausalWNConv1d(*args, **kwargs): return CausalConvNet(*args, **kwargs).weight_norm() def CausalWNConvTranspose1d(*args, **kwargs): return CausalTransConvNet(*args, **kwargs).weight_norm() class ConvNeXtBlock(nn.Module): r"""ConvNeXt Block (1D). DwConv -> (N, C, L) → (N, L, C) -> LN -> Linear -> GELU -> Linear -> (N, C, L) with residual """ def __init__( self, dim: int, layer_scale_init_value: float = 1e-6, mlp_ratio: float = 4.0, kernel_size: int = 7, dilation: int = 1, ): super().__init__() convnet_type = CausalConvNet self.dwconv = convnet_type( dim, dim, kernel_size=kernel_size, groups=dim, dilation=dilation, ) # depthwise conv self.norm = nn.LayerNorm(dim, eps=1e-6) self.pwconv1 = nn.Linear(dim, int(mlp_ratio * dim)) self.act = nn.GELU() self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim) self.gamma = ( nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) if layer_scale_init_value > 0 else None ) def forward(self, x: Tensor, apply_residual: bool = True) -> Tensor: inp = x x = self.dwconv(x) x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C) x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) if self.gamma is not None: x = self.gamma * x x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L) if apply_residual: x = inp + x return x class DownsampleResidualVectorQuantize(nn.Module): def __init__( self, input_dim: int = 1024, n_codebooks: int = 9, codebook_dim: int = 8, quantizer_dropout: float = 0.5, codebook_size: int = 1024, semantic_codebook_size: int = 4096, downsample_factor: Tuple[int, ...] = (2, 2), downsample_dims: Optional[Tuple[int, ...]] = None, pre_module: Optional[nn.Module] = None, post_module: Optional[nn.Module] = None, semantic_predictor_module: Optional[nn.Module] = None, ): super().__init__() if downsample_dims is None: downsample_dims = tuple(input_dim for _ in range(len(downsample_factor))) all_dims = (input_dim,) + tuple(downsample_dims) self.semantic_quantizer = ResidualVectorQuantize( input_dim=input_dim, n_codebooks=1, codebook_size=semantic_codebook_size, codebook_dim=codebook_dim, quantizer_dropout=0.0, ) self.quantizer = ResidualVectorQuantize( input_dim=input_dim, n_codebooks=n_codebooks, codebook_size=codebook_size, codebook_dim=codebook_dim, quantizer_dropout=quantizer_dropout, ) convnet_type = CausalConvNet transconvnet_type = CausalTransConvNet self.downsample = nn.Sequential( *[ nn.Sequential( convnet_type(all_dims[idx], all_dims[idx + 1], kernel_size=factor, stride=factor), ConvNeXtBlock(dim=all_dims[idx + 1]), ) for idx, factor in enumerate(downsample_factor) ] ) self.upsample = nn.Sequential( *[ nn.Sequential( transconvnet_type(all_dims[idx + 1], all_dims[idx], kernel_size=factor, stride=factor), ConvNeXtBlock(dim=all_dims[idx]), ) for idx, factor in reversed(list(enumerate(downsample_factor))) ] ) self.apply(self._init_weights) self.pre_module = pre_module if pre_module is not None else nn.Identity() self.post_module = post_module if post_module is not None else nn.Identity() self.semantic_predictor_module = ( semantic_predictor_module if semantic_predictor_module is not None else nn.Identity() ) @staticmethod def _init_weights(m): if isinstance(m, (nn.Conv1d, nn.Linear)): nn.init.trunc_normal_(m.weight, std=0.02) if getattr(m, "bias", None) is not None: nn.init.constant_(m.bias, 0) def forward(self, z: Tensor, n_quantizers: Optional[int] = None, semantic_len: Optional[Tensor] = None, **kwargs): # z: (B, D, T) original_shape = z.shape if semantic_len is None: semantic_len = torch.LongTensor([z.shape[-1]]) z = self.downsample(z) z = self.pre_module(z) # (B, D, T) or (B, T, D) depending on module; original uses channels-first in/out semantic_z, semantic_codes, semantic_latents, semantic_commitment_loss, semantic_codebook_loss = \ self.semantic_quantizer(z) residual_z = z - semantic_z residual_z, codes, latents, commitment_loss, codebook_loss = self.quantizer(residual_z, n_quantizers=n_quantizers) z = semantic_z + residual_z commitment_loss = commitment_loss + semantic_commitment_loss codebook_loss = codebook_loss + semantic_codebook_loss codes = torch.cat([semantic_codes, codes], dim=1) latents = torch.cat([semantic_latents, latents], dim=1) z = self.post_module(z) z = self.upsample(z) # Pad or crop z to match original shape (time dimension) diff = original_shape[-1] - z.shape[-1] right = 0 left = abs(diff) - right if diff > 0: z = F.pad(z, (left, right)) elif diff < 0: z = z[..., left:] return VQResult( z=z, codes=codes, latents=latents, commitment_loss=commitment_loss, codebook_loss=codebook_loss, ) def decode(self, indices: Tensor) -> Tensor: new_indices = torch.zeros_like(indices) new_indices[:, 0] = torch.clamp(indices[:, 0], max=self.semantic_quantizer.codebook_size - 1) new_indices[:, 1:] = torch.clamp(indices[:, 1:], max=self.quantizer.codebook_size - 1) z_q_semantic = self.semantic_quantizer.from_codes(new_indices[:, :1])[0] z_q_residual = self.quantizer.from_codes(new_indices[:, 1:])[0] z_q = z_q_semantic + z_q_residual z_q = self.post_module(z_q) z_q = self.upsample(z_q) return z_q # -------------------------------------------------------------------- # Transformer stack # -------------------------------------------------------------------- @dataclass class ModelArgs: block_size: int = 2048 n_layer: int = 8 n_head: int = 8 dim: int = 512 intermediate_size: int = 1536 n_local_heads: int = -1 head_dim: int = 64 rope_base: float = 10000 norm_eps: float = 1e-5 dropout_rate: float = 0.1 attn_dropout_rate: float = 0.1 channels_first: bool = True # to be compatible with conv1d input/output pos_embed_type: str = "rope" # "rope" or "conformer" max_relative_position: int = 128 def __post_init__(self): if self.n_local_heads == -1: self.n_local_heads = self.n_head if self.intermediate_size is None: hidden_dim = 4 * self.dim n_hidden = int(2 * hidden_dim / 3) self.intermediate_size = find_multiple(n_hidden, 256) assert self.pos_embed_type in ["rope", "conformer"] class KVCache(nn.Module): def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): super().__init__() cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor): # input_pos: [S], k_val: [B, H, S, D] assert input_pos.shape[0] == k_val.shape[2] k_out = self.k_cache v_out = self.v_cache k_out[:, :, input_pos] = k_val v_out[:, :, input_pos] = v_val return ( k_out[:, :, : input_pos.max() + 1, :], v_out[:, :, : input_pos.max() + 1, :], ) def clear_cache(self, prompt_len: int): self.k_cache[:, :, prompt_len:, :].fill_(0) self.v_cache[:, :, prompt_len:, :].fill_(0) class Transformer(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) self.norm = RMSNorm(config.dim, eps=config.norm_eps) if config.pos_embed_type == "rope": freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim, self.config.rope_base) self.register_buffer("freqs_cis", freqs_cis) else: self.register_buffer("freqs_cis", None) causal_mask = torch.tril(torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool)) self.register_buffer("causal_mask", causal_mask) self.max_batch_size = -1 self.max_seq_length = -1 self.use_kv_cache = False def setup_caches(self, max_batch_size, max_seq_length): head_dim = self.config.dim // self.config.n_head max_seq_length = find_multiple(max_seq_length, 8) self.max_seq_length = max_seq_length self.max_batch_size = max_batch_size dtype = self.norm.weight.dtype device = self.norm.weight.device for b in self.layers: b.attention.kv_cache = KVCache( max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype ).to(device) self.use_kv_cache = True def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, mask: Optional[Tensor] = None) -> Tensor: if self.config.pos_embed_type == "rope": assert self.freqs_cis is not None freqs_cis = self.freqs_cis[input_pos] else: freqs_cis = None if mask is None: if not self.training and self.use_kv_cache: mask = self.causal_mask[None, None, input_pos] mask = mask[..., : input_pos.max() + 1] else: mask = self.causal_mask[None, None, input_pos] mask = mask[..., input_pos] for layer in self.layers: x = layer(x, input_pos, freqs_cis, mask) x = self.norm(x) return x class TransformerBlock(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() self.attention = Attention(config) self.feed_forward = FeedForward(config) self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) self.attention_layer_scale = LayerScale(config.dim, inplace=True) self.ffn_layer_scale = LayerScale(config.dim, inplace=True) def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: h = x + self.attention_layer_scale( self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) ) out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h))) return out class Attention(nn.Module): def __init__(self, config: ModelArgs): super().__init__() assert config.dim % config.n_head == 0 total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False) self.kv_cache = None self.n_head = config.n_head self.head_dim = config.head_dim self.n_local_heads = config.n_local_heads self.dim = config.dim self.attn_dropout_rate = config.attn_dropout_rate self.pos_embed_type = config.pos_embed_type if self.pos_embed_type == "conformer": self.max_relative_position = config.max_relative_position num_pos_embeddings = 2 * config.max_relative_position + 1 self.rel_pos_embeddings = nn.Parameter(torch.zeros(num_pos_embeddings, self.head_dim)) nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02) def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor: positions = torch.arange(seqlen, device=q.device) relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) # [S, S] relative_positions = torch.clamp(relative_positions + self.max_relative_position, 0, 2 * self.max_relative_position) rel_embeddings = self.rel_pos_embeddings[relative_positions] # [S, S, D] q = q.transpose(1, 2) # [B, S, H, D] rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) # [B, S, H, S] rel_logits = rel_logits.transpose(1, 2) # [B, H, S, S] return rel_logits def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: bsz, seqlen, _ = x.shape kv_size = self.n_local_heads * self.head_dim q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1) context_seqlen = seqlen q = q.view(bsz, seqlen, self.n_head, self.head_dim) k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) if self.pos_embed_type == "rope": q = apply_rotary_emb(q, freqs_cis) k = apply_rotary_emb(k, freqs_cis) q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) if self.kv_cache is not None: k, v = self.kv_cache.update(input_pos, k, v) k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) if self.pos_embed_type == "conformer": scale = 1.0 / math.sqrt(self.head_dim) scores = torch.matmul(q, k.transpose(-2, -1)) * scale rel_scores = self._compute_conformer_pos_scores(q, seqlen) scores = scores + rel_scores if mask is not None: scores = scores.masked_fill(~mask, float("-inf")) attn = F.softmax(scores, dim=-1) if self.attn_dropout_rate > 0 and self.training: attn = F.dropout(attn, p=self.attn_dropout_rate) y = torch.matmul(attn, v) else: y = F.scaled_dot_product_attention( q, k, v, dropout_p=self.attn_dropout_rate if self.training else 0.0, attn_mask=mask, ) y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head) y = self.wo(y) return y class FeedForward(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, x: Tensor) -> Tensor: return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x))) class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) def forward(self, x: Tensor) -> Tensor: output = self._norm(x.float()).type_as(x) return output * self.weight class LayerScale(nn.Module): def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-2, inplace: bool = False) -> None: super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: Tensor) -> Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma class WindowLimitedTransformer(Transformer): """Transformer with window-limited causal attention.""" def __init__( self, config: ModelArgs, input_dim: int = 512, window_size: Optional[int] = None, causal: bool = True, look_ahead_conv: Optional[nn.Module] = None, ): super().__init__(config) self.window_size = window_size self.causal = causal self.channels_first = config.channels_first self.look_ahead_conv = look_ahead_conv if look_ahead_conv is not None else nn.Identity() self.input_proj = nn.Linear(input_dim, config.dim) if input_dim != config.dim else nn.Identity() self.output_proj = nn.Linear(config.dim, input_dim) if input_dim != config.dim else nn.Identity() def make_window_limited_mask(self, max_length: int, x_lens: Optional[Tensor] = None) -> Tensor: if self.causal: mask = torch.tril(torch.ones(max_length, max_length)) row_indices = torch.arange(max_length).view(-1, 1) window_size = self.window_size or max_length valid_range = (row_indices - window_size + 1).clamp(min=0) column_indices = torch.arange(max_length) mask = (column_indices >= valid_range) & mask.bool() else: raise NotImplementedError mask = mask.bool()[None, None] return mask def make_mask(self, max_length: int, x_lens: Optional[Tensor] = None) -> Tensor: if self.causal: mask = torch.tril(torch.ones(max_length, max_length)) else: mask = torch.ones(max_length, max_length) mask = mask.bool()[None, None] for i, x_len in enumerate(x_lens): mask[:x_len, i] = 0 mask = mask.bool()[None, None] return mask def forward(self, x: Tensor, x_lens: Optional[Tensor] = None) -> Tensor: if self.channels_first: x = x.transpose(1, 2) x = self.input_proj(x) x = self.look_ahead_conv(x) input_pos = torch.arange(x.shape[1], device=x.device) max_length = x.shape[1] if self.window_size is not None: mask = self.make_window_limited_mask(max_length, x_lens) else: mask = self.make_mask(max_length, x_lens) mask = mask.to(x.device) x = super().forward(x, input_pos, mask) x = self.output_proj(x) if self.channels_first: x = x.transpose(1, 2) return x def precompute_freqs_cis( seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16 ) -> Tensor: freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) t = torch.arange(seq_len, device=freqs.device) freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) return cache.to(dtype=dtype) def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: xshaped = x.float().reshape(*x.shape[:-1], -1, 2) freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) x_out2 = torch.stack( [ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], ], -1, ) x_out2 = x_out2.flatten(3) return x_out2.type_as(x) def init_weights(m): if isinstance(m, nn.Conv1d): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) # -------------------------------------------------------------------- # Top-level AE # -------------------------------------------------------------------- class EncoderBlock(nn.Module): def __init__( self, dim: int = 16, stride: int = 1, causal: bool = False, n_t_layer: int = 0, transformer_general_config=None, ): super().__init__() conv_class = CausalWNConv1d if causal else WNConv1d transformer_module = ( nn.Identity() if n_t_layer == 0 else WindowLimitedTransformer( causal=causal, input_dim=dim, window_size=512, config=transformer_general_config( n_layer=n_t_layer, n_head=dim // 64, dim=dim, intermediate_size=dim * 3, ), ) ) self.block = nn.Sequential( # three multi‑receptive‑field residual units ResidualUnit(dim // 2, dilation=1, causal=causal), ResidualUnit(dim // 2, dilation=3, causal=causal), ResidualUnit(dim // 2, dilation=9, causal=causal), Snake1d(dim // 2), conv_class(dim // 2, dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)), transformer_module, ) def forward(self, x: Tensor) -> Tensor: return self.block(x) class ResidualUnit(nn.Module): def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False): super().__init__() conv_class = CausalWNConv1d if causal else WNConv1d pad = ((7 - 1) * dilation) // 2 self.block = nn.Sequential( Snake1d(dim), conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad), Snake1d(dim), conv_class(dim, dim, kernel_size=1), ) self.causal = causal def forward(self, x: Tensor) -> Tensor: y = self.block(x) pad = x.shape[-1] - y.shape[-1] if pad > 0: if self.causal: x = x[..., :-pad] else: x = x[..., pad // 2 : -pad // 2] return x + y class Encoder(nn.Module): def __init__( self, d_model: int = 64, strides: List[int] = [2, 4, 8, 8], d_latent: int = 64, n_transformer_layers: List[int] = [0, 0, 4, 4], transformer_general_config: Optional[ModelArgs] = None, causal: bool = False, ): super().__init__() conv_class = CausalWNConv1d if causal else WNConv1d layers: List[nn.Module] = [conv_class(1, d_model, kernel_size=7, padding=3)] for stride, n_t_layer in zip(strides, n_transformer_layers): d_model *= 2 layers.append( EncoderBlock( d_model, stride=stride, causal=causal, n_t_layer=n_t_layer, transformer_general_config=transformer_general_config, ) ) layers += [Snake1d(d_model), conv_class(d_model, d_latent, kernel_size=3, padding=1)] self.block = nn.Sequential(*layers) self.enc_dim = d_model def forward(self, x: Tensor) -> Tensor: return self.block(x) class DecoderBlock(nn.Module): def __init__( self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, causal: bool = False, n_t_layer: int = 0, transformer_general_config=None, ): super().__init__() conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d transformer_module = ( nn.Identity() if n_t_layer == 0 else WindowLimitedTransformer( causal=causal, input_dim=input_dim, window_size=None, config=transformer_general_config( n_layer=n_t_layer, n_head=input_dim // 64, dim=input_dim, intermediate_size=input_dim * 3, ), ) ) self.block = nn.Sequential( Snake1d(input_dim), conv_trans_class(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)), ResidualUnit(output_dim, dilation=1, causal=causal), ResidualUnit(output_dim, dilation=3, causal=causal), ResidualUnit(output_dim, dilation=9, causal=causal), ) def forward(self, x: Tensor) -> Tensor: return self.block(x) class Decoder(nn.Module): def __init__( self, input_channel: int, channels: int, rates: List[int], d_out: int = 1, causal: bool = False, n_transformer_layers: List[int] = [0, 0, 0, 0], transformer_general_config=None, ): super().__init__() conv_class = CausalWNConv1d if causal else WNConv1d layers: List[nn.Module] = [conv_class(input_channel, channels, kernel_size=7, padding=3)] for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)): input_dim = channels // 2**i output_dim = channels // 2 ** (i + 1) layers.append( DecoderBlock( input_dim, output_dim, stride, causal=causal, n_t_layer=n_t_layer, transformer_general_config=transformer_general_config, ) ) layers += [Snake1d(output_dim), conv_class(output_dim, d_out, kernel_size=7, padding=3), nn.Tanh()] self.model = nn.Sequential(*layers) def forward(self, x: Tensor) -> Tensor: return self.model(x) class DAC(nn.Module): def __init__( self, encoder_dim: int = 64, encoder_rates: List[int] = [2, 4, 8, 8], latent_dim: Optional[int] = None, decoder_dim: int = 1536, decoder_rates: List[int] = [8, 8, 4, 2], quantizer: Optional[nn.Module] = None, sample_rate: int = 44100, causal: bool = True, encoder_transformer_layers: List[int] = [0, 0, 0, 0], decoder_transformer_layers: List[int] = [0, 0, 0, 0], transformer_general_config=None, ): super().__init__() self.encoder_dim = encoder_dim self.encoder_rates = encoder_rates self.decoder_dim = decoder_dim self.decoder_rates = decoder_rates self.sample_rate = sample_rate if latent_dim is None: latent_dim = encoder_dim * (2 ** len(encoder_rates)) self.latent_dim = latent_dim self.hop_length = int(np.prod(encoder_rates)) self.encoder = Encoder( encoder_dim, encoder_rates, latent_dim, causal=causal, n_transformer_layers=encoder_transformer_layers, transformer_general_config=transformer_general_config, ) self.quantizer = quantizer self.decoder = Decoder( latent_dim, decoder_dim, decoder_rates, causal=causal, n_transformer_layers=decoder_transformer_layers, transformer_general_config=transformer_general_config, ) self.sample_rate = sample_rate self.apply(init_weights) self.delay = self.get_delay() self.frame_length = self.hop_length * 4 def get_output_length(self, input_length: int) -> int: length = input_length for stride in self.encoder_rates: length = math.ceil(length / stride) return length def get_delay(self) -> int: l_out = self.get_output_length(0) L = l_out layers = [layer for layer in self.modules() if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d))] for layer in reversed(layers): d = layer.dilation[0] k = layer.kernel_size[0] s = layer.stride[0] if isinstance(layer, nn.ConvTranspose1d): L = ((L - d * (k - 1) - 1) / s) + 1 elif isinstance(layer, nn.Conv1d): L = (L - 1) * s + d * (k - 1) + 1 L = math.ceil(L) l_in = L return (l_in - l_out) // 2 def preprocess(self, audio_data: Tensor, sample_rate: Optional[int]) -> Tensor: if sample_rate is None: sample_rate = self.sample_rate assert sample_rate == self.sample_rate length = audio_data.shape[-1] right_pad = math.ceil(length / self.hop_length) * self.hop_length - length audio_data = F.pad(audio_data, (0, right_pad)) return audio_data def encode( self, audio_data: Tensor, audio_lengths: Optional[Tensor] = None, n_quantizers: Optional[int] = None, **kwargs, ): """Encode audio to quantized code indices.""" if audio_data.ndim == 2: audio_data = audio_data.unsqueeze(1) length = audio_data.shape[-1] right_pad = math.ceil(length / self.frame_length) * self.frame_length - length audio_data = F.pad(audio_data, (0, right_pad)) if audio_lengths is None: audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device) z = self.encoder(audio_data) vq_results = self.quantizer(z, n_quantizers, **kwargs) indices = vq_results.codes indices_lens = torch.ceil(audio_lengths / self.frame_length).long() return indices, indices_lens def decode(self, indices: Tensor, feature_lengths: Tensor): """Decode code indices to audio.""" if indices.ndim == 2: indices = indices[None] z = self.quantizer.decode(indices) audio_lengths = feature_lengths * self.frame_length return self.decoder(z), audio_lengths def encode_to_codes(self, audio: Tensor, audio_lengths: Optional[Tensor] = None, n_quantizers: Optional[int] = None, **kw): return self.encode(audio, audio_lengths, n_quantizers, **kw) def decode_codes(self, indices: Tensor, feature_lengths: Tensor): return self.decode(indices, feature_lengths) @torch.no_grad() def encode_zq(self, audio_data: Tensor) -> Tensor: indices, _ = self.encode(audio_data) new_indices = torch.zeros_like(indices) new_indices[:, 0] = torch.clamp(indices[:, 0], max=self.quantizer.semantic_quantizer.codebook_size - 1) new_indices[:, 1:] = torch.clamp(indices[:, 1:], max=self.quantizer.quantizer.codebook_size - 1) z_q_semantic = self.quantizer.semantic_quantizer.from_codes(new_indices[:, :1])[0] z_q_residual = self.quantizer.quantizer.from_codes(new_indices[:, 1:])[0] z_q = z_q_semantic + z_q_residual return z_q @torch.no_grad() def decode_zq(self, z_q: Tensor) -> Tensor: z_q = self.quantizer.post_module(z_q) z_q = self.quantizer.upsample(z_q) return self.decoder(z_q) @property def device(self) -> torch.device: return next(self.parameters()).device @property def dtype(self) -> torch.dtype: return next(self.parameters()).dtype # -------------------------------------------------------------------- # Build helpers # -------------------------------------------------------------------- def build_ae(**cfg) -> DAC: """ Factory used by external loaders """ # Shared transformer config for the RVQ pre/post modules q_config = ModelArgs( block_size=4096, n_layer=8, n_head=16, dim=1024, intermediate_size=3072, head_dim=64, norm_eps=1e-5, dropout_rate=0.1, attn_dropout_rate=0.1, channels_first=True ) def make_transformer(): return WindowLimitedTransformer( causal=True, window_size=128, input_dim=1024, config=q_config ) quantizer = DownsampleResidualVectorQuantize( input_dim=1024, n_codebooks=9, codebook_size=1024, codebook_dim=8, quantizer_dropout=0.5, downsample_factor=(2, 2), semantic_codebook_size=4096, pre_module=make_transformer(), post_module=make_transformer(), ) def transformer_general_config(**kw): return ModelArgs( block_size=kw.get("block_size", 16384), n_layer=kw.get("n_layer", 8), n_head=kw.get("n_head", 8), dim=kw.get("dim", 512), intermediate_size=kw.get("intermediate_size", 1536), n_local_heads=kw.get("n_local_heads", -1), head_dim=kw.get("head_dim", 64), rope_base=kw.get("rope_base", 10000), norm_eps=kw.get("norm_eps", 1e-5), dropout_rate=kw.get("dropout_rate", 0.1), attn_dropout_rate=kw.get("attn_dropout_rate", 0.1), channels_first=kw.get("channels_first", True), ) dac = DAC( encoder_dim=64, encoder_rates=[2, 4, 8, 8], latent_dim=1024, decoder_dim=1536, decoder_rates=[8, 8, 4, 2], quantizer=quantizer, sample_rate=44100, causal=True, encoder_transformer_layers=[0, 0, 0, 4], decoder_transformer_layers=[4, 0, 0, 0], transformer_general_config=transformer_general_config, ) return dac __all__ = [ "DAC", "build_ae", "VectorQuantize", "ResidualVectorQuantize", "DownsampleResidualVectorQuantize", ] # ----- BEGIN DAC MIT LICENSE ----- # MIT License # Copyright (c) 2023-present, Descript # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ----- END DAC MIT LICENSE -----