Spaces:
Running
on
Zero
Running
on
Zero
| import random | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| from typing import List | |
| from torch import nn | |
| from torch.nn import Module | |
| from torch.amp import autocast | |
| from einx import get_at | |
| from einops import rearrange, reduce, pack, unpack | |
| from sparktts.modules.fsq.finite_scalar_quantization import FSQ | |
| def exists(val): | |
| return val is not None | |
| def first(l): | |
| return l[0] | |
| def default(val, d): | |
| return val if exists(val) else d | |
| def round_up_multiple(num, mult): | |
| return ceil(num / mult) * mult | |
| # distributed helpers | |
| def is_distributed(): | |
| return dist.is_initialized() and dist.get_world_size() > 1 | |
| def get_maybe_sync_seed(device, max_size=10_000): | |
| rand_int = torch.randint(0, max_size, (), device=device) | |
| if is_distributed(): | |
| dist.all_reduce(rand_int) | |
| return rand_int.item() | |
| class ResidualFSQ(Module): | |
| """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf""" | |
| def __init__( | |
| self, | |
| *, | |
| levels: List[int], | |
| num_quantizers, | |
| dim=None, | |
| is_channel_first=False, | |
| quantize_dropout=False, | |
| quantize_dropout_cutoff_index=0, | |
| quantize_dropout_multiple_of=1, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| codebook_dim = len(levels) | |
| dim = default(dim, codebook_dim) | |
| requires_projection = codebook_dim != dim | |
| self.project_in = ( | |
| nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity() | |
| ) | |
| self.project_out = ( | |
| nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity() | |
| ) | |
| self.has_projections = requires_projection | |
| self.is_channel_first = is_channel_first | |
| self.num_quantizers = num_quantizers | |
| self.levels = levels | |
| self.layers = nn.ModuleList([]) | |
| levels_tensor = torch.Tensor(levels) | |
| scales = [] | |
| for ind in range(num_quantizers): | |
| scales.append((levels_tensor - 1) ** -ind) | |
| fsq = FSQ(levels=levels, dim=codebook_dim, **kwargs) | |
| self.layers.append(fsq) | |
| assert all([not fsq.has_projections for fsq in self.layers]) | |
| self.codebook_size = self.layers[0].codebook_size | |
| self.register_buffer("scales", torch.stack(scales), persistent=False) | |
| self.quantize_dropout = quantize_dropout and num_quantizers > 1 | |
| assert quantize_dropout_cutoff_index >= 0 | |
| self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index | |
| self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4 | |
| def codebooks(self): | |
| codebooks = [layer.implicit_codebook for layer in self.layers] | |
| codebooks = torch.stack(codebooks, dim=0) | |
| return codebooks | |
| def get_codes_from_indices(self, indices): | |
| batch, quantize_dim = indices.shape[0], indices.shape[-1] | |
| # may also receive indices in the shape of 'b h w q' (accept_image_fmap) | |
| indices, ps = pack([indices], "b * q") | |
| # because of quantize dropout, one can pass in indices that are coarse | |
| # and the network should be able to reconstruct | |
| if quantize_dim < self.num_quantizers: | |
| assert ( | |
| self.quantize_dropout > 0.0 | |
| ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations" | |
| indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1) | |
| # take care of quantizer dropout | |
| mask = indices == -1 | |
| indices = indices.masked_fill( | |
| mask, 0 | |
| ) # have it fetch a dummy code to be masked out later | |
| all_codes = get_at("q [c] d, b n q -> q b n d", self.codebooks, indices) | |
| # mask out any codes that were dropout-ed | |
| all_codes = all_codes.masked_fill(rearrange(mask, "b n q -> q b n 1"), 0.0) | |
| # scale the codes | |
| scales = rearrange(self.scales, "q d -> q 1 1 d") | |
| all_codes = all_codes * scales | |
| # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension) | |
| (all_codes,) = unpack(all_codes, ps, "q b * d") | |
| return all_codes | |
| def get_output_from_indices(self, indices): | |
| codes = self.get_codes_from_indices(indices) | |
| codes_summed = reduce(codes, "q ... -> ...", "sum") | |
| return self.project_out(codes_summed) | |
| def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None): | |
| num_quant, quant_dropout_multiple_of, device = ( | |
| self.num_quantizers, | |
| self.quantize_dropout_multiple_of, | |
| x.device, | |
| ) | |
| # handle channel first | |
| if self.is_channel_first: | |
| x = rearrange(x, "b d ... -> b ... d") | |
| x, ps = pack([x], "b * d") | |
| # maybe project in | |
| x = self.project_in(x) | |
| quantized_out = 0.0 | |
| residual = x | |
| all_indices = [] | |
| should_quantize_dropout = self.training and self.quantize_dropout | |
| # sample a layer index at which to dropout further residual quantization | |
| # also prepare null indices | |
| if should_quantize_dropout: | |
| # check if seed is manually passed in | |
| if not exists(rand_quantize_dropout_fixed_seed): | |
| rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) | |
| rand = random.Random(rand_quantize_dropout_fixed_seed) | |
| rand_quantize_dropout_index = rand.randrange( | |
| self.quantize_dropout_cutoff_index, num_quant | |
| ) | |
| if quant_dropout_multiple_of != 1: | |
| rand_quantize_dropout_index = ( | |
| round_up_multiple( | |
| rand_quantize_dropout_index + 1, quant_dropout_multiple_of | |
| ) | |
| - 1 | |
| ) | |
| null_indices = torch.full( | |
| x.shape[:2], -1.0, device=device, dtype=torch.long | |
| ) | |
| # go through the layers | |
| with autocast("cuda", enabled=False): | |
| for quantizer_index, (layer, scale) in enumerate( | |
| zip(self.layers, self.scales) | |
| ): | |
| if ( | |
| should_quantize_dropout | |
| and quantizer_index > rand_quantize_dropout_index | |
| ): | |
| all_indices.append(null_indices) | |
| continue | |
| quantized, indices = layer(residual / scale) | |
| quantized = quantized * scale | |
| residual = residual - quantized.detach() | |
| quantized_out = quantized_out + quantized | |
| all_indices.append(indices) | |
| # project out, if needed | |
| quantized_out = self.project_out(quantized_out) | |
| # stack all indices | |
| all_indices = torch.stack(all_indices, dim=-1) | |
| # channel first out | |
| if self.is_channel_first: | |
| (quantized_out,) = unpack(quantized_out, ps, "b * d") | |
| (all_indices,) = unpack(all_indices, ps, "b * d") | |
| quantized_out = rearrange(quantized_out, "b ... d -> b d ...") | |
| all_indices = rearrange(all_indices, "b ... d -> b d ...") | |
| # return | |
| ret = (quantized_out, all_indices) | |
| if not return_all_codes: | |
| return ret | |
| # whether to return all codes from all codebooks across layers | |
| all_codes = self.get_codes_from_indices(all_indices) | |
| # will return all codes in shape (quantizer, batch, sequence length, codebook dimension) | |
| return (*ret, all_codes) | |
| # grouped residual fsq | |
| class GroupedResidualFSQ(Module): | |
| def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs): | |
| super().__init__() | |
| self.dim = dim | |
| self.groups = groups | |
| assert (dim % groups) == 0 | |
| dim_per_group = dim // groups | |
| self.accept_image_fmap = accept_image_fmap | |
| self.rvqs = nn.ModuleList([]) | |
| for _ in range(groups): | |
| self.rvqs.append(ResidualFSQ(dim=dim_per_group, **kwargs)) | |
| self.codebook_size = self.rvqs[0].codebook_size | |
| def codebooks(self): | |
| return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs)) | |
| def split_dim(self): | |
| return 1 if self.accept_image_fmap else -1 | |
| def get_codes_from_indices(self, indices): | |
| codes = tuple( | |
| rvq.get_codes_from_indices(chunk_indices) | |
| for rvq, chunk_indices in zip(self.rvqs, indices) | |
| ) | |
| return torch.stack(codes) | |
| def get_output_from_indices(self, indices): | |
| outputs = tuple( | |
| rvq.get_output_from_indices(chunk_indices) | |
| for rvq, chunk_indices in zip(self.rvqs, indices) | |
| ) | |
| return torch.cat(outputs, dim=self.split_dim) | |
| def forward(self, x, return_all_codes=False): | |
| shape, split_dim, device = x.shape, self.split_dim, x.device | |
| assert shape[split_dim] == self.dim | |
| # split the feature dimension into groups | |
| x = x.chunk(self.groups, dim=split_dim) | |
| forward_kwargs = dict( | |
| return_all_codes=return_all_codes, | |
| rand_quantize_dropout_fixed_seed=( | |
| get_maybe_sync_seed(device) if self.training else None | |
| ), | |
| ) | |
| # invoke residual vq on each group | |
| out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x)) | |
| out = tuple(zip(*out)) | |
| # otherwise, get all the zipped outputs and combine them | |
| quantized, all_indices, *maybe_all_codes = out | |
| quantized = torch.cat(quantized, dim=split_dim) | |
| all_indices = torch.stack(all_indices) | |
| ret = (quantized, all_indices, *maybe_all_codes) | |
| return ret | |
| if __name__ == "__main__": | |
| model = ResidualFSQ( | |
| levels=[4, 4, 4, 4, 4, 4], | |
| num_quantizers=1, | |
| dim=30, | |
| is_channel_first=True, | |
| quantize_dropout=False, | |
| ) | |
| x = torch.randn(2, 30, 10) | |
| quantize, embed_ind = model(x) | |
| emb_from_ind = model.get_output_from_indices(embed_ind.transpose(1, 2)) | |
| print(quantize == emb_from_ind.transpose(1, 2)) | |
| print("quantize shape", quantize.shape) | |
| print("embed_ind", embed_ind) | |