| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """PyTorch MossAudioTokenizer model.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import copy |
| | import math |
| | from contextlib import ExitStack, contextmanager |
| | from dataclasses import dataclass |
| | from typing import cast |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from transformers.modeling_utils import PreTrainedAudioTokenizerBase |
| | from transformers.utils import ModelOutput, auto_docstring, logging |
| | from .configuration_moss_audio_tokenizer import MossAudioTokenizerConfig |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @dataclass |
| | @auto_docstring |
| | class MossAudioTokenizerEncoderOutput(ModelOutput): |
| | r""" |
| | audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*): |
| | Discrete audio codes computed using the encoder and quantizer. |
| | audio_codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| | Valid lengths for each sample's audio codes. |
| | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, hidden_size, sequence_length)`, *optional*): |
| | Hidden states from the encoder before quantization. |
| | """ |
| |
|
| | audio_codes: torch.Tensor | None = None |
| | audio_codes_lengths: torch.Tensor | None = None |
| | encoder_hidden_states: torch.Tensor | None = None |
| |
|
| |
|
| | @dataclass |
| | @auto_docstring |
| | class MossAudioTokenizerDecoderOutput(ModelOutput): |
| | r""" |
| | audio (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*): |
| | Decoded audio waveform. |
| | audio_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| | Valid lengths for each sample's audio. |
| | """ |
| |
|
| | audio: torch.Tensor | None = None |
| | audio_lengths: torch.Tensor | None = None |
| |
|
| |
|
| | @dataclass |
| | @auto_docstring |
| | class MossAudioTokenizerOutput(ModelOutput): |
| | r""" |
| | audio (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*): |
| | Decoded audio waveform. |
| | audio_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| | Valid lengths for each sample's audio. |
| | audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*): |
| | Discrete audio codes computed using the encoder and quantizer. |
| | audio_codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| | Valid lengths for each sample's audio codes. |
| | """ |
| |
|
| | audio: torch.Tensor | None = None |
| | audio_lengths: torch.Tensor | None = None |
| | audio_codes: torch.Tensor | None = None |
| | audio_codes_lengths: torch.Tensor | None = None |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @dataclass |
| | class StreamingState: |
| | """Base state for streaming modules.""" |
| |
|
| | batch_size: int |
| | device: torch.device |
| |
|
| | def __post_init__(self): |
| | self.exec_mask = torch.ones(self.batch_size, dtype=torch.bool, device=self.device) |
| |
|
| | def set_exec_mask(self, exec_mask: torch.Tensor): |
| | self.exec_mask[:] = exec_mask |
| |
|
| | def reset(self, reset_mask: torch.Tensor) -> None: |
| | self.exec_mask[:] = torch.where(reset_mask, torch.ones_like(self.exec_mask), self.exec_mask) |
| |
|
| | def __enter__(self): |
| | |
| | return self |
| |
|
| | def __exit__(self, exc_type, exc_value, traceback) -> None: |
| | pass |
| |
|
| |
|
| | class StreamingModule(nn.Module): |
| | """Base class for streaming components.""" |
| |
|
| | def __init__(self) -> None: |
| | super().__init__() |
| | self._streaming_state: StreamingState | None = None |
| | self._streaming_detached: bool = False |
| | self._cached_children: list[tuple[str, StreamingModule]] | None = None |
| |
|
| | @property |
| | def is_streaming(self): |
| | return self._streaming_state is not None |
| |
|
| | def _apply_named_streaming(self, fn): |
| | def _handle_module(prefix: str, module: nn.Module): |
| | if isinstance(module, StreamingModule): |
| | if module._streaming_detached and prefix != "": |
| | return |
| | if self._cached_children is None: |
| | raise RuntimeError("Internal error: _cached_children should be initialized before traversal.") |
| | self._cached_children.append((prefix, module)) |
| | for name, child in module.named_children(): |
| | new_prefix = f"{prefix}.{name}" if prefix else name |
| | _handle_module(new_prefix, child) |
| |
|
| | if self._cached_children is None: |
| | self._cached_children = [] |
| | _handle_module("", self) |
| | for name, child in self._cached_children: |
| | fn(name, child) |
| |
|
| | def _start_streaming(self, batch_size: int, exit_stack: ExitStack): |
| | def _start_streaming_fn(name: str, module: StreamingModule): |
| | if module._streaming_state is not None: |
| | raise RuntimeError(f"{name} is already streaming!") |
| | state = module._init_streaming_state(batch_size) |
| | exit_stack.enter_context(state) |
| | module._streaming_state = state |
| |
|
| | self._apply_named_streaming(_start_streaming_fn) |
| |
|
| | def _stop_streaming(self) -> None: |
| | def _stop_streaming_fn(name: str, module: StreamingModule): |
| | module._streaming_state = None |
| |
|
| | self._apply_named_streaming(_stop_streaming_fn) |
| |
|
| | def _init_streaming_state(self, batch_size: int) -> StreamingState: |
| | device = next(iter(self.parameters())).device |
| | return StreamingState(batch_size, device) |
| |
|
| | def streaming(self, batch_size: int) -> ExitStack: |
| | """Context manager to enter streaming mode.""" |
| | exit_stack = ExitStack() |
| | self._start_streaming(batch_size, exit_stack) |
| | exit_stack.callback(self._stop_streaming) |
| | return exit_stack |
| |
|
| |
|
| | class StreamingContainer(StreamingModule): |
| | """Container for streaming modules.""" |
| |
|
| | pass |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class MossAudioTokenizerRMSNorm(nn.Module): |
| | """Root Mean Square Layer Normalization.""" |
| |
|
| | def __init__( |
| | self, |
| | dim: int, |
| | eps: float = 1e-5, |
| | dtype: torch.dtype | None = None, |
| | device=None, |
| | ): |
| | super().__init__() |
| | self.eps = eps |
| | self.dtype = dtype |
| | self.alpha = nn.Parameter(torch.full((1, 1, dim), 1.0, requires_grad=True, device=device, dtype=dtype)) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | x_dtype = x.dtype |
| | if self.dtype is not None: |
| | x = x.to(self.dtype) |
| | var = self.eps + torch.mean(x**2, dim=2, keepdim=True) |
| | y = (x * (self.alpha.to(var) * torch.rsqrt(var))).to(x_dtype) |
| | return y |
| |
|
| |
|
| | class MossAudioTokenizerLayerScale(nn.Module): |
| | """Layer scale from Touvron et al. 2021.""" |
| |
|
| | def __init__( |
| | self, |
| | channels: int, |
| | init: float = 1e-4, |
| | channel_last: bool = True, |
| | device=None, |
| | dtype=None, |
| | ): |
| | super().__init__() |
| | self.channel_last = channel_last |
| | self.scale = nn.Parameter(torch.full((channels,), init, requires_grad=True, device=device, dtype=dtype)) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | if self.channel_last: |
| | return self.scale * x |
| | else: |
| | return self.scale[:, None] * x |
| |
|
| |
|
| | def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module: |
| | """Create normalization module.""" |
| | if norm_type == "layer_norm": |
| | return nn.LayerNorm(dim, eps=1e-5, **kwargs) |
| | elif norm_type in {"rms_norm"}: |
| | return MossAudioTokenizerRMSNorm(dim, eps=1e-5, **kwargs) |
| | elif norm_type in {"rms_norm_f32"}: |
| | kwargs.pop("dtype", None) |
| | return MossAudioTokenizerRMSNorm(dim, eps=1e-8, dtype=torch.float, **kwargs) |
| | else: |
| | raise ValueError(f"Unknown norm type: {norm_type}") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def apply_rope( |
| | q: torch.Tensor, |
| | k: torch.Tensor, |
| | offset: torch.Tensor, |
| | max_period: float = 10_000, |
| | time_before_heads: bool = False, |
| | ): |
| | """Apply rotary position embedding.""" |
| | if time_before_heads: |
| | B, T, H, D = q.shape |
| | else: |
| | B, H, T, D = q.shape |
| | if k.shape != q.shape: |
| | raise ValueError(f"Expected k.shape == q.shape, got k={tuple(k.shape)} q={tuple(q.shape)}") |
| | if D <= 0 or (D % 2) != 0: |
| | raise ValueError(f"RoPE requires an even last dimension, got D={D}") |
| |
|
| | ds = torch.arange(D // 2, device=q.device, dtype=torch.float32) |
| | freqs = torch.exp(ds * (-math.log(max_period) * 2 / D)) |
| | ts = offset.float().view(-1, 1) + torch.arange(T, device=q.device, dtype=torch.float32) |
| |
|
| | if time_before_heads: |
| | ts = ts.view(B, -1, 1, 1) |
| | else: |
| | ts = ts.view(B, 1, -1, 1) |
| |
|
| | dims = q.shape[:-1] |
| | q = q.view(*dims, D // 2, 2) |
| | k = k.view(*dims, D // 2, 2) |
| |
|
| | qr, qi = q[..., 0].float(), q[..., 1].float() |
| | kr, ki = k[..., 0].float(), k[..., 1].float() |
| |
|
| | rotr = torch.cos(freqs * ts) |
| | roti = torch.sin(freqs * ts) |
| |
|
| | qor = qr * rotr - qi * roti |
| | qoi = qr * roti + qi * rotr |
| | kor = kr * rotr - ki * roti |
| | koi = kr * roti + ki * rotr |
| |
|
| | dtype = q.dtype |
| | qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1) |
| | ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1) |
| |
|
| | return qo.view(*dims, D), ko.view(*dims, D) |
| |
|
| |
|
| | class MossAudioTokenizerRotaryEmbedding(nn.Module): |
| | """Rotary positional embedding (RoPE).""" |
| |
|
| | def __init__(self, max_period: float = 10000.0): |
| | super().__init__() |
| | self.max_period = max_period |
| |
|
| | def forward( |
| | self, |
| | q: torch.Tensor, |
| | k: torch.Tensor, |
| | offset: torch.Tensor, |
| | time_before_heads: bool = False, |
| | ): |
| | return apply_rope(q, k, offset, self.max_period, time_before_heads) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class MossAudioTokenizerActivationGating(nn.Module): |
| | """Gating FFN layer with activation.""" |
| |
|
| | def __init__(self, dim: int, dim_feedforward: int, activation, **factory_kwargs): |
| | super().__init__() |
| | if dim_feedforward == 4 * dim: |
| | hidden = (21 * dim) // 8 |
| | else: |
| | hidden = (2 * dim_feedforward) // 3 |
| |
|
| | self.linear_in = nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs) |
| | self.linear_out = nn.Linear(hidden, dim, bias=False, **factory_kwargs) |
| | self.activation = activation |
| |
|
| | def forward(self, x: torch.Tensor): |
| | x = self.linear_in(x) |
| | B, T, _ = x.shape |
| | x = x.view(B, T, 2, -1) |
| | x = self.activation(x[..., 0, :]) * x[..., 1, :] |
| | x = self.linear_out(x) |
| | return x |
| |
|
| |
|
| | def _get_activation(name: str): |
| | if name in ["sigmoid", "tanh", "relu"]: |
| | return getattr(torch, name) |
| | elif name in ["leaky_relu", "elu", "gelu", "silu", "mish", "softsign"]: |
| | return getattr(F, name) |
| | elif name == "identity": |
| | return nn.Identity() |
| | else: |
| | raise ValueError(f"Unknown activation {name}") |
| |
|
| |
|
| | def make_gating(name: str, dim: int, dim_feedforward: int, **factory_kwargs) -> nn.Module: |
| | return MossAudioTokenizerActivationGating(dim, dim_feedforward, _get_activation(name), **factory_kwargs) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def create_sin_embedding( |
| | positions: torch.Tensor, |
| | dim: int, |
| | max_period: float = 10000, |
| | dtype: torch.dtype = torch.float32, |
| | ) -> torch.Tensor: |
| | """Create sinusoidal positional embedding with shape [B, T, C].""" |
| | if dim % 2 != 0: |
| | raise ValueError(f"Sinusoidal embedding requires even dim, got dim={dim}") |
| | half_dim = dim // 2 |
| | if half_dim <= 1: |
| | raise ValueError(f"Sinusoidal embedding requires dim >= 4, got dim={dim}") |
| | positions = positions.to(dtype) |
| | adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1) |
| | max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) |
| | phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) |
| | return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class KVCacheResult: |
| | """Container for KV cache results that supports tuple unpacking.""" |
| |
|
| | __slots__ = ("keys", "values", "positions") |
| |
|
| | def __init__(self, keys: torch.Tensor, values: torch.Tensor, positions: torch.Tensor): |
| | self.keys = keys |
| | self.values = values |
| | self.positions = positions |
| |
|
| | def __iter__(self): |
| | """Allow unpacking as (keys, values, positions).""" |
| | return iter((self.keys, self.values, self.positions)) |
| |
|
| | @staticmethod |
| | def from_kv(keys: torch.Tensor, values: torch.Tensor) -> KVCacheResult: |
| | B, H, T, D = keys.shape |
| | positions = torch.arange(T, device=keys.device, dtype=torch.long) |
| | return KVCacheResult(keys, values, positions.expand(B, -1)) |
| |
|
| |
|
| | class RingKVCache: |
| | """Efficient streaming KVCache compatible with CUDA Graph.""" |
| |
|
| | def __init__( |
| | self, |
| | batch_size: int, |
| | num_heads: int, |
| | dim_per_head: int, |
| | capacity: int, |
| | respect_exec_mask: bool = True, |
| | device: torch.device = torch.device("cuda"), |
| | dtype: torch.dtype = torch.bfloat16, |
| | ): |
| | self.capacity = capacity |
| | self.cache = torch.zeros( |
| | (2, batch_size, num_heads, capacity, dim_per_head), |
| | device=device, |
| | dtype=dtype, |
| | ) |
| | self.respect_exec_mask = respect_exec_mask |
| | if self.respect_exec_mask: |
| | self.end_offset = torch.zeros(batch_size, device=device, dtype=torch.long) |
| | else: |
| | self.end_offset = torch.zeros(1, device=device, dtype=torch.long) |
| |
|
| | def reset(self, reset_mask: torch.Tensor) -> None: |
| | self.end_offset[:] = torch.where(reset_mask, torch.zeros_like(self.end_offset), self.end_offset) |
| |
|
| | def complete(self, k: torch.Tensor, v: torch.Tensor, exec_mask: torch.Tensor) -> KVCacheResult: |
| | B, H, T, D = k.shape |
| | if T <= 0: |
| | raise ValueError(f"Expected T > 0, got T={T}") |
| |
|
| | indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) |
| | indexes = indexes + self.end_offset.view(-1, 1) |
| | indexes = indexes % self.capacity |
| |
|
| | if self.respect_exec_mask: |
| | this_indexes = indexes.view(B, 1, T, 1).expand(-1, H, T, D) |
| | self.cache[0].scatter_(2, this_indexes, k) |
| | self.cache[1].scatter_(2, this_indexes, v) |
| | else: |
| | self.cache[0].index_copy_(2, indexes[0], k) |
| | self.cache[1].index_copy_(2, indexes[0], v) |
| |
|
| | keys = self.cache[0] |
| | values = self.cache[1] |
| |
|
| | indexes = torch.arange(self.capacity, device=self.end_offset.device, dtype=torch.long) |
| | last_offset = self.end_offset.view(-1, 1) + T - 1 |
| | end_index = last_offset % self.capacity |
| | delta = indexes - end_index |
| |
|
| | positions = torch.where( |
| | delta <= 0, |
| | last_offset + delta, |
| | last_offset + delta - self.capacity, |
| | ) |
| |
|
| | if self.respect_exec_mask: |
| | self.end_offset[:] = torch.where(exec_mask, self.end_offset + T, self.end_offset) |
| | else: |
| | self.end_offset.add_(T) |
| |
|
| | invalid = indexes >= self.end_offset.view(-1, 1) |
| | positions = torch.where(invalid, torch.full_like(positions, -1), positions) |
| |
|
| | return KVCacheResult(keys, values, positions) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @dataclass |
| | class MHAState(StreamingState): |
| | kv_cache: RingKVCache | None |
| | offset: torch.Tensor |
| | offset_cpu: int |
| |
|
| | def reset(self, reset_mask: torch.Tensor): |
| | super().reset(reset_mask) |
| | self.offset[:] = torch.where(reset_mask, torch.zeros_like(self.offset), self.offset) |
| | if self.kv_cache is not None: |
| | self.kv_cache.reset(reset_mask) |
| | self.offset_cpu = 0 |
| |
|
| |
|
| | def apply_weights_per_step( |
| | modules: nn.ModuleList, |
| | schedule: list[int] | None, |
| | x: torch.Tensor, |
| | offset: int | None, |
| | ) -> torch.Tensor: |
| | """Apply different weights for each time step.""" |
| | if len(modules) == 1: |
| | return modules[0](x) |
| |
|
| | if offset is None: |
| | raise ValueError("offset must be provided when using per-step weights (len(modules) > 1).") |
| | ys = [] |
| | B, T, C = x.shape |
| | for t in range(T): |
| | module_index = t + offset |
| | if schedule is not None: |
| | if module_index >= len(schedule) or module_index < 0: |
| | raise ValueError( |
| | f"weights_per_step_schedule is too short for module_index={module_index} (len={len(schedule)})." |
| | ) |
| | module_index = schedule[module_index] |
| | if module_index >= len(modules) or module_index < 0: |
| | raise ValueError(f"module_index={module_index} out of range for len(modules)={len(modules)}.") |
| | y = modules[module_index](x[:, t : t + 1]) |
| | ys.append(y) |
| | return torch.cat(ys, 1) |
| |
|
| |
|
| | class MossAudioTokenizerMultiheadAttention(StreamingModule): |
| | """Multi-head attention with streaming support.""" |
| |
|
| | def __init__( |
| | self, |
| | embed_dim: int, |
| | num_heads: int, |
| | causal: bool = False, |
| | context: int | None = None, |
| | rope: MossAudioTokenizerRotaryEmbedding | None = None, |
| | weights_per_step: int = 0, |
| | weights_per_step_schedule: list[int] | None = None, |
| | device=None, |
| | dtype=None, |
| | ): |
| | super().__init__() |
| | factory_kwargs = {"device": device, "dtype": dtype} |
| |
|
| | self.embed_dim = embed_dim |
| | self.causal = causal |
| | self.context = context |
| | self.rope = rope |
| | self.num_heads = num_heads |
| | self.weights_per_step = weights_per_step |
| | self.weights_per_step_schedule = weights_per_step_schedule |
| |
|
| | out_dim = 3 * embed_dim |
| | mult = 1 |
| | if weights_per_step: |
| | mult = max(weights_per_step_schedule) + 1 if weights_per_step_schedule else weights_per_step |
| | self.mult = mult |
| |
|
| | self.out_projs = nn.ModuleList( |
| | [nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs) for _ in range(mult)] |
| | ) |
| | self.in_projs = nn.ModuleList( |
| | [nn.Linear(embed_dim, out_dim, bias=False, **factory_kwargs) for _ in range(mult)] |
| | ) |
| |
|
| | self._register_load_state_dict_pre_hook(self._load_hook, with_module=True) |
| |
|
| | @staticmethod |
| | def _load_hook(module, state_dict, prefix, *_): |
| | mappings = { |
| | "in_proj_weight": "in_projs.{i}.weight", |
| | "in_proj.weight": "in_projs.{i}.weight", |
| | "out_proj.weight": "out_projs.{i}.weight", |
| | } |
| | mult = module.mult |
| | for suffix in ["", "_scb"]: |
| | for source, target in mappings.items(): |
| | this_source = prefix + source + suffix |
| | if this_source in state_dict: |
| | weight = state_dict[this_source] |
| | _, *OD = weight.shape |
| | weight = weight.view(mult, -1, *OD) |
| | for i in range(mult): |
| | state_dict[prefix + target.format(i=i) + suffix] = weight[i] |
| | state_dict.pop(this_source) |
| |
|
| | def _init_streaming_state(self, batch_size: int) -> MHAState: |
| | in_proj = cast(nn.Linear, self.in_projs[0]) |
| | device = cast(torch.device, in_proj.weight.device) |
| | dtype = cast(torch.dtype, in_proj.weight.dtype) |
| |
|
| | dim_per_head = self.embed_dim // self.num_heads |
| | if self.context is None: |
| | capacity = self.weights_per_step if self.weights_per_step else 1024 |
| | else: |
| | capacity = self.context |
| |
|
| | kv_cache = RingKVCache( |
| | batch_size, |
| | self.num_heads, |
| | dim_per_head, |
| | capacity, |
| | respect_exec_mask=not self.weights_per_step, |
| | device=cast(torch.device, device), |
| | dtype=cast(torch.dtype, dtype), |
| | ) |
| | return MHAState( |
| | batch_size, |
| | cast(torch.device, device), |
| | kv_cache, |
| | offset=torch.zeros(batch_size, device=cast(torch.device, device), dtype=torch.long), |
| | offset_cpu=0, |
| | ) |
| |
|
| | def _complete_kv(self, k, v) -> KVCacheResult: |
| | state = cast(MHAState | None, self._streaming_state) |
| | if state is None: |
| | return KVCacheResult.from_kv(k, v) |
| | if state.kv_cache is None: |
| | return KVCacheResult.from_kv(k, v) |
| | return state.kv_cache.complete(k, v, state.exec_mask) |
| |
|
| | def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): |
| | state = cast(MHAState | None, self._streaming_state) |
| | B, T = query.shape[:2] |
| |
|
| | if state is None: |
| | offset = torch.zeros(B, device=query.device, dtype=torch.long) |
| | offset_cpu = 0 |
| | else: |
| | offset = state.offset |
| | offset_cpu = state.offset_cpu |
| |
|
| | projected = apply_weights_per_step(self.in_projs, self.weights_per_step_schedule, query, offset_cpu) |
| | dim_per_head = self.embed_dim // self.num_heads |
| | projected = projected.reshape(B, T, 3, self.num_heads, dim_per_head).permute(2, 0, 3, 1, 4) |
| | q, k, v = projected[0], projected[1], projected[2] |
| |
|
| | if self.rope: |
| | q, k = self.rope(q, k, offset, time_before_heads=False) |
| |
|
| | k, v, pos_k = self._complete_kv(k, v) |
| | pos_k = pos_k[:, None] |
| |
|
| | if self.causal: |
| | pos_q = offset.view(-1, 1, 1) + torch.arange(T, device=q.device, dtype=torch.long).view(-1, 1) |
| | delta = pos_q - pos_k |
| | attn_bias = (pos_k >= 0) & (delta >= 0) |
| | if self.context is not None: |
| | attn_bias = attn_bias & (delta < self.context) |
| | attn_bias = attn_bias[:, None] |
| | else: |
| | attn_bias = None |
| |
|
| | x = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0) |
| | x = x.transpose(1, 2).reshape(B, T, self.embed_dim) |
| | x = apply_weights_per_step(self.out_projs, self.weights_per_step_schedule, x, offset_cpu) |
| |
|
| | if state is not None: |
| | state.offset[:] = torch.where(state.exec_mask, state.offset + T, state.offset) |
| | state.offset_cpu += T |
| | return x |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @dataclass |
| | class LayerState(StreamingState): |
| | offset_cpu: int = 0 |
| |
|
| | def reset(self, reset_mask: torch.Tensor): |
| | super().reset(reset_mask) |
| | self.offset_cpu = 0 |
| |
|
| |
|
| | class MossAudioTokenizerTransformerLayer(StreamingModule): |
| | """Transformer layer with streaming support.""" |
| |
|
| | def __init__( |
| | self, |
| | d_model: int, |
| | num_heads: int, |
| | dim_feedforward: int = 2048, |
| | causal: bool = False, |
| | context: int | None = None, |
| | rope: MossAudioTokenizerRotaryEmbedding | None = None, |
| | norm: str = "layer_norm", |
| | layer_scale: float | None = None, |
| | gating: str = "none", |
| | weights_per_step: int = 0, |
| | weights_per_step_schedule: list[int] | None = None, |
| | activation=F.gelu, |
| | device=None, |
| | dtype=None, |
| | ): |
| | super().__init__() |
| | factory_kwargs = {"device": device, "dtype": dtype} |
| |
|
| | self.self_attn = MossAudioTokenizerMultiheadAttention( |
| | embed_dim=d_model, |
| | num_heads=num_heads, |
| | causal=causal, |
| | context=context, |
| | rope=rope, |
| | weights_per_step=weights_per_step, |
| | weights_per_step_schedule=weights_per_step_schedule, |
| | **factory_kwargs, |
| | ) |
| | self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) |
| | self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) |
| |
|
| | self.weights_per_step = weights_per_step |
| | self.weights_per_step_schedule = weights_per_step_schedule |
| | self.gating: nn.Module | nn.ModuleList | None = None |
| | self.linear1: nn.Module | None = None |
| | self.linear2: nn.Module | None = None |
| | self.activation = activation |
| |
|
| | num_weights = 1 |
| | if weights_per_step: |
| | num_weights = max(weights_per_step_schedule) + 1 if weights_per_step_schedule else weights_per_step |
| |
|
| | if gating == "none": |
| | self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False, **factory_kwargs) |
| | self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False, **factory_kwargs) |
| | else: |
| | if weights_per_step: |
| | dim_ff_list = [dim_feedforward] * num_weights if isinstance(dim_feedforward, int) else dim_feedforward |
| | self.gating = nn.ModuleList( |
| | [make_gating(gating, d_model, dim, **factory_kwargs) for dim in dim_ff_list] |
| | ) |
| | else: |
| | self.gating = make_gating(gating, d_model, dim_feedforward, **factory_kwargs) |
| |
|
| | if layer_scale is None: |
| | self.layer_scale_1 = nn.Identity() |
| | self.layer_scale_2 = nn.Identity() |
| | else: |
| | self.layer_scale_1 = MossAudioTokenizerLayerScale( |
| | channels=d_model, init=layer_scale, channel_last=True, **cast(dict[str, object], factory_kwargs) |
| | ) |
| | self.layer_scale_2 = MossAudioTokenizerLayerScale( |
| | channels=d_model, init=layer_scale, channel_last=True, **cast(dict[str, object], factory_kwargs) |
| | ) |
| |
|
| | def _init_streaming_state(self, batch_size: int) -> LayerState: |
| | device = next(iter(self.parameters())).device |
| | return LayerState(batch_size, device, offset_cpu=0) |
| |
|
| | def _ff_block(self, x: torch.Tensor) -> torch.Tensor: |
| | state = self._streaming_state |
| | offset = state.offset_cpu if isinstance(state, LayerState) else 0 |
| |
|
| | x_orig = x |
| | x = self.norm2(x) |
| |
|
| | if self.gating is None: |
| | assert self.linear1 is not None |
| | assert self.linear2 is not None |
| | update = self.linear2(self.activation(self.linear1(x))) |
| | else: |
| | if self.weights_per_step: |
| | assert isinstance(self.gating, nn.ModuleList) |
| | update = apply_weights_per_step(self.gating, self.weights_per_step_schedule, x, offset) |
| | else: |
| | update = self.gating(x) |
| | return x_orig.to(update) + self.layer_scale_2(update) |
| |
|
| | def _sa_block(self, x: torch.Tensor): |
| | x_orig = x |
| | x = self.norm1(x) |
| | update = self.self_attn(x, x, x) |
| | return x_orig.to(update) + self.layer_scale_1(update) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | x = self._sa_block(x) |
| | x = self._ff_block(x) |
| | state = self._streaming_state |
| | if state is not None: |
| | assert isinstance(state, LayerState) |
| | state.offset_cpu += x.shape[1] |
| | return x |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @dataclass |
| | class TransformerState(StreamingState): |
| | offsets: torch.Tensor |
| |
|
| | def reset(self, reset_mask: torch.Tensor): |
| | super().reset(reset_mask) |
| | self.offsets[:] = torch.where(reset_mask, torch.zeros_like(self.offsets), self.offsets) |
| |
|
| |
|
| | class MossAudioTokenizerTransformer(StreamingModule): |
| | """Transformer with streaming/causal support.""" |
| |
|
| | def __init__( |
| | self, |
| | d_model: int, |
| | num_heads: int, |
| | num_layers: int, |
| | dim_feedforward: int = 2048, |
| | causal: bool = False, |
| | context: int | None = None, |
| | positional_embedding: str = "sin", |
| | max_period: float = 10_000, |
| | positional_scale: float = 1.0, |
| | device=None, |
| | dtype=None, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | if d_model % num_heads != 0: |
| | raise ValueError(f"d_model must be divisible by num_heads, got d_model={d_model}, num_heads={num_heads}") |
| |
|
| | self.positional_embedding = positional_embedding |
| | self.max_period = max_period |
| | self.positional_scale = positional_scale |
| |
|
| | self.rope: MossAudioTokenizerRotaryEmbedding | None = None |
| | if positional_embedding in {"rope", "sin_rope"}: |
| | self.rope = MossAudioTokenizerRotaryEmbedding(max_period=max_period) |
| |
|
| | self.layers = nn.ModuleList() |
| | for _ in range(num_layers): |
| | self.layers.append( |
| | MossAudioTokenizerTransformerLayer( |
| | d_model=d_model, |
| | num_heads=num_heads, |
| | dim_feedforward=dim_feedforward, |
| | causal=causal, |
| | context=context, |
| | rope=self.rope, |
| | device=device, |
| | dtype=dtype, |
| | **kwargs, |
| | ) |
| | ) |
| |
|
| | def _init_streaming_state(self, batch_size: int) -> TransformerState: |
| | device = next(self.parameters()).device |
| | return TransformerState( |
| | batch_size, |
| | device, |
| | offsets=torch.zeros(batch_size, device=device, dtype=torch.long), |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor, *args, **kwargs): |
| | B, T, C = x.shape |
| | state = self._streaming_state |
| | offsets = ( |
| | torch.zeros(1, dtype=torch.long, device=x.device) |
| | if state is None |
| | else ( |
| | state.offsets |
| | if isinstance(state, TransformerState) |
| | else torch.zeros(1, dtype=torch.long, device=x.device) |
| | ) |
| | ) |
| |
|
| | if self.positional_embedding in {"sin", "sin_rope"}: |
| | positions = torch.arange(T, device=x.device).view(1, -1, 1) |
| | positions = positions + offsets.view(-1, 1, 1) |
| | pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype) |
| | x = x + self.positional_scale * pos_emb |
| |
|
| | for layer in self.layers: |
| | x = layer(x, *args, **kwargs) |
| |
|
| | if state is not None: |
| | assert isinstance(state, TransformerState) |
| | state.offsets[:] = torch.where(state.exec_mask, state.offsets + T, state.offsets) |
| | return x |
| |
|
| |
|
| | class MossAudioTokenizerProjectedTransformer(StreamingContainer): |
| | """Transformer with input/output projections.""" |
| |
|
| | def __init__( |
| | self, |
| | input_dimension: int, |
| | output_dimension: int, |
| | d_model: int, |
| | *, |
| | conv_layout: bool = False, |
| | module_type: str, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | self.module_type = module_type |
| | self.downsample_ratio: int = 1 |
| | self.input_dimension = input_dimension |
| | self.output_dimension = output_dimension |
| |
|
| | self.input_proj = ( |
| | nn.Linear(input_dimension, d_model, bias=False) if d_model != input_dimension else nn.Identity() |
| | ) |
| | self.transformer = MossAudioTokenizerTransformer(d_model=d_model, **kwargs) |
| | self.conv_layout = conv_layout |
| | self.output_proj = ( |
| | nn.Linear(d_model, output_dimension, bias=False) if d_model != output_dimension else nn.Identity() |
| | ) |
| |
|
| | def forward(self, x, input_lengths, *args, **kwargs): |
| | x = self.input_proj(x.transpose(1, 2)) |
| | x = self.transformer(x, *args, **kwargs) |
| | x = self.output_proj(x).transpose(1, 2) |
| | return x, input_lengths |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class MossAudioTokenizerPatchedPretransform(nn.Module): |
| | """Patching module for downsampling/upsampling.""" |
| |
|
| | def __init__(self, patch_size: int, is_downsample: bool, module_type: str, **kwargs): |
| | super().__init__() |
| | self.patch_size = patch_size |
| | self.downsample_ratio: int = patch_size |
| | self.is_downsample = is_downsample |
| | self.module_type = module_type |
| |
|
| | def encode(self, x, input_lengths): |
| | b, d, _ = x.shape |
| | h = self.patch_size |
| | x = x.reshape(b, d, -1, h).permute(0, 1, 3, 2).reshape(b, d * h, -1) |
| | |
| | |
| | output_lengths = input_lengths // self.patch_size |
| | return x, output_lengths |
| |
|
| | def decode(self, x, input_lengths): |
| | b, dh, l = x.shape |
| | h = self.patch_size |
| | d = dh // h |
| | x = x.reshape(b, d, h, l).permute(0, 1, 3, 2).reshape(b, d, l * h) |
| | output_lengths = input_lengths * self.patch_size |
| | return x, output_lengths |
| |
|
| | def forward(self, x, input_lengths): |
| | if self.is_downsample: |
| | return self.encode(x, input_lengths) |
| | else: |
| | return self.decode(x, input_lengths) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def WNConv1d(*args, **kwargs): |
| | """Weight-normalized Conv1d.""" |
| | return nn.utils.parametrizations.weight_norm(nn.Conv1d(*args, **kwargs)) |
| |
|
| |
|
| | class MossAudioTokenizerVectorQuantize(nn.Module): |
| | """Single codebook vector quantization (inference only).""" |
| |
|
| | def __init__( |
| | self, |
| | input_dim: int, |
| | codebook_size: int, |
| | codebook_dim: int, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | self.input_dim = input_dim |
| | self.codebook_size = codebook_size |
| | self.codebook_dim = codebook_dim |
| |
|
| | if input_dim != codebook_dim: |
| | self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) |
| | self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) |
| | else: |
| | self.in_proj = nn.Identity() |
| | self.out_proj = nn.Identity() |
| |
|
| | self.codebook = nn.Embedding(codebook_size, codebook_dim) |
| |
|
| | @torch.no_grad() |
| | def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """ |
| | Args: |
| | z: Input tensor of shape (B, D, T) |
| | Returns: |
| | z_q: Quantized tensor of shape (B, D, T) |
| | indices: Code indices of shape (B, T) |
| | z_e: Encoded tensor before quantization |
| | """ |
| | z = z.float() |
| | z_e = self.in_proj(z).float() |
| |
|
| | encodings = z_e.transpose(1, 2).reshape(-1, z_e.shape[1]) |
| |
|
| | codebook_weight = self.codebook.weight |
| | dist = ( |
| | encodings.pow(2).sum(1, keepdim=True) |
| | - 2 * encodings @ codebook_weight.float().t() |
| | + codebook_weight.float().pow(2).sum(1, keepdim=True).t() |
| | ) |
| |
|
| | indices = (-dist).max(1)[1] |
| | indices = indices.reshape(z.size(0), -1) |
| |
|
| | z_q = self.decode_code(indices) |
| | z_q = self.out_proj(z_q).float() |
| |
|
| | return z_q, indices, z_e |
| |
|
| | def decode_code(self, embed_id: torch.Tensor) -> torch.Tensor: |
| | """Decode code indices to embeddings.""" |
| | return self.codebook(embed_id).transpose(1, 2).float() |
| |
|
| |
|
| | class MossAudioTokenizerLFQ(nn.Module): |
| | """LFQ (inference-only) used by ResidualLFQ.""" |
| |
|
| | def __init__( |
| | self, |
| | input_dim: int, |
| | codebook_size: int, |
| | codebook_dim: int, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | self.input_dim = input_dim |
| | self.codebook_size = codebook_size |
| | self.codebook_dim = codebook_dim |
| |
|
| | if self.input_dim != self.codebook_dim: |
| | self.in_proj = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1) |
| | self.out_proj = WNConv1d(self.codebook_dim, self.input_dim, kernel_size=1) |
| | else: |
| | self.in_proj = nn.Identity() |
| | self.out_proj = nn.Identity() |
| |
|
| | self.codebook = nn.Embedding(codebook_size, codebook_dim) |
| |
|
| | @torch.no_grad() |
| | def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """Quantize z into codebook vectors.""" |
| | z = z.float() |
| | z_e = self.in_proj(z).float() |
| | z_q, indices = self.decode_latents(z_e) |
| | z_q = (z_e + (z_q - z_e).detach()).float() |
| | z_q = self.out_proj(z_q).float() |
| | return z_q, indices, z_e |
| |
|
| | def embed_code(self, embed_id: torch.Tensor) -> torch.Tensor: |
| | return F.embedding(embed_id, self.codebook.weight) |
| |
|
| | def decode_code_wo_out_proj(self, embed_id: torch.Tensor) -> torch.Tensor: |
| | return self.embed_code(embed_id).transpose(1, 2) |
| |
|
| | def decode_code(self, embed_id: torch.Tensor) -> torch.Tensor: |
| | z_q = self.decode_code_wo_out_proj(embed_id).float() |
| | z_q = self.out_proj(z_q).float() |
| | return z_q |
| |
|
| | def decode_latents(self, latents: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| | """Match training LFQ: L2-normalize then argmin squared distance.""" |
| | encodings = latents.transpose(1, 2).reshape(-1, latents.shape[1]).float() |
| | codebook = self.codebook.weight.float() |
| |
|
| | encodings = F.normalize(encodings) |
| | codebook = F.normalize(codebook) |
| |
|
| | dist = ( |
| | encodings.pow(2).sum(1, keepdim=True) |
| | - 2 * encodings @ codebook.t() |
| | + codebook.pow(2).sum(1, keepdim=True).t() |
| | ) |
| | indices = (-dist).max(1)[1] |
| | indices = indices.reshape(latents.size(0), -1) |
| | z_q = self.decode_code_wo_out_proj(indices).float() |
| | return z_q, indices |
| |
|
| |
|
| | class MossAudioTokenizerResidualVQ(nn.Module): |
| | """Residual Vector Quantization (inference only).""" |
| |
|
| | def __init__( |
| | self, |
| | input_dim: int = 1024, |
| | rvq_dim: int | None = None, |
| | output_dim: int | None = None, |
| | num_quantizers: int = 32, |
| | codebook_size: int = 1024, |
| | codebook_dim: int = 8, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | self.input_dim = input_dim |
| | self.rvq_dim = rvq_dim or input_dim |
| | self.output_dim = output_dim or input_dim |
| | self.num_quantizers = num_quantizers |
| | self.codebook_size = codebook_size |
| | self.codebook_dim = codebook_dim |
| |
|
| | self.input_proj = ( |
| | WNConv1d(input_dim, self.rvq_dim, kernel_size=1) if input_dim != self.rvq_dim else nn.Identity() |
| | ) |
| | self.output_proj = ( |
| | WNConv1d(self.rvq_dim, self.output_dim, kernel_size=1) |
| | if self.rvq_dim != self.output_dim |
| | else nn.Identity() |
| | ) |
| |
|
| | self.quantizers = nn.ModuleList( |
| | [ |
| | MossAudioTokenizerVectorQuantize( |
| | input_dim=self.rvq_dim, |
| | codebook_size=codebook_size, |
| | codebook_dim=codebook_dim, |
| | **kwargs, |
| | ) |
| | for _ in range(num_quantizers) |
| | ] |
| | ) |
| |
|
| | @torch.no_grad() |
| | def forward( |
| | self, |
| | z: torch.Tensor, |
| | input_length: torch.Tensor, |
| | n_quantizers: int | None = None, |
| | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """ |
| | Args: |
| | z: Input tensor of shape (B, D, T) |
| | input_length: Valid lengths for each sample (B,) |
| | n_quantizers: Number of quantizers to use |
| | Returns: |
| | quantized_out: Quantized output (B, D, T) |
| | all_indices: All code indices (N, B, T) |
| | output_length: Output lengths (B,) |
| | """ |
| | z = self.input_proj(z) |
| |
|
| | batch_size, _, max_time = z.shape |
| | mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1) |
| |
|
| | quantized_out = torch.zeros_like(z, dtype=torch.float32) |
| | residual = z.clone().float() |
| | all_indices = [] |
| |
|
| | n_quantizers = n_quantizers or self.num_quantizers |
| |
|
| | for i, quantizer in enumerate(self.quantizers): |
| | if i >= n_quantizers: |
| | break |
| |
|
| | masked_residual = residual * mask.unsqueeze(1) |
| | z_q_i, indices_i, _ = quantizer(masked_residual) |
| |
|
| | update_mask = mask.unsqueeze(1) |
| | quantized_out = quantized_out + z_q_i * update_mask |
| | residual = residual - z_q_i * update_mask |
| | all_indices.append(indices_i) |
| |
|
| | all_indices = torch.stack(all_indices) |
| | quantized_out = self.output_proj(quantized_out) |
| |
|
| | return quantized_out, all_indices, input_length |
| |
|
| | def decode_codes(self, codes: torch.Tensor) -> torch.Tensor: |
| | """Decode codes from multiple quantizers to embeddings.""" |
| | nq, B, T = codes.shape |
| | emb = torch.zeros(B, self.rvq_dim, T, device=codes.device, dtype=torch.float32) |
| |
|
| | for i, quantizer in enumerate(self.quantizers[:nq]): |
| | quantizer = cast(MossAudioTokenizerVectorQuantize, quantizer) |
| | quantized_i = quantizer.decode_code(codes[i]) |
| | emb += quantized_i |
| |
|
| | emb = self.output_proj(emb) |
| | return emb |
| |
|
| |
|
| | class MossAudioTokenizerResidualLFQ(nn.Module): |
| | """Residual LFQ (inference only).""" |
| |
|
| | def __init__( |
| | self, |
| | input_dim: int = 1024, |
| | rvq_dim: int | None = None, |
| | output_dim: int | None = None, |
| | num_quantizers: int = 32, |
| | codebook_size: int = 1024, |
| | codebook_dim: int = 8, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | self.input_dim = input_dim |
| | self.rvq_dim = rvq_dim or input_dim |
| | self.output_dim = output_dim or input_dim |
| | self.num_quantizers = num_quantizers |
| | self.codebook_size = codebook_size |
| | self.codebook_dim = codebook_dim |
| |
|
| | self.input_proj = ( |
| | WNConv1d(input_dim, self.rvq_dim, kernel_size=1) if input_dim != self.rvq_dim else nn.Identity() |
| | ) |
| | self.output_proj = ( |
| | WNConv1d(self.rvq_dim, self.output_dim, kernel_size=1) |
| | if self.rvq_dim != self.output_dim |
| | else nn.Identity() |
| | ) |
| |
|
| | self.quantizers = nn.ModuleList( |
| | [ |
| | MossAudioTokenizerLFQ( |
| | input_dim=self.rvq_dim, |
| | codebook_size=codebook_size, |
| | codebook_dim=codebook_dim, |
| | **kwargs, |
| | ) |
| | for _ in range(num_quantizers) |
| | ] |
| | ) |
| |
|
| | @torch.no_grad() |
| | def forward( |
| | self, |
| | z: torch.Tensor, |
| | input_length: torch.Tensor, |
| | n_quantizers: int | None = None, |
| | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """Inference quantization.""" |
| | z = self.input_proj(z).float() |
| |
|
| | batch_size, _, max_time = z.shape |
| | mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1) |
| |
|
| | quantized_out = torch.zeros_like(z, dtype=torch.float32) |
| | residual = z.clone().float() |
| | all_indices = [] |
| |
|
| | n_quantizers = n_quantizers or self.num_quantizers |
| | for i, quantizer in enumerate(self.quantizers): |
| | if i >= n_quantizers: |
| | break |
| |
|
| | masked_residual = residual * mask.unsqueeze(1) |
| | z_q_i, indices_i, _ = quantizer(masked_residual) |
| |
|
| | update_mask = mask.unsqueeze(1) |
| | quantized_out = quantized_out + z_q_i * update_mask |
| | residual = residual - z_q_i * update_mask |
| | all_indices.append(indices_i) |
| |
|
| | all_indices = ( |
| | torch.stack(all_indices) |
| | if all_indices |
| | else torch.empty(0, batch_size, max_time, device=z.device, dtype=torch.long) |
| | ) |
| | quantized_out = self.output_proj(quantized_out) |
| | return quantized_out, all_indices, input_length |
| |
|
| | def decode_codes(self, codes: torch.Tensor) -> torch.Tensor: |
| | nq, B, T = codes.shape |
| | emb = torch.zeros(B, self.rvq_dim, T, device=codes.device, dtype=torch.float32) |
| | for i, quantizer in enumerate(self.quantizers[:nq]): |
| | quantizer = cast(MossAudioTokenizerLFQ, quantizer) |
| | emb += quantizer.decode_code(codes[i]).float() |
| | emb = self.output_proj(emb) |
| | return emb |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @auto_docstring |
| | class MossAudioTokenizerPreTrainedModel(PreTrainedAudioTokenizerBase): |
| | """Base class for MossAudioTokenizer models.""" |
| |
|
| | config_class = MossAudioTokenizerConfig |
| | base_model_prefix = "" |
| | main_input_name = "input_values" |
| | input_modalities = "audio" |
| | supports_gradient_checkpointing = False |
| | _no_split_modules = [ |
| | "MossAudioTokenizerTransformerLayer", |
| | "MossAudioTokenizerResidualVQ", |
| | "MossAudioTokenizerResidualLFQ", |
| | ] |
| |
|
| |
|
| | @auto_docstring( |
| | custom_intro=""" |
| | The MossAudioTokenizer neural audio codec model for audio tokenization and synthesis. |
| | """ |
| | ) |
| | class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel): |
| | """ |
| | MossAudioTokenizer model for audio tokenization and synthesis. |
| | |
| | This model can encode audio waveforms into discrete tokens and decode |
| | tokens back into audio waveforms. |
| | """ |
| |
|
| | def __init__(self, config: MossAudioTokenizerConfig): |
| | super().__init__(config) |
| |
|
| | self.config = config |
| | _ = config.version |
| | self.sampling_rate = config.sampling_rate |
| | self.downsample_rate = config.downsample_rate |
| | self.causal_transformer_context_duration = config.causal_transformer_context_duration |
| |
|
| | |
| | current_frame_rate: float = float(self.sampling_rate) |
| | self.encoder = nn.ModuleList() |
| |
|
| | for encoder_kwargs_i in config.encoder_kwargs: |
| | encoder_kwargs_i = dict(encoder_kwargs_i) |
| | if encoder_kwargs_i["module_type"] == "PatchedPretransform": |
| | self.encoder.append(MossAudioTokenizerPatchedPretransform(**encoder_kwargs_i, is_downsample=True)) |
| | elif encoder_kwargs_i["module_type"] == "Transformer": |
| | self.encoder.append( |
| | MossAudioTokenizerProjectedTransformer( |
| | **encoder_kwargs_i, |
| | context=int(current_frame_rate * self.causal_transformer_context_duration), |
| | ) |
| | ) |
| | current_frame_rate /= self.encoder[-1].downsample_ratio |
| |
|
| | |
| | quantizer_kwargs = dict(config.quantizer_kwargs) |
| | quantizer_type = quantizer_kwargs.get("quantizer_type", getattr(config, "quantizer_type", "rvq")) |
| | if quantizer_type in {"rvq", "spec_rvq"}: |
| | self.quantizer = MossAudioTokenizerResidualVQ(**quantizer_kwargs) |
| | elif quantizer_type in {"rlfq", "random_prefix_rlfq"}: |
| | self.quantizer = MossAudioTokenizerResidualLFQ(**quantizer_kwargs) |
| | else: |
| | raise ValueError(f"Unsupported quantizer_type: {quantizer_type}") |
| |
|
| | |
| | decoder_kwargs_list = copy.deepcopy(config.decoder_kwargs) |
| | self.decoder = nn.ModuleList() |
| |
|
| | for decoder_kwargs_i in decoder_kwargs_list: |
| | decoder_kwargs_i = dict(decoder_kwargs_i) |
| | if decoder_kwargs_i["module_type"] == "PatchedPretransform": |
| | self.decoder.append(MossAudioTokenizerPatchedPretransform(**decoder_kwargs_i, is_downsample=False)) |
| | elif decoder_kwargs_i["module_type"] == "Transformer": |
| | self.decoder.append( |
| | MossAudioTokenizerProjectedTransformer( |
| | **decoder_kwargs_i, |
| | context=int(current_frame_rate * self.causal_transformer_context_duration), |
| | ) |
| | ) |
| | current_frame_rate *= self.decoder[-1].downsample_ratio |
| |
|
| | self.post_init() |
| |
|
| | def _start_streaming(self, batch_size: int): |
| | """Start streaming mode for all modules.""" |
| |
|
| | def _start(module): |
| | if isinstance(module, StreamingModule): |
| | module._streaming_state = module._init_streaming_state(batch_size) |
| |
|
| | self.apply(_start) |
| |
|
| | def _stop_streaming(self): |
| | """Stop streaming mode for all modules.""" |
| |
|
| | def _stop(module): |
| | if isinstance(module, StreamingModule): |
| | module._streaming_state = None |
| |
|
| | self.apply(_stop) |
| |
|
| | @contextmanager |
| | def streaming(self, batch_size: int = 1): |
| | """Context manager for streaming mode.""" |
| | self._start_streaming(batch_size) |
| | try: |
| | yield |
| | finally: |
| | self._stop_streaming() |
| |
|
| | @torch.no_grad() |
| | def batch_encode( |
| | self, wav_list: list[torch.Tensor], num_quantizers: int | None = None |
| | ) -> MossAudioTokenizerEncoderOutput: |
| | """Batch encode a list of audio waveforms. |
| | |
| | Args: |
| | wav_list: List of audio tensors, each of shape `(num_samples,)`. |
| | num_quantizers: Number of quantizers to use. By default, all quantizers are used. |
| | |
| | Returns: |
| | [`MossAudioTokenizerEncoderOutput`] with `audio_codes` and `audio_codes_lengths`. |
| | """ |
| | if len(wav_list) == 0: |
| | raise ValueError("`wav_list` must contain at least one waveform.") |
| |
|
| | device = wav_list[0].device |
| | batch_size = len(wav_list) |
| |
|
| | max_length = max(wav.shape[-1] for wav in wav_list) |
| | input_values = torch.zeros(batch_size, 1, max_length, device=device) |
| | input_lengths = torch.zeros(batch_size, device=device, dtype=torch.long) |
| |
|
| | for i, wav in enumerate(wav_list): |
| | input_values[i, 0, : wav.shape[-1]] = wav |
| | input_lengths[i] = wav.shape[-1] |
| |
|
| | return self._encode_frame(input_values, input_lengths, n_quantizers=num_quantizers) |
| |
|
| | @torch.no_grad() |
| | def batch_decode( |
| | self, codes_list: list[torch.Tensor], num_quantizers: int | None = None |
| | ) -> MossAudioTokenizerDecoderOutput: |
| | """Batch decode a list of audio codes. |
| | |
| | Args: |
| | codes_list: List of audio code tensors, each of shape `(num_quantizers, codes_length)`. |
| | num_quantizers: If provided, decode only the first `num_quantizers` quantizers from each element in |
| | `codes_list`. If omitted, all elements in `codes_list` must have the same number of quantizers. |
| | |
| | Returns: |
| | [`MossAudioTokenizerDecoderOutput`] with `audio` and `audio_lengths`. |
| | """ |
| | if len(codes_list) == 0: |
| | raise ValueError("`codes_list` must contain at least one code tensor.") |
| |
|
| | batch_size = len(codes_list) |
| | device = codes_list[0].device |
| | nqs = [codes.shape[0] for codes in codes_list] |
| | if num_quantizers is None: |
| | num_quantizers = nqs[0] |
| | if any(nq != num_quantizers for nq in nqs): |
| | raise ValueError( |
| | "All elements in `codes_list` must have the same number of quantizers when `num_quantizers` is None. " |
| | "Pass `num_quantizers=...` to decode a common prefix." |
| | ) |
| | else: |
| | min_nq = min(nqs) |
| | if min_nq < num_quantizers: |
| | raise ValueError( |
| | "`num_quantizers` must be <= the number of quantizers for every element in `codes_list`. " |
| | f"Got num_quantizers={num_quantizers}, min(codes.shape[0])={min_nq}." |
| | ) |
| | max_length = max(codes.shape[-1] for codes in codes_list) |
| |
|
| | audio_codes = torch.zeros(num_quantizers, batch_size, max_length, device=device, dtype=torch.long) |
| | audio_codes_lengths = torch.zeros(batch_size, device=device, dtype=torch.long) |
| |
|
| | for i, codes in enumerate(codes_list): |
| | codes = codes[:num_quantizers] |
| | audio_codes[:, i, : codes.shape[-1]] = codes |
| | audio_codes_lengths[i] = codes.shape[-1] |
| |
|
| | return self._decode_frame(audio_codes, audio_codes_lengths) |
| |
|
| | @torch.no_grad() |
| | def _encode_frame( |
| | self, |
| | input_values: torch.Tensor, |
| | input_lengths: torch.Tensor | None = None, |
| | n_quantizers: int | None = None, |
| | ) -> MossAudioTokenizerEncoderOutput: |
| | """Tokenize audio waveform into discrete tokens.""" |
| | |
| | if input_values.dim() == 2: |
| | input_values = input_values.unsqueeze(1) |
| |
|
| | B, _, T = input_values.shape |
| | device = input_values.device |
| |
|
| | if input_lengths is None: |
| | input_lengths = torch.full((B,), T, device=device, dtype=torch.long) |
| |
|
| | |
| | if T % self.downsample_rate != 0: |
| | pad_length = self.downsample_rate - (T % self.downsample_rate) |
| | input_values = F.pad(input_values, (0, pad_length)) |
| |
|
| | |
| | e, e_lengths = input_values, input_lengths |
| | for encoder_module in self.encoder: |
| | e, e_lengths = encoder_module(e, e_lengths) |
| |
|
| | |
| | quantizer = cast(MossAudioTokenizerResidualVQ | MossAudioTokenizerResidualLFQ, self.quantizer) |
| | zq, audio_codes, audio_codes_lengths = quantizer(e, e_lengths, n_quantizers) |
| |
|
| | return MossAudioTokenizerEncoderOutput( |
| | audio_codes=audio_codes, audio_codes_lengths=audio_codes_lengths, encoder_hidden_states=e |
| | ) |
| |
|
| | @torch.no_grad() |
| | def _decode_frame( |
| | self, |
| | codes: torch.Tensor, |
| | codes_lengths: torch.Tensor | None = None, |
| | ) -> MossAudioTokenizerDecoderOutput: |
| | """Detokenize discrete tokens into audio waveform.""" |
| | nq, B, T = codes.shape |
| | device = codes.device |
| |
|
| | if codes_lengths is None: |
| | codes_lengths = torch.full((B,), T, device=device, dtype=torch.long) |
| |
|
| | |
| | quantizer = cast(MossAudioTokenizerResidualVQ | MossAudioTokenizerResidualLFQ, self.quantizer) |
| | zq = quantizer.decode_codes(codes) |
| |
|
| | d, d_lengths = zq, codes_lengths |
| | for decoder_module in self.decoder: |
| | d, d_lengths = decoder_module(d, d_lengths) |
| |
|
| | return MossAudioTokenizerDecoderOutput(audio=d, audio_lengths=d_lengths) |
| |
|
| | def encode( |
| | self, |
| | input_values: torch.Tensor, |
| | padding_mask: torch.Tensor | None = None, |
| | num_quantizers: int | None = None, |
| | return_dict: bool | None = None, |
| | chunk_duration: float | None = None, |
| | ): |
| | """ |
| | Encodes the input audio waveform into discrete codes. |
| | |
| | Args: |
| | input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): |
| | Float values of the input audio waveform. |
| | padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| | Mask to indicate valid audio samples. |
| | num_quantizers (`int`, *optional*): |
| | Number of quantizers to use. By default, all quantizers are used. |
| | return_dict (`bool`, *optional*): |
| | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| | chunk_duration (`float`, *optional*): |
| | If provided, encode the input waveform in successive chunks of `chunk_duration` seconds while keeping a |
| | streaming KV cache for the causal transformers. |
| | |
| | `chunk_duration` must be <= `config.causal_transformer_context_duration`, and |
| | `chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. |
| | |
| | Returns: |
| | `MossAudioTokenizerEncoderOutput` or tuple containing audio codes and lengths. |
| | """ |
| | return_dict = return_dict if return_dict is not None else self.config.return_dict |
| |
|
| | |
| | if input_values.dim() == 2: |
| | input_values = input_values.unsqueeze(1) |
| |
|
| | B, _, T = input_values.shape |
| | device = input_values.device |
| |
|
| | if padding_mask is not None: |
| | input_lengths = padding_mask.sum(dim=-1).long() |
| | else: |
| | input_lengths = torch.full((B,), T, device=device, dtype=torch.long) |
| |
|
| | if chunk_duration is None: |
| | encoder_output = self._encode_frame(input_values, input_lengths, num_quantizers) |
| | else: |
| | if chunk_duration <= 0: |
| | raise ValueError("`chunk_duration` must be > 0 when provided.") |
| | if chunk_duration > self.causal_transformer_context_duration: |
| | raise ValueError( |
| | "`chunk_duration` must be <= `config.causal_transformer_context_duration` " |
| | f"({self.causal_transformer_context_duration}), got {chunk_duration}." |
| | ) |
| | if B != 1: |
| | raise ValueError("Streaming encode via `chunk_duration` currently only supports batch_size=1.") |
| |
|
| | chunk_length = int(round(chunk_duration * self.sampling_rate)) |
| | if chunk_length <= 0: |
| | raise ValueError("`chunk_duration` is too small and results in chunk_length <= 0.") |
| | if chunk_length % self.downsample_rate != 0: |
| | raise ValueError( |
| | "`chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. " |
| | f"Got chunk_length={chunk_length}, downsample_rate={self.downsample_rate}." |
| | ) |
| |
|
| | input_length = int(input_lengths[0].item()) |
| | if input_length <= chunk_length: |
| | encoder_output = self._encode_frame(input_values[..., :input_length], input_lengths, num_quantizers) |
| | else: |
| | codes_chunks: list[torch.Tensor] = [] |
| | hidden_chunks: list[torch.Tensor] = [] |
| |
|
| | with ExitStack() as exit_stack: |
| | for encoder_module in self.encoder: |
| | if isinstance(encoder_module, StreamingModule): |
| | exit_stack.enter_context(encoder_module.streaming(batch_size=B)) |
| |
|
| | for start_idx in range(0, input_length, chunk_length): |
| | input_length_i = min(chunk_length, input_length - start_idx) |
| | if input_length_i <= 0: |
| | break |
| |
|
| | input_lengths_i = torch.tensor([input_length_i], device=device, dtype=torch.long) |
| | input_values_i = input_values[..., start_idx : start_idx + input_length_i] |
| | result_i = self._encode_frame(input_values_i, input_lengths_i, num_quantizers) |
| |
|
| | if result_i.audio_codes is None or result_i.audio_codes_lengths is None: |
| | raise RuntimeError("Internal error: `_encode_frame` returned empty audio codes.") |
| | if result_i.encoder_hidden_states is None: |
| | raise RuntimeError("Internal error: `_encode_frame` returned empty encoder hidden states.") |
| |
|
| | codes_length_i = result_i.audio_codes_lengths |
| | codes_chunks.append(result_i.audio_codes[:, :, : codes_length_i[0]]) |
| | hidden_chunks.append(result_i.encoder_hidden_states[:, :, : codes_length_i[0]]) |
| |
|
| | audio_codes = torch.cat(codes_chunks, dim=-1) |
| | encoder_hidden_states = torch.cat(hidden_chunks, dim=-1) |
| | audio_codes_lengths = torch.tensor([audio_codes.shape[-1]], device=device, dtype=torch.long) |
| | encoder_output = MossAudioTokenizerEncoderOutput( |
| | audio_codes=audio_codes, |
| | audio_codes_lengths=audio_codes_lengths, |
| | encoder_hidden_states=encoder_hidden_states, |
| | ) |
| |
|
| | if not return_dict: |
| | assert encoder_output.audio_codes is not None |
| | assert encoder_output.audio_codes_lengths is not None |
| | return ( |
| | cast(torch.Tensor, encoder_output.audio_codes), |
| | cast(torch.Tensor, encoder_output.audio_codes_lengths), |
| | ) |
| | return encoder_output |
| |
|
| | def decode( |
| | self, |
| | audio_codes: torch.Tensor, |
| | padding_mask: torch.Tensor | None = None, |
| | return_dict: bool | None = None, |
| | chunk_duration: float | None = None, |
| | num_quantizers: int | None = None, |
| | ): |
| | """ |
| | Decodes the given codes into an output audio waveform. |
| | |
| | Args: |
| | audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`): |
| | Discrete code embeddings computed using `model.encode`. |
| | padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| | Mask to indicate valid code positions. |
| | return_dict (`bool`, *optional*): |
| | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| | chunk_duration (`float`, *optional*): |
| | If provided, decode the input codes in successive chunks of `chunk_duration` seconds while keeping a |
| | streaming KV cache for the causal transformers. |
| | |
| | num_quantizers (`int`, *optional*): |
| | Number of quantizers to use. By default, all quantizers in `audio_codes` are used. |
| | |
| | `chunk_duration` must be <= `config.causal_transformer_context_duration`, and |
| | `chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. |
| | |
| | Returns: |
| | `MossAudioTokenizerDecoderOutput` or tuple containing decoded audio. |
| | """ |
| | return_dict = return_dict if return_dict is not None else self.config.return_dict |
| |
|
| | if audio_codes.dim() == 2: |
| | audio_codes = audio_codes.unsqueeze(1) |
| |
|
| | if num_quantizers is not None: |
| | if num_quantizers > audio_codes.shape[0]: |
| | raise ValueError( |
| | f"`num_quantizers` ({num_quantizers}) must be <= audio_codes.shape[0] ({audio_codes.shape[0]})." |
| | ) |
| | audio_codes = audio_codes[:num_quantizers] |
| |
|
| | _, B, T = audio_codes.shape |
| | device = audio_codes.device |
| |
|
| | if padding_mask is not None: |
| | codes_lengths = padding_mask.sum(dim=-1).long() |
| | else: |
| | codes_lengths = torch.full((B,), T, device=device, dtype=torch.long) |
| |
|
| | if chunk_duration is None: |
| | decoder_output = self._decode_frame(audio_codes, codes_lengths) |
| | else: |
| | if chunk_duration <= 0: |
| | raise ValueError("`chunk_duration` must be > 0 when provided.") |
| | if chunk_duration > self.causal_transformer_context_duration: |
| | raise ValueError( |
| | "`chunk_duration` must be <= `config.causal_transformer_context_duration` " |
| | f"({self.causal_transformer_context_duration}), got {chunk_duration}." |
| | ) |
| | if B != 1: |
| | raise ValueError("Streaming decode via `chunk_duration` currently only supports batch_size=1.") |
| |
|
| | chunk_length = int(round(chunk_duration * self.sampling_rate)) |
| | if chunk_length <= 0: |
| | raise ValueError("`chunk_duration` is too small and results in chunk_length <= 0.") |
| | if chunk_length % self.downsample_rate != 0: |
| | raise ValueError( |
| | "`chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. " |
| | f"Got chunk_length={chunk_length}, downsample_rate={self.downsample_rate}." |
| | ) |
| |
|
| | chunk_frame_length = chunk_length // self.downsample_rate |
| | codes_length = int(codes_lengths[0].item()) |
| | if codes_length <= chunk_frame_length: |
| | decoder_output = self._decode_frame(audio_codes[..., :codes_length], codes_lengths) |
| | else: |
| | wav_chunks: list[torch.Tensor] = [] |
| | with ExitStack() as exit_stack: |
| | for decoder_module in self.decoder: |
| | if isinstance(decoder_module, StreamingModule): |
| | exit_stack.enter_context(decoder_module.streaming(batch_size=B)) |
| |
|
| | for start_idx in range(0, codes_length, chunk_frame_length): |
| | codes_length_i = min(chunk_frame_length, codes_length - start_idx) |
| | if codes_length_i <= 0: |
| | break |
| |
|
| | codes_lengths_i = torch.tensor([codes_length_i], device=device, dtype=torch.long) |
| | codes_i = audio_codes[:, :, start_idx : start_idx + codes_length_i] |
| | result_i = self._decode_frame(codes_i, codes_lengths_i) |
| |
|
| | if result_i.audio is None or result_i.audio_lengths is None: |
| | raise RuntimeError("Internal error: `_decode_frame` returned empty audio.") |
| |
|
| | wav_chunks.append(result_i.audio[:, :, : result_i.audio_lengths[0]]) |
| |
|
| | wav = torch.cat(wav_chunks, dim=-1) |
| | audio_lengths = torch.tensor([wav.shape[-1]], device=device, dtype=torch.long) |
| | decoder_output = MossAudioTokenizerDecoderOutput(audio=wav, audio_lengths=audio_lengths) |
| |
|
| | if not return_dict: |
| | assert decoder_output.audio is not None |
| | return (cast(torch.Tensor, decoder_output.audio),) |
| | return decoder_output |
| |
|
| | @auto_docstring |
| | def forward( |
| | self, |
| | input_values: torch.FloatTensor | None = None, |
| | padding_mask: torch.BoolTensor | None = None, |
| | audio_codes: torch.Tensor | None = None, |
| | num_quantizers: int | None = None, |
| | return_dict: bool | None = None, |
| | ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | MossAudioTokenizerOutput: |
| | r""" |
| | input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*): |
| | Raw audio input converted to Float. |
| | padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| | Mask to avoid computing on padding token indices. Mask values selected in `[0, 1]`: |
| | - 1 for tokens that are **not masked**, |
| | - 0 for tokens that are **masked**. |
| | audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*): |
| | Discrete code embeddings computed using `model.encode`. |
| | num_quantizers (`int`, *optional*): |
| | Number of quantizers (codebooks) to use. By default, all quantizers are used. |
| | return_dict (`bool`, *optional*): |
| | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| | |
| | Examples: |
| | |
| | ```python |
| | >>> import torch |
| | >>> from transformers import MossAudioTokenizerModel |
| | |
| | >>> model = MossAudioTokenizerModel.from_pretrained("moss_audio_tokenizer-model") |
| | |
| | >>> # Create dummy audio input |
| | >>> audio = torch.randn(1, 1, 24000) # 1 second of audio at 24kHz |
| | |
| | >>> outputs = model(input_values=audio) |
| | >>> audio_codes = outputs.audio_codes |
| | >>> audio_values = outputs.audio |
| | ``` |
| | """ |
| | return_dict = return_dict if return_dict is not None else self.config.return_dict |
| |
|
| | output_audio_codes: torch.Tensor | None = None |
| | output_audio_codes_lengths: torch.Tensor | None = None |
| | output_audio: torch.Tensor | None = None |
| | output_audio_lengths: torch.Tensor | None = None |
| | decoded_from_encoded_codes = False |
| |
|
| | |
| | if input_values is not None: |
| | encoder_output = self.encode(input_values, padding_mask, num_quantizers, return_dict=True) |
| | encoder_output = cast(MossAudioTokenizerEncoderOutput, encoder_output) |
| | output_audio_codes = encoder_output.audio_codes |
| | output_audio_codes_lengths = encoder_output.audio_codes_lengths |
| |
|
| | |
| | if audio_codes is None: |
| | audio_codes = output_audio_codes |
| | decoded_from_encoded_codes = True |
| |
|
| | |
| | if audio_codes is not None: |
| | |
| | if decoded_from_encoded_codes and output_audio_codes_lengths is not None: |
| | decoder_output = self._decode_frame(audio_codes, output_audio_codes_lengths) |
| | else: |
| | decoder_output = self.decode( |
| | audio_codes, |
| | padding_mask=padding_mask, |
| | return_dict=True, |
| | num_quantizers=num_quantizers, |
| | ) |
| | decoder_output = cast(MossAudioTokenizerDecoderOutput, decoder_output) |
| | output_audio = decoder_output.audio |
| | output_audio_lengths = decoder_output.audio_lengths |
| |
|
| | if not return_dict: |
| | return (output_audio_codes, output_audio, output_audio_lengths) |
| |
|
| | return MossAudioTokenizerOutput( |
| | audio=output_audio, |
| | audio_lengths=output_audio_lengths, |
| | audio_codes=output_audio_codes, |
| | audio_codes_lengths=output_audio_codes_lengths, |
| | ) |
| |
|
| |
|
| | __all__ = ["MossAudioTokenizerModel", "MossAudioTokenizerPreTrainedModel"] |
| |
|