Spaces:
Sleeping
Sleeping
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This code is modified from https://github.com/ZhangXInFD/SpeechTokenizer/blob/main/speechtokenizer/model.py | |
| # Licensed under Apache License 2.0 | |
| from .modules.seanet import SEANetEncoder, SEANetDecoder | |
| from .modules.quantization import ResidualVectorQuantizer | |
| import torch.nn as nn | |
| from einops import rearrange | |
| import torch | |
| import numpy as np | |
| class SpeechTokenizer(nn.Module): | |
| def __init__(self, config): | |
| """ | |
| Parameters | |
| ---------- | |
| config : json | |
| Model Config. | |
| """ | |
| super().__init__() | |
| self.encoder = SEANetEncoder( | |
| n_filters=config.get("n_filters"), | |
| dimension=config.get("dimension"), | |
| ratios=config.get("strides"), | |
| lstm=config.get("lstm_layers"), | |
| bidirectional=config.get("bidirectional"), | |
| dilation_base=config.get("dilation_base"), | |
| residual_kernel_size=config.get("residual_kernel_size"), | |
| n_residual_layers=config.get("n_residual_layers"), | |
| activation=config.get("activation"), | |
| ) | |
| self.sample_rate = config.get("sample_rate") | |
| self.n_q = config.get("n_q") | |
| self.downsample_rate = np.prod(config.get("strides")) | |
| if config.get("dimension") != config.get("semantic_dimension"): | |
| self.transform = nn.Linear( | |
| config.get("dimension"), config.get("semantic_dimension") | |
| ) | |
| else: | |
| self.transform = nn.Identity() | |
| self.quantizer = ResidualVectorQuantizer( | |
| dimension=config.get("dimension"), | |
| n_q=config.get("n_q"), | |
| bins=config.get("codebook_size"), | |
| ) | |
| self.decoder = SEANetDecoder( | |
| n_filters=config.get("n_filters"), | |
| dimension=config.get("dimension"), | |
| ratios=config.get("strides"), | |
| lstm=config.get("lstm_layers"), | |
| bidirectional=False, | |
| dilation_base=config.get("dilation_base"), | |
| residual_kernel_size=config.get("residual_kernel_size"), | |
| n_residual_layers=config.get("n_residual_layers"), | |
| activation=config.get("activation"), | |
| ) | |
| def load_from_checkpoint(cls, config_path: str, ckpt_path: str): | |
| """ | |
| Parameters | |
| ---------- | |
| config_path : str | |
| Path of model configuration file. | |
| ckpt_path : str | |
| Path of model checkpoint. | |
| Returns | |
| ------- | |
| model : SpeechTokenizer | |
| SpeechTokenizer model. | |
| """ | |
| import json | |
| with open(config_path) as f: | |
| cfg = json.load(f) | |
| model = cls(cfg) | |
| params = torch.load(ckpt_path, map_location="cpu") | |
| model.load_state_dict(params) | |
| return model | |
| def forward(self, x: torch.tensor, n_q: int = None, layers: list = [0]): | |
| """ | |
| Parameters | |
| ---------- | |
| x : torch.tensor | |
| Input wavs. Shape: (batch, channels, timesteps). | |
| n_q : int, optional | |
| Number of quantizers in RVQ used to encode. The default is all layers. | |
| layers : list[int], optional | |
| Layers of RVQ should return quantized result. The default is the first layer. | |
| Returns | |
| ------- | |
| o : torch.tensor | |
| Output wavs. Shape: (batch, channels, timesteps). | |
| commit_loss : torch.tensor | |
| Commitment loss from residual vector quantizers. | |
| feature : torch.tensor | |
| Output of RVQ's first layer. Shape: (batch, timesteps, dimension) | |
| """ | |
| n_q = n_q if n_q else self.n_q | |
| e = self.encoder(x) | |
| quantized, codes, commit_loss, quantized_list = self.quantizer( | |
| e, n_q=n_q, layers=layers | |
| ) | |
| feature = rearrange(quantized_list[0], "b d t -> b t d") | |
| feature = self.transform(feature) | |
| o = self.decoder(quantized) | |
| return o, commit_loss, feature | |
| def forward_feature(self, x: torch.tensor, layers: list = None): | |
| """ | |
| Parameters | |
| ---------- | |
| x : torch.tensor | |
| Input wavs. Shape should be (batch, channels, timesteps). | |
| layers : list[int], optional | |
| Layers of RVQ should return quantized result. The default is all layers. | |
| Returns | |
| ------- | |
| quantized_list : list[torch.tensor] | |
| Quantized of required layers. | |
| """ | |
| e = self.encoder(x) | |
| layers = layers if layers else list(range(self.n_q)) | |
| quantized, codes, commit_loss, quantized_list = self.quantizer(e, layers=layers) | |
| return quantized_list | |
| def encode(self, x: torch.tensor, n_q: int = None, st: int = None): | |
| """ | |
| Parameters | |
| ---------- | |
| x : torch.tensor | |
| Input wavs. Shape: (batch, channels, timesteps). | |
| n_q : int, optional | |
| Number of quantizers in RVQ used to encode. The default is all layers. | |
| st : int, optional | |
| Start quantizer index in RVQ. The default is 0. | |
| Returns | |
| ------- | |
| codes : torch.tensor | |
| Output indices for each quantizer. Shape: (n_q, batch, timesteps) | |
| """ | |
| e = self.encoder(x) | |
| if st is None: | |
| st = 0 | |
| n_q = n_q if n_q else self.n_q | |
| codes = self.quantizer.encode(e, n_q=n_q, st=st) | |
| return codes | |
| def decode(self, codes: torch.tensor, st: int = 0): | |
| """ | |
| Parameters | |
| ---------- | |
| codes : torch.tensor | |
| Indices for each quantizer. Shape: (n_q, batch, timesteps). | |
| st : int, optional | |
| Start quantizer index in RVQ. The default is 0. | |
| Returns | |
| ------- | |
| o : torch.tensor | |
| Reconstruct wavs from codes. Shape: (batch, channels, timesteps) | |
| """ | |
| quantized = self.quantizer.decode(codes, st=st) | |
| o = self.decoder(quantized) | |
| return o | |