Spaces:
Running
on
Zero
Running
on
Zero
| # SPDX-FileCopyrightText: 2025 Jordan Darefsky | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # This file contains portions adapted from: | |
| # • Descript Audio Codec (DAC) — MIT License (full text appended below) | |
| # • Fish-Speech S1 DAC Autoencoder — reference implementation (Apache-2.0 / CC-BY-NC), | |
| # rewritten here in a single-file Torch module for interoperability and transparency. | |
| # | |
| # OVERALL LICENSE (this file): Apache-2.0, except where explicitly marked: | |
| # # SPDX-License-Identifier: MIT | |
| # Keep these notices and the embedded MIT text if you redistribute this file. | |
| # NOTE (style/provenance): | |
| # Code in this module has been largely copy-and-pasted from the Fish-S1-DAC and DAC repositories, | |
| # and refactored with help from ChatGPT/Claude (these models also helped with licensing). | |
| # Thus, it stylistically differs from the rest of the codebase (I'm not even sure about internal consistency) | |
| # and is likely much messier than it would have been had it been written from scratch. | |
| from __future__ import annotations | |
| import math | |
| from dataclasses import dataclass | |
| from typing import List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from torch import Tensor, nn | |
| from torch.nn import functional as F | |
| from torch.nn.utils.parametrizations import weight_norm | |
| from torch.nn.utils.parametrize import remove_parametrizations | |
| from einops import rearrange | |
| # -------------------------------------------------------------------- | |
| # Shared helpers | |
| # -------------------------------------------------------------------- | |
| def find_multiple(n: int, k: int) -> int: | |
| return n if n % k == 0 else n + k - (n % k) | |
| def unpad1d(x: Tensor, paddings: Tuple[int, int]) -> Tensor: | |
| """Remove padding from x, handling properly zero padding. Only for 1d!""" | |
| padding_left, padding_right = paddings | |
| assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) | |
| assert (padding_left + padding_right) <= x.shape[-1] | |
| end = x.shape[-1] - padding_right | |
| return x[..., padding_left:end] | |
| def get_extra_padding_for_conv1d( | |
| x: Tensor, kernel_size: int, stride: int, padding_total: int = 0 | |
| ) -> int: | |
| """See pad_for_conv1d; enough right pad so striding evenly covers length.""" | |
| length = x.shape[-1] | |
| n_frames = (length - kernel_size + padding_total) / stride + 1 | |
| ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) | |
| return ideal_length - length | |
| def pad1d( | |
| x: Tensor, | |
| paddings: Tuple[int, int], | |
| mode: str = "zeros", | |
| value: float = 0.0, | |
| ) -> Tensor: | |
| """ | |
| Reflect‑safe 1D pad: if reflect would underflow on small inputs, insert | |
| temporary right zero-pad before reflecting. | |
| """ | |
| length = x.shape[-1] | |
| padding_left, padding_right = paddings | |
| assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) | |
| if mode == "reflect": | |
| max_pad = max(padding_left, padding_right) | |
| extra_pad = 0 | |
| if length <= max_pad: | |
| extra_pad = max_pad - length + 1 | |
| x = F.pad(x, (0, extra_pad)) | |
| padded = F.pad(x, (padding_left, padding_right), mode, value) | |
| end = padded.shape[-1] - extra_pad | |
| return padded[..., :end] | |
| else: | |
| return F.pad(x, (padding_left, padding_right), mode, value) | |
| # -------------------------------------------------------------------- | |
| # DAC Layers (adapted) — MIT | |
| # Original: https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/layers.py | |
| # SPDX-License-Identifier: MIT | |
| # -------------------------------------------------------------------- | |
| def WNConv1d(*args, **kwargs): | |
| return weight_norm(nn.Conv1d(*args, **kwargs)) | |
| def WNConvTranspose1d(*args, **kwargs): | |
| return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) | |
| def snake(x: Tensor, alpha: Tensor) -> Tensor: | |
| shape = x.shape | |
| x = x.reshape(shape[0], shape[1], -1) | |
| x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) | |
| x = x.reshape(shape) | |
| return x | |
| class Snake1d(nn.Module): | |
| def __init__(self, channels: int): | |
| super().__init__() | |
| self.alpha = nn.Parameter(torch.ones(1, channels, 1)) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return snake(x, self.alpha) | |
| # -------------------------------------------------------------------- | |
| # DAC Vector Quantize (adapted) — MIT | |
| # Original: https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/quantize.py | |
| # SPDX-License-Identifier: MIT | |
| # -------------------------------------------------------------------- | |
| class VectorQuantize(nn.Module): | |
| """ | |
| VQ with factorized, l2-normalized codes (ViT‑VQGAN style). | |
| I/O in (B, D, T). | |
| """ | |
| def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): | |
| super().__init__() | |
| self.codebook_size = codebook_size | |
| self.codebook_dim = codebook_dim | |
| self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) | |
| self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) | |
| self.codebook = nn.Embedding(codebook_size, codebook_dim) | |
| def forward(self, z: Tensor): | |
| z_e = self.in_proj(z) # (B, D, T) | |
| z_q, indices = self.decode_latents(z_e) | |
| commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) | |
| codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) | |
| z_q = z_e + (z_q - z_e).detach() # straight‑through | |
| z_q = self.out_proj(z_q) | |
| return z_q, commitment_loss, codebook_loss, indices, z_e | |
| def embed_code(self, embed_id: Tensor) -> Tensor: | |
| return F.embedding(embed_id, self.codebook.weight) | |
| def decode_code(self, embed_id: Tensor) -> Tensor: | |
| return self.embed_code(embed_id).transpose(1, 2) | |
| def decode_latents(self, latents: Tensor) -> Tuple[Tensor, Tensor]: | |
| encodings = rearrange(latents, "b d t -> (b t) d") | |
| codebook = self.codebook.weight | |
| encodings = F.normalize(encodings) | |
| codebook = F.normalize(codebook) | |
| dist = ( | |
| encodings.pow(2).sum(1, keepdim=True) | |
| - 2 * encodings @ codebook.t() | |
| + codebook.pow(2).sum(1, keepdim=True).t() | |
| ) | |
| indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) | |
| z_q = self.decode_code(indices) | |
| return z_q, indices | |
| class ResidualVectorQuantize(nn.Module): | |
| """SoundStream-style residual VQ stack.""" | |
| def __init__( | |
| self, | |
| input_dim: int = 512, | |
| n_codebooks: int = 9, | |
| codebook_size: int = 1024, | |
| codebook_dim: Union[int, List[int]] = 8, | |
| quantizer_dropout: float = 0.0, | |
| ): | |
| super().__init__() | |
| if isinstance(codebook_dim, int): | |
| codebook_dim = [codebook_dim for _ in range(n_codebooks)] | |
| self.n_codebooks = n_codebooks | |
| self.codebook_dim = codebook_dim | |
| self.codebook_size = codebook_size | |
| self.quantizers = nn.ModuleList([ | |
| VectorQuantize(input_dim, codebook_size, codebook_dim[i]) | |
| for i in range(n_codebooks) | |
| ]) | |
| self.quantizer_dropout = quantizer_dropout | |
| def forward(self, z: Tensor, n_quantizers: Optional[int] = None): | |
| z_q = 0 | |
| residual = z | |
| commitment_loss = 0 | |
| codebook_loss = 0 | |
| codebook_indices = [] | |
| latents = [] | |
| if n_quantizers is None: | |
| n_quantizers = self.n_codebooks | |
| if self.training: | |
| n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 | |
| dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) | |
| n_dropout = int(z.shape[0] * self.quantizer_dropout) | |
| n_quantizers[:n_dropout] = dropout[:n_dropout] | |
| n_quantizers = n_quantizers.to(z.device) | |
| for i, quantizer in enumerate(self.quantizers): | |
| if self.training is False and i >= n_quantizers: | |
| break | |
| z_q_i, commit_i, codebk_i, indices_i, z_e_i = quantizer(residual) | |
| mask = (torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers) | |
| z_q = z_q + z_q_i * mask[:, None, None] | |
| residual = residual - z_q_i | |
| commitment_loss += (commit_i * mask).mean() | |
| codebook_loss += (codebk_i * mask).mean() | |
| codebook_indices.append(indices_i) | |
| latents.append(z_e_i) | |
| codes = torch.stack(codebook_indices, dim=1) | |
| latents = torch.cat(latents, dim=1) | |
| return z_q, codes, latents, commitment_loss, codebook_loss | |
| def from_codes(self, codes: Tensor) -> Tuple[Tensor, Tensor, Tensor]: | |
| z_q = 0.0 | |
| z_p = [] | |
| n_codebooks = codes.shape[1] | |
| for i in range(n_codebooks): | |
| z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) | |
| z_p.append(z_p_i) | |
| z_q_i = self.quantizers[i].out_proj(z_p_i) | |
| z_q = z_q + z_q_i | |
| return z_q, torch.cat(z_p, dim=1), codes | |
| def from_latents(self, latents: Tensor) -> Tuple[Tensor, Tensor, Tensor]: | |
| z_q = 0 | |
| z_p = [] | |
| codes = [] | |
| dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) | |
| n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0] | |
| for i in range(n_codebooks): | |
| j, k = dims[i], dims[i + 1] | |
| z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) | |
| z_p.append(z_p_i) | |
| codes.append(codes_i) | |
| z_q_i = self.quantizers[i].out_proj(z_p_i) | |
| z_q = z_q + z_q_i | |
| return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) | |
| # -------------------------------------------------------------------- | |
| # S1 DAC rvq | |
| # -------------------------------------------------------------------- | |
| class VQResult: | |
| z: Tensor | |
| codes: Tensor | |
| latents: Tensor | |
| codebook_loss: Tensor | |
| commitment_loss: Tensor | |
| semantic_distill_z: Optional[Tensor] = None | |
| class CausalConvNet(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| dilation=1, | |
| stride=1, | |
| groups=1, | |
| padding=None, | |
| ): | |
| super().__init__() | |
| self.conv = nn.Conv1d( | |
| in_channels, out_channels, kernel_size, | |
| stride=stride, dilation=dilation, groups=groups, | |
| ) | |
| self.stride = stride | |
| self.kernel_size = (kernel_size - 1) * dilation + 1 | |
| self.dilation = dilation | |
| self.padding = self.kernel_size - self.stride | |
| def forward(self, x: Tensor) -> Tensor: | |
| pad = self.padding | |
| extra = get_extra_padding_for_conv1d(x, self.kernel_size, self.stride, pad) | |
| x = pad1d(x, (pad, extra), mode="constant", value=0) | |
| return self.conv(x).contiguous() | |
| def weight_norm(self, name="weight", dim=0): | |
| self.conv = weight_norm(self.conv, name=name, dim=dim) | |
| return self | |
| def remove_weight_norm(self): | |
| self.conv = remove_parametrizations(self.conv) | |
| return self | |
| class CausalTransConvNet(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None): | |
| super().__init__() | |
| self.conv = nn.ConvTranspose1d( | |
| in_channels, out_channels, kernel_size, | |
| stride=stride, dilation=dilation | |
| ) | |
| self.stride = stride | |
| self.kernel_size = kernel_size | |
| def forward(self, x: Tensor) -> Tensor: | |
| x = self.conv(x) | |
| pad = self.kernel_size - self.stride | |
| padding_right = math.ceil(pad) | |
| padding_left = pad - padding_right | |
| x = unpad1d(x, (padding_left, padding_right)) | |
| return x.contiguous() | |
| def weight_norm(self, name="weight", dim=0): | |
| self.conv = weight_norm(self.conv, name=name, dim=dim) | |
| return self | |
| def remove_weight_norm(self): | |
| self.conv = remove_parametrizations(self.conv) | |
| return self | |
| def CausalWNConv1d(*args, **kwargs): | |
| return CausalConvNet(*args, **kwargs).weight_norm() | |
| def CausalWNConvTranspose1d(*args, **kwargs): | |
| return CausalTransConvNet(*args, **kwargs).weight_norm() | |
| class ConvNeXtBlock(nn.Module): | |
| r"""ConvNeXt Block (1D). | |
| DwConv -> (N, C, L) → (N, L, C) -> LN -> Linear -> GELU -> Linear -> (N, C, L) with residual | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| layer_scale_init_value: float = 1e-6, | |
| mlp_ratio: float = 4.0, | |
| kernel_size: int = 7, | |
| dilation: int = 1, | |
| ): | |
| super().__init__() | |
| convnet_type = CausalConvNet | |
| self.dwconv = convnet_type( | |
| dim, dim, kernel_size=kernel_size, | |
| groups=dim, dilation=dilation, | |
| ) # depthwise conv | |
| self.norm = nn.LayerNorm(dim, eps=1e-6) | |
| self.pwconv1 = nn.Linear(dim, int(mlp_ratio * dim)) | |
| self.act = nn.GELU() | |
| self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim) | |
| self.gamma = ( | |
| nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) | |
| if layer_scale_init_value > 0 else None | |
| ) | |
| def forward(self, x: Tensor, apply_residual: bool = True) -> Tensor: | |
| inp = x | |
| x = self.dwconv(x) | |
| x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C) | |
| x = self.norm(x) | |
| x = self.pwconv1(x) | |
| x = self.act(x) | |
| x = self.pwconv2(x) | |
| if self.gamma is not None: | |
| x = self.gamma * x | |
| x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L) | |
| if apply_residual: | |
| x = inp + x | |
| return x | |
| class DownsampleResidualVectorQuantize(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim: int = 1024, | |
| n_codebooks: int = 9, | |
| codebook_dim: int = 8, | |
| quantizer_dropout: float = 0.5, | |
| codebook_size: int = 1024, | |
| semantic_codebook_size: int = 4096, | |
| downsample_factor: Tuple[int, ...] = (2, 2), | |
| downsample_dims: Optional[Tuple[int, ...]] = None, | |
| pre_module: Optional[nn.Module] = None, | |
| post_module: Optional[nn.Module] = None, | |
| semantic_predictor_module: Optional[nn.Module] = None, | |
| ): | |
| super().__init__() | |
| if downsample_dims is None: | |
| downsample_dims = tuple(input_dim for _ in range(len(downsample_factor))) | |
| all_dims = (input_dim,) + tuple(downsample_dims) | |
| self.semantic_quantizer = ResidualVectorQuantize( | |
| input_dim=input_dim, | |
| n_codebooks=1, | |
| codebook_size=semantic_codebook_size, | |
| codebook_dim=codebook_dim, | |
| quantizer_dropout=0.0, | |
| ) | |
| self.quantizer = ResidualVectorQuantize( | |
| input_dim=input_dim, | |
| n_codebooks=n_codebooks, | |
| codebook_size=codebook_size, | |
| codebook_dim=codebook_dim, | |
| quantizer_dropout=quantizer_dropout, | |
| ) | |
| convnet_type = CausalConvNet | |
| transconvnet_type = CausalTransConvNet | |
| self.downsample = nn.Sequential( | |
| *[ | |
| nn.Sequential( | |
| convnet_type(all_dims[idx], all_dims[idx + 1], kernel_size=factor, stride=factor), | |
| ConvNeXtBlock(dim=all_dims[idx + 1]), | |
| ) | |
| for idx, factor in enumerate(downsample_factor) | |
| ] | |
| ) | |
| self.upsample = nn.Sequential( | |
| *[ | |
| nn.Sequential( | |
| transconvnet_type(all_dims[idx + 1], all_dims[idx], kernel_size=factor, stride=factor), | |
| ConvNeXtBlock(dim=all_dims[idx]), | |
| ) | |
| for idx, factor in reversed(list(enumerate(downsample_factor))) | |
| ] | |
| ) | |
| self.apply(self._init_weights) | |
| self.pre_module = pre_module if pre_module is not None else nn.Identity() | |
| self.post_module = post_module if post_module is not None else nn.Identity() | |
| self.semantic_predictor_module = ( | |
| semantic_predictor_module if semantic_predictor_module is not None else nn.Identity() | |
| ) | |
| def _init_weights(m): | |
| if isinstance(m, (nn.Conv1d, nn.Linear)): | |
| nn.init.trunc_normal_(m.weight, std=0.02) | |
| if getattr(m, "bias", None) is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, z: Tensor, n_quantizers: Optional[int] = None, semantic_len: Optional[Tensor] = None, **kwargs): | |
| # z: (B, D, T) | |
| original_shape = z.shape | |
| if semantic_len is None: | |
| semantic_len = torch.LongTensor([z.shape[-1]]) | |
| z = self.downsample(z) | |
| z = self.pre_module(z) # (B, D, T) or (B, T, D) depending on module; original uses channels-first in/out | |
| semantic_z, semantic_codes, semantic_latents, semantic_commitment_loss, semantic_codebook_loss = \ | |
| self.semantic_quantizer(z) | |
| residual_z = z - semantic_z | |
| residual_z, codes, latents, commitment_loss, codebook_loss = self.quantizer(residual_z, n_quantizers=n_quantizers) | |
| z = semantic_z + residual_z | |
| commitment_loss = commitment_loss + semantic_commitment_loss | |
| codebook_loss = codebook_loss + semantic_codebook_loss | |
| codes = torch.cat([semantic_codes, codes], dim=1) | |
| latents = torch.cat([semantic_latents, latents], dim=1) | |
| z = self.post_module(z) | |
| z = self.upsample(z) | |
| # Pad or crop z to match original shape (time dimension) | |
| diff = original_shape[-1] - z.shape[-1] | |
| right = 0 | |
| left = abs(diff) - right | |
| if diff > 0: | |
| z = F.pad(z, (left, right)) | |
| elif diff < 0: | |
| z = z[..., left:] | |
| return VQResult( | |
| z=z, codes=codes, latents=latents, | |
| commitment_loss=commitment_loss, codebook_loss=codebook_loss, | |
| ) | |
| def decode(self, indices: Tensor) -> Tensor: | |
| new_indices = torch.zeros_like(indices) | |
| new_indices[:, 0] = torch.clamp(indices[:, 0], max=self.semantic_quantizer.codebook_size - 1) | |
| new_indices[:, 1:] = torch.clamp(indices[:, 1:], max=self.quantizer.codebook_size - 1) | |
| z_q_semantic = self.semantic_quantizer.from_codes(new_indices[:, :1])[0] | |
| z_q_residual = self.quantizer.from_codes(new_indices[:, 1:])[0] | |
| z_q = z_q_semantic + z_q_residual | |
| z_q = self.post_module(z_q) | |
| z_q = self.upsample(z_q) | |
| return z_q | |
| # -------------------------------------------------------------------- | |
| # Transformer stack | |
| # -------------------------------------------------------------------- | |
| class ModelArgs: | |
| block_size: int = 2048 | |
| n_layer: int = 8 | |
| n_head: int = 8 | |
| dim: int = 512 | |
| intermediate_size: int = 1536 | |
| n_local_heads: int = -1 | |
| head_dim: int = 64 | |
| rope_base: float = 10000 | |
| norm_eps: float = 1e-5 | |
| dropout_rate: float = 0.1 | |
| attn_dropout_rate: float = 0.1 | |
| channels_first: bool = True # to be compatible with conv1d input/output | |
| pos_embed_type: str = "rope" # "rope" or "conformer" | |
| max_relative_position: int = 128 | |
| def __post_init__(self): | |
| if self.n_local_heads == -1: | |
| self.n_local_heads = self.n_head | |
| if self.intermediate_size is None: | |
| hidden_dim = 4 * self.dim | |
| n_hidden = int(2 * hidden_dim / 3) | |
| self.intermediate_size = find_multiple(n_hidden, 256) | |
| assert self.pos_embed_type in ["rope", "conformer"] | |
| class KVCache(nn.Module): | |
| def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): | |
| super().__init__() | |
| cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) | |
| self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) | |
| self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) | |
| def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor): | |
| # input_pos: [S], k_val: [B, H, S, D] | |
| assert input_pos.shape[0] == k_val.shape[2] | |
| k_out = self.k_cache | |
| v_out = self.v_cache | |
| k_out[:, :, input_pos] = k_val | |
| v_out[:, :, input_pos] = v_val | |
| return ( | |
| k_out[:, :, : input_pos.max() + 1, :], | |
| v_out[:, :, : input_pos.max() + 1, :], | |
| ) | |
| def clear_cache(self, prompt_len: int): | |
| self.k_cache[:, :, prompt_len:, :].fill_(0) | |
| self.v_cache[:, :, prompt_len:, :].fill_(0) | |
| class Transformer(nn.Module): | |
| def __init__(self, config: ModelArgs) -> None: | |
| super().__init__() | |
| self.config = config | |
| self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) | |
| self.norm = RMSNorm(config.dim, eps=config.norm_eps) | |
| if config.pos_embed_type == "rope": | |
| freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim, self.config.rope_base) | |
| self.register_buffer("freqs_cis", freqs_cis) | |
| else: | |
| self.register_buffer("freqs_cis", None) | |
| causal_mask = torch.tril(torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool)) | |
| self.register_buffer("causal_mask", causal_mask) | |
| self.max_batch_size = -1 | |
| self.max_seq_length = -1 | |
| self.use_kv_cache = False | |
| def setup_caches(self, max_batch_size, max_seq_length): | |
| head_dim = self.config.dim // self.config.n_head | |
| max_seq_length = find_multiple(max_seq_length, 8) | |
| self.max_seq_length = max_seq_length | |
| self.max_batch_size = max_batch_size | |
| dtype = self.norm.weight.dtype | |
| device = self.norm.weight.device | |
| for b in self.layers: | |
| b.attention.kv_cache = KVCache( | |
| max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype | |
| ).to(device) | |
| self.use_kv_cache = True | |
| def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, mask: Optional[Tensor] = None) -> Tensor: | |
| if self.config.pos_embed_type == "rope": | |
| assert self.freqs_cis is not None | |
| freqs_cis = self.freqs_cis[input_pos] | |
| else: | |
| freqs_cis = None | |
| if mask is None: | |
| if not self.training and self.use_kv_cache: | |
| mask = self.causal_mask[None, None, input_pos] | |
| mask = mask[..., : input_pos.max() + 1] | |
| else: | |
| mask = self.causal_mask[None, None, input_pos] | |
| mask = mask[..., input_pos] | |
| for layer in self.layers: | |
| x = layer(x, input_pos, freqs_cis, mask) | |
| x = self.norm(x) | |
| return x | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, config: ModelArgs) -> None: | |
| super().__init__() | |
| self.attention = Attention(config) | |
| self.feed_forward = FeedForward(config) | |
| self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) | |
| self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) | |
| self.attention_layer_scale = LayerScale(config.dim, inplace=True) | |
| self.ffn_layer_scale = LayerScale(config.dim, inplace=True) | |
| def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: | |
| h = x + self.attention_layer_scale( | |
| self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) | |
| ) | |
| out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h))) | |
| return out | |
| class Attention(nn.Module): | |
| def __init__(self, config: ModelArgs): | |
| super().__init__() | |
| assert config.dim % config.n_head == 0 | |
| total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim | |
| self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) | |
| self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False) | |
| self.kv_cache = None | |
| self.n_head = config.n_head | |
| self.head_dim = config.head_dim | |
| self.n_local_heads = config.n_local_heads | |
| self.dim = config.dim | |
| self.attn_dropout_rate = config.attn_dropout_rate | |
| self.pos_embed_type = config.pos_embed_type | |
| if self.pos_embed_type == "conformer": | |
| self.max_relative_position = config.max_relative_position | |
| num_pos_embeddings = 2 * config.max_relative_position + 1 | |
| self.rel_pos_embeddings = nn.Parameter(torch.zeros(num_pos_embeddings, self.head_dim)) | |
| nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02) | |
| def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor: | |
| positions = torch.arange(seqlen, device=q.device) | |
| relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) # [S, S] | |
| relative_positions = torch.clamp(relative_positions + self.max_relative_position, | |
| 0, 2 * self.max_relative_position) | |
| rel_embeddings = self.rel_pos_embeddings[relative_positions] # [S, S, D] | |
| q = q.transpose(1, 2) # [B, S, H, D] | |
| rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) # [B, S, H, S] | |
| rel_logits = rel_logits.transpose(1, 2) # [B, H, S, S] | |
| return rel_logits | |
| def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: | |
| bsz, seqlen, _ = x.shape | |
| kv_size = self.n_local_heads * self.head_dim | |
| q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1) | |
| context_seqlen = seqlen | |
| q = q.view(bsz, seqlen, self.n_head, self.head_dim) | |
| k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) | |
| v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) | |
| if self.pos_embed_type == "rope": | |
| q = apply_rotary_emb(q, freqs_cis) | |
| k = apply_rotary_emb(k, freqs_cis) | |
| q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) | |
| if self.kv_cache is not None: | |
| k, v = self.kv_cache.update(input_pos, k, v) | |
| k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) | |
| v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) | |
| if self.pos_embed_type == "conformer": | |
| scale = 1.0 / math.sqrt(self.head_dim) | |
| scores = torch.matmul(q, k.transpose(-2, -1)) * scale | |
| rel_scores = self._compute_conformer_pos_scores(q, seqlen) | |
| scores = scores + rel_scores | |
| if mask is not None: | |
| scores = scores.masked_fill(~mask, float("-inf")) | |
| attn = F.softmax(scores, dim=-1) | |
| if self.attn_dropout_rate > 0 and self.training: | |
| attn = F.dropout(attn, p=self.attn_dropout_rate) | |
| y = torch.matmul(attn, v) | |
| else: | |
| y = F.scaled_dot_product_attention( | |
| q, k, v, | |
| dropout_p=self.attn_dropout_rate if self.training else 0.0, | |
| attn_mask=mask, | |
| ) | |
| y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head) | |
| y = self.wo(y) | |
| return y | |
| class FeedForward(nn.Module): | |
| def __init__(self, config: ModelArgs) -> None: | |
| super().__init__() | |
| self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) | |
| self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) | |
| self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) | |
| self.dropout = nn.Dropout(config.dropout_rate) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x))) | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-5): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def _norm(self, x): | |
| return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) | |
| def forward(self, x: Tensor) -> Tensor: | |
| output = self._norm(x.float()).type_as(x) | |
| return output * self.weight | |
| class LayerScale(nn.Module): | |
| def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-2, inplace: bool = False) -> None: | |
| super().__init__() | |
| self.inplace = inplace | |
| self.gamma = nn.Parameter(init_values * torch.ones(dim)) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return x.mul_(self.gamma) if self.inplace else x * self.gamma | |
| class WindowLimitedTransformer(Transformer): | |
| """Transformer with window-limited causal attention.""" | |
| def __init__( | |
| self, | |
| config: ModelArgs, | |
| input_dim: int = 512, | |
| window_size: Optional[int] = None, | |
| causal: bool = True, | |
| look_ahead_conv: Optional[nn.Module] = None, | |
| ): | |
| super().__init__(config) | |
| self.window_size = window_size | |
| self.causal = causal | |
| self.channels_first = config.channels_first | |
| self.look_ahead_conv = look_ahead_conv if look_ahead_conv is not None else nn.Identity() | |
| self.input_proj = nn.Linear(input_dim, config.dim) if input_dim != config.dim else nn.Identity() | |
| self.output_proj = nn.Linear(config.dim, input_dim) if input_dim != config.dim else nn.Identity() | |
| def make_window_limited_mask(self, max_length: int, x_lens: Optional[Tensor] = None) -> Tensor: | |
| if self.causal: | |
| mask = torch.tril(torch.ones(max_length, max_length)) | |
| row_indices = torch.arange(max_length).view(-1, 1) | |
| window_size = self.window_size or max_length | |
| valid_range = (row_indices - window_size + 1).clamp(min=0) | |
| column_indices = torch.arange(max_length) | |
| mask = (column_indices >= valid_range) & mask.bool() | |
| else: | |
| raise NotImplementedError | |
| mask = mask.bool()[None, None] | |
| return mask | |
| def make_mask(self, max_length: int, x_lens: Optional[Tensor] = None) -> Tensor: | |
| if self.causal: | |
| mask = torch.tril(torch.ones(max_length, max_length)) | |
| else: | |
| mask = torch.ones(max_length, max_length) | |
| mask = mask.bool()[None, None] | |
| for i, x_len in enumerate(x_lens): | |
| mask[:x_len, i] = 0 | |
| mask = mask.bool()[None, None] | |
| return mask | |
| def forward(self, x: Tensor, x_lens: Optional[Tensor] = None) -> Tensor: | |
| if self.channels_first: | |
| x = x.transpose(1, 2) | |
| x = self.input_proj(x) | |
| x = self.look_ahead_conv(x) | |
| input_pos = torch.arange(x.shape[1], device=x.device) | |
| max_length = x.shape[1] | |
| if self.window_size is not None: | |
| mask = self.make_window_limited_mask(max_length, x_lens) | |
| else: | |
| mask = self.make_mask(max_length, x_lens) | |
| mask = mask.to(x.device) | |
| x = super().forward(x, input_pos, mask) | |
| x = self.output_proj(x) | |
| if self.channels_first: | |
| x = x.transpose(1, 2) | |
| return x | |
| def precompute_freqs_cis( | |
| seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16 | |
| ) -> Tensor: | |
| freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) | |
| t = torch.arange(seq_len, device=freqs.device) | |
| freqs = torch.outer(t, freqs) | |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) | |
| cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) | |
| return cache.to(dtype=dtype) | |
| def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: | |
| xshaped = x.float().reshape(*x.shape[:-1], -1, 2) | |
| freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) | |
| x_out2 = torch.stack( | |
| [ | |
| xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], | |
| xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], | |
| ], | |
| -1, | |
| ) | |
| x_out2 = x_out2.flatten(3) | |
| return x_out2.type_as(x) | |
| def init_weights(m): | |
| if isinstance(m, nn.Conv1d): | |
| nn.init.trunc_normal_(m.weight, std=0.02) | |
| nn.init.constant_(m.bias, 0) | |
| # -------------------------------------------------------------------- | |
| # Top-level AE | |
| # -------------------------------------------------------------------- | |
| class EncoderBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int = 16, | |
| stride: int = 1, | |
| causal: bool = False, | |
| n_t_layer: int = 0, | |
| transformer_general_config=None, | |
| ): | |
| super().__init__() | |
| conv_class = CausalWNConv1d if causal else WNConv1d | |
| transformer_module = ( | |
| nn.Identity() | |
| if n_t_layer == 0 | |
| else WindowLimitedTransformer( | |
| causal=causal, | |
| input_dim=dim, | |
| window_size=512, | |
| config=transformer_general_config( | |
| n_layer=n_t_layer, | |
| n_head=dim // 64, | |
| dim=dim, | |
| intermediate_size=dim * 3, | |
| ), | |
| ) | |
| ) | |
| self.block = nn.Sequential( | |
| # three multi‑receptive‑field residual units | |
| ResidualUnit(dim // 2, dilation=1, causal=causal), | |
| ResidualUnit(dim // 2, dilation=3, causal=causal), | |
| ResidualUnit(dim // 2, dilation=9, causal=causal), | |
| Snake1d(dim // 2), | |
| conv_class(dim // 2, dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)), | |
| transformer_module, | |
| ) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self.block(x) | |
| class ResidualUnit(nn.Module): | |
| def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False): | |
| super().__init__() | |
| conv_class = CausalWNConv1d if causal else WNConv1d | |
| pad = ((7 - 1) * dilation) // 2 | |
| self.block = nn.Sequential( | |
| Snake1d(dim), | |
| conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad), | |
| Snake1d(dim), | |
| conv_class(dim, dim, kernel_size=1), | |
| ) | |
| self.causal = causal | |
| def forward(self, x: Tensor) -> Tensor: | |
| y = self.block(x) | |
| pad = x.shape[-1] - y.shape[-1] | |
| if pad > 0: | |
| if self.causal: | |
| x = x[..., :-pad] | |
| else: | |
| x = x[..., pad // 2 : -pad // 2] | |
| return x + y | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int = 64, | |
| strides: List[int] = [2, 4, 8, 8], | |
| d_latent: int = 64, | |
| n_transformer_layers: List[int] = [0, 0, 4, 4], | |
| transformer_general_config: Optional[ModelArgs] = None, | |
| causal: bool = False, | |
| ): | |
| super().__init__() | |
| conv_class = CausalWNConv1d if causal else WNConv1d | |
| layers: List[nn.Module] = [conv_class(1, d_model, kernel_size=7, padding=3)] | |
| for stride, n_t_layer in zip(strides, n_transformer_layers): | |
| d_model *= 2 | |
| layers.append( | |
| EncoderBlock( | |
| d_model, stride=stride, causal=causal, | |
| n_t_layer=n_t_layer, transformer_general_config=transformer_general_config, | |
| ) | |
| ) | |
| layers += [Snake1d(d_model), conv_class(d_model, d_latent, kernel_size=3, padding=1)] | |
| self.block = nn.Sequential(*layers) | |
| self.enc_dim = d_model | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self.block(x) | |
| class DecoderBlock(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim: int = 16, | |
| output_dim: int = 8, | |
| stride: int = 1, | |
| causal: bool = False, | |
| n_t_layer: int = 0, | |
| transformer_general_config=None, | |
| ): | |
| super().__init__() | |
| conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d | |
| transformer_module = ( | |
| nn.Identity() | |
| if n_t_layer == 0 | |
| else WindowLimitedTransformer( | |
| causal=causal, | |
| input_dim=input_dim, | |
| window_size=None, | |
| config=transformer_general_config( | |
| n_layer=n_t_layer, | |
| n_head=input_dim // 64, | |
| dim=input_dim, | |
| intermediate_size=input_dim * 3, | |
| ), | |
| ) | |
| ) | |
| self.block = nn.Sequential( | |
| Snake1d(input_dim), | |
| conv_trans_class(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)), | |
| ResidualUnit(output_dim, dilation=1, causal=causal), | |
| ResidualUnit(output_dim, dilation=3, causal=causal), | |
| ResidualUnit(output_dim, dilation=9, causal=causal), | |
| ) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self.block(x) | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| input_channel: int, | |
| channels: int, | |
| rates: List[int], | |
| d_out: int = 1, | |
| causal: bool = False, | |
| n_transformer_layers: List[int] = [0, 0, 0, 0], | |
| transformer_general_config=None, | |
| ): | |
| super().__init__() | |
| conv_class = CausalWNConv1d if causal else WNConv1d | |
| layers: List[nn.Module] = [conv_class(input_channel, channels, kernel_size=7, padding=3)] | |
| for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)): | |
| input_dim = channels // 2**i | |
| output_dim = channels // 2 ** (i + 1) | |
| layers.append( | |
| DecoderBlock( | |
| input_dim, output_dim, stride, causal=causal, | |
| n_t_layer=n_t_layer, transformer_general_config=transformer_general_config, | |
| ) | |
| ) | |
| layers += [Snake1d(output_dim), conv_class(output_dim, d_out, kernel_size=7, padding=3), nn.Tanh()] | |
| self.model = nn.Sequential(*layers) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self.model(x) | |
| class DAC(nn.Module): | |
| def __init__( | |
| self, | |
| encoder_dim: int = 64, | |
| encoder_rates: List[int] = [2, 4, 8, 8], | |
| latent_dim: Optional[int] = None, | |
| decoder_dim: int = 1536, | |
| decoder_rates: List[int] = [8, 8, 4, 2], | |
| quantizer: Optional[nn.Module] = None, | |
| sample_rate: int = 44100, | |
| causal: bool = True, | |
| encoder_transformer_layers: List[int] = [0, 0, 0, 0], | |
| decoder_transformer_layers: List[int] = [0, 0, 0, 0], | |
| transformer_general_config=None, | |
| ): | |
| super().__init__() | |
| self.encoder_dim = encoder_dim | |
| self.encoder_rates = encoder_rates | |
| self.decoder_dim = decoder_dim | |
| self.decoder_rates = decoder_rates | |
| self.sample_rate = sample_rate | |
| if latent_dim is None: | |
| latent_dim = encoder_dim * (2 ** len(encoder_rates)) | |
| self.latent_dim = latent_dim | |
| self.hop_length = int(np.prod(encoder_rates)) | |
| self.encoder = Encoder( | |
| encoder_dim, encoder_rates, latent_dim, causal=causal, | |
| n_transformer_layers=encoder_transformer_layers, | |
| transformer_general_config=transformer_general_config, | |
| ) | |
| self.quantizer = quantizer | |
| self.decoder = Decoder( | |
| latent_dim, decoder_dim, decoder_rates, causal=causal, | |
| n_transformer_layers=decoder_transformer_layers, | |
| transformer_general_config=transformer_general_config, | |
| ) | |
| self.sample_rate = sample_rate | |
| self.apply(init_weights) | |
| self.delay = self.get_delay() | |
| self.frame_length = self.hop_length * 4 | |
| def get_output_length(self, input_length: int) -> int: | |
| length = input_length | |
| for stride in self.encoder_rates: | |
| length = math.ceil(length / stride) | |
| return length | |
| def get_delay(self) -> int: | |
| l_out = self.get_output_length(0) | |
| L = l_out | |
| layers = [layer for layer in self.modules() if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d))] | |
| for layer in reversed(layers): | |
| d = layer.dilation[0] | |
| k = layer.kernel_size[0] | |
| s = layer.stride[0] | |
| if isinstance(layer, nn.ConvTranspose1d): | |
| L = ((L - d * (k - 1) - 1) / s) + 1 | |
| elif isinstance(layer, nn.Conv1d): | |
| L = (L - 1) * s + d * (k - 1) + 1 | |
| L = math.ceil(L) | |
| l_in = L | |
| return (l_in - l_out) // 2 | |
| def preprocess(self, audio_data: Tensor, sample_rate: Optional[int]) -> Tensor: | |
| if sample_rate is None: | |
| sample_rate = self.sample_rate | |
| assert sample_rate == self.sample_rate | |
| length = audio_data.shape[-1] | |
| right_pad = math.ceil(length / self.hop_length) * self.hop_length - length | |
| audio_data = F.pad(audio_data, (0, right_pad)) | |
| return audio_data | |
| def encode( | |
| self, | |
| audio_data: Tensor, | |
| audio_lengths: Optional[Tensor] = None, | |
| n_quantizers: Optional[int] = None, | |
| **kwargs, | |
| ): | |
| """Encode audio to quantized code indices.""" | |
| if audio_data.ndim == 2: | |
| audio_data = audio_data.unsqueeze(1) | |
| length = audio_data.shape[-1] | |
| right_pad = math.ceil(length / self.frame_length) * self.frame_length - length | |
| audio_data = F.pad(audio_data, (0, right_pad)) | |
| if audio_lengths is None: | |
| audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device) | |
| z = self.encoder(audio_data) | |
| vq_results = self.quantizer(z, n_quantizers, **kwargs) | |
| indices = vq_results.codes | |
| indices_lens = torch.ceil(audio_lengths / self.frame_length).long() | |
| return indices, indices_lens | |
| def decode(self, indices: Tensor, feature_lengths: Tensor): | |
| """Decode code indices to audio.""" | |
| if indices.ndim == 2: | |
| indices = indices[None] | |
| z = self.quantizer.decode(indices) | |
| audio_lengths = feature_lengths * self.frame_length | |
| return self.decoder(z), audio_lengths | |
| def encode_to_codes(self, audio: Tensor, audio_lengths: Optional[Tensor] = None, n_quantizers: Optional[int] = None, **kw): | |
| return self.encode(audio, audio_lengths, n_quantizers, **kw) | |
| def decode_codes(self, indices: Tensor, feature_lengths: Tensor): | |
| return self.decode(indices, feature_lengths) | |
| def encode_zq(self, audio_data: Tensor) -> Tensor: | |
| indices, _ = self.encode(audio_data) | |
| new_indices = torch.zeros_like(indices) | |
| new_indices[:, 0] = torch.clamp(indices[:, 0], max=self.quantizer.semantic_quantizer.codebook_size - 1) | |
| new_indices[:, 1:] = torch.clamp(indices[:, 1:], max=self.quantizer.quantizer.codebook_size - 1) | |
| z_q_semantic = self.quantizer.semantic_quantizer.from_codes(new_indices[:, :1])[0] | |
| z_q_residual = self.quantizer.quantizer.from_codes(new_indices[:, 1:])[0] | |
| z_q = z_q_semantic + z_q_residual | |
| return z_q | |
| def decode_zq(self, z_q: Tensor) -> Tensor: | |
| z_q = self.quantizer.post_module(z_q) | |
| z_q = self.quantizer.upsample(z_q) | |
| return self.decoder(z_q) | |
| def device(self) -> torch.device: return next(self.parameters()).device | |
| def dtype(self) -> torch.dtype: return next(self.parameters()).dtype | |
| # -------------------------------------------------------------------- | |
| # Build helpers | |
| # -------------------------------------------------------------------- | |
| def build_ae(**cfg) -> DAC: | |
| """ | |
| Factory used by external loaders | |
| """ | |
| # Shared transformer config for the RVQ pre/post modules | |
| q_config = ModelArgs( | |
| block_size=4096, n_layer=8, n_head=16, dim=1024, | |
| intermediate_size=3072, head_dim=64, norm_eps=1e-5, | |
| dropout_rate=0.1, attn_dropout_rate=0.1, channels_first=True | |
| ) | |
| def make_transformer(): | |
| return WindowLimitedTransformer( | |
| causal=True, window_size=128, input_dim=1024, config=q_config | |
| ) | |
| quantizer = DownsampleResidualVectorQuantize( | |
| input_dim=1024, n_codebooks=9, codebook_size=1024, codebook_dim=8, | |
| quantizer_dropout=0.5, downsample_factor=(2, 2), | |
| semantic_codebook_size=4096, | |
| pre_module=make_transformer(), | |
| post_module=make_transformer(), | |
| ) | |
| def transformer_general_config(**kw): | |
| return ModelArgs( | |
| block_size=kw.get("block_size", 16384), | |
| n_layer=kw.get("n_layer", 8), | |
| n_head=kw.get("n_head", 8), | |
| dim=kw.get("dim", 512), | |
| intermediate_size=kw.get("intermediate_size", 1536), | |
| n_local_heads=kw.get("n_local_heads", -1), | |
| head_dim=kw.get("head_dim", 64), | |
| rope_base=kw.get("rope_base", 10000), | |
| norm_eps=kw.get("norm_eps", 1e-5), | |
| dropout_rate=kw.get("dropout_rate", 0.1), | |
| attn_dropout_rate=kw.get("attn_dropout_rate", 0.1), | |
| channels_first=kw.get("channels_first", True), | |
| ) | |
| dac = DAC( | |
| encoder_dim=64, encoder_rates=[2, 4, 8, 8], latent_dim=1024, | |
| decoder_dim=1536, decoder_rates=[8, 8, 4, 2], | |
| quantizer=quantizer, sample_rate=44100, causal=True, | |
| encoder_transformer_layers=[0, 0, 0, 4], | |
| decoder_transformer_layers=[4, 0, 0, 0], | |
| transformer_general_config=transformer_general_config, | |
| ) | |
| return dac | |
| __all__ = [ | |
| "DAC", | |
| "build_ae", | |
| "VectorQuantize", | |
| "ResidualVectorQuantize", | |
| "DownsampleResidualVectorQuantize", | |
| ] | |
| # ----- BEGIN DAC MIT LICENSE ----- | |
| # MIT License | |
| # Copyright (c) 2023-present, Descript | |
| # | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| # ----- END DAC MIT LICENSE ----- | |