Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| # Copyright 2024 Xiaomi Corp. (authors: Han Zhu) | |
| # | |
| # See ../../../../LICENSE for clarification regarding multiple authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Optional, Union | |
| import torch | |
| class DiffusionModel(torch.nn.Module): | |
| """A wrapper of diffusion models for inference. | |
| Args: | |
| model: The diffusion model. | |
| func_name: The function name to call. | |
| """ | |
| def __init__( | |
| self, | |
| model: torch.nn.Module, | |
| func_name: str = "forward_fm_decoder", | |
| ): | |
| super().__init__() | |
| self.model = model | |
| self.func_name = func_name | |
| self.model_func = getattr(self.model, func_name) | |
| def forward( | |
| self, | |
| t: torch.Tensor, | |
| x: torch.Tensor, | |
| text_condition: torch.Tensor, | |
| speech_condition: torch.Tensor, | |
| padding_mask: Optional[torch.Tensor] = None, | |
| guidance_scale: Union[float, torch.Tensor] = 0.0, | |
| **kwargs | |
| ) -> torch.Tensor: | |
| """ | |
| Forward function that Handles the classifier-free guidance. | |
| Args: | |
| t: The current timestep, a tensor of a tensor of a single float. | |
| x: The initial value, with the shape (batch, seq_len, emb_dim). | |
| text_condition: The text_condition of the diffision model, with | |
| the shape (batch, seq_len, emb_dim). | |
| speech_condition: The speech_condition of the diffision model, with the | |
| shape (batch, seq_len, emb_dim). | |
| padding_mask: The mask for padding; True means masked position, with the | |
| shape (batch, seq_len). | |
| guidance_scale: The scale of classifier-free guidance, a float or a tensor | |
| of shape (batch, 1, 1). | |
| Retrun: | |
| The prediction with the shape (batch, seq_len, emb_dim). | |
| """ | |
| if not torch.is_tensor(guidance_scale): | |
| guidance_scale = torch.tensor( | |
| guidance_scale, dtype=t.dtype, device=t.device | |
| ) | |
| if (guidance_scale == 0.0).all(): | |
| return self.model_func( | |
| t=t, | |
| xt=x, | |
| text_condition=text_condition, | |
| speech_condition=speech_condition, | |
| padding_mask=padding_mask, | |
| **kwargs | |
| ) | |
| else: | |
| assert t.dim() == 0 | |
| x = torch.cat([x] * 2, dim=0) | |
| padding_mask = torch.cat([padding_mask] * 2, dim=0) | |
| text_condition = torch.cat( | |
| [torch.zeros_like(text_condition), text_condition], dim=0 | |
| ) | |
| if t > 0.5: | |
| speech_condition = torch.cat( | |
| [torch.zeros_like(speech_condition), speech_condition], dim=0 | |
| ) | |
| else: | |
| guidance_scale = guidance_scale * 2 | |
| speech_condition = torch.cat( | |
| [speech_condition, speech_condition], dim=0 | |
| ) | |
| data_uncond, data_cond = self.model_func( | |
| t=t, | |
| xt=x, | |
| text_condition=text_condition, | |
| speech_condition=speech_condition, | |
| padding_mask=padding_mask, | |
| **kwargs | |
| ).chunk(2, dim=0) | |
| res = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond | |
| return res | |
| class DistillDiffusionModel(DiffusionModel): | |
| """A wrapper of distilled diffusion models for inference. | |
| Args: | |
| model: The distilled diffusion model. | |
| func_name: The function name to call. | |
| """ | |
| def __init__( | |
| self, | |
| model: torch.nn.Module, | |
| func_name: str = "forward_fm_decoder", | |
| ): | |
| super().__init__(model=model, func_name=func_name) | |
| def forward( | |
| self, | |
| t: torch.Tensor, | |
| x: torch.Tensor, | |
| text_condition: torch.Tensor, | |
| speech_condition: torch.Tensor, | |
| padding_mask: Optional[torch.Tensor] = None, | |
| guidance_scale: Union[float, torch.Tensor] = 0.0, | |
| **kwargs | |
| ) -> torch.Tensor: | |
| """ | |
| Forward function that Handles the classifier-free guidance. | |
| Args: | |
| t: The current timestep, a tensor of a single float. | |
| x: The initial value, with the shape (batch, seq_len, emb_dim). | |
| text_condition: The text_condition of the diffision model, with | |
| the shape (batch, seq_len, emb_dim). | |
| speech_condition: The speech_condition of the diffision model, with the | |
| shape (batch, seq_len, emb_dim). | |
| padding_mask: The mask for padding; True means masked position, with the | |
| shape (batch, seq_len). | |
| guidance_scale: The scale of classifier-free guidance, a float or a tensor | |
| of shape (batch, 1, 1). | |
| Retrun: | |
| The prediction with the shape (batch, seq_len, emb_dim). | |
| """ | |
| if not torch.is_tensor(guidance_scale): | |
| guidance_scale = torch.tensor( | |
| guidance_scale, dtype=t.dtype, device=t.device | |
| ) | |
| return self.model_func( | |
| t=t, | |
| xt=x, | |
| text_condition=text_condition, | |
| speech_condition=speech_condition, | |
| padding_mask=padding_mask, | |
| guidance_scale=guidance_scale, | |
| **kwargs | |
| ) | |
| class EulerSolver: | |
| def __init__( | |
| self, | |
| model: torch.nn.Module, | |
| func_name: str = "forward_fm_decoder", | |
| ): | |
| """Construct a Euler Solver | |
| Args: | |
| model: The diffusion model. | |
| func_name: The function name to call. | |
| """ | |
| self.model = DiffusionModel(model, func_name=func_name) | |
| def sample( | |
| self, | |
| x: torch.Tensor, | |
| text_condition: torch.Tensor, | |
| speech_condition: torch.Tensor, | |
| padding_mask: torch.Tensor, | |
| num_step: int = 10, | |
| guidance_scale: Union[float, torch.Tensor] = 0.0, | |
| t_start: float = 0.0, | |
| t_end: float = 1.0, | |
| t_shift: float = 1.0, | |
| **kwargs | |
| ) -> torch.Tensor: | |
| """ | |
| Compute the sample at time `t_end` by Euler Solver. | |
| Args: | |
| x: The initial value at time `t_start`, with the shape (batch, seq_len, | |
| emb_dim). | |
| text_condition: The text condition of the diffision mode, with the | |
| shape (batch, seq_len, emb_dim). | |
| speech_condition: The speech condition of the diffision model, with the | |
| shape (batch, seq_len, emb_dim). | |
| padding_mask: The mask for padding; True means masked position, with the | |
| shape (batch, seq_len). | |
| num_step: The number of ODE steps. | |
| guidance_scale: The scale for classifier-free guidance, which is | |
| a float or a tensor with the shape (batch, 1, 1). | |
| t_start: the start timestep in the range of [0, 1]. | |
| t_end: the end time_step in the range of [0, 1]. | |
| t_shift: shift the t toward smaller numbers so that the sampling | |
| will emphasize low SNR region. Should be in the range of (0, 1]. | |
| The shifting will be more significant when the number is smaller. | |
| Returns: | |
| The approximated solution at time `t_end`. | |
| """ | |
| device = x.device | |
| assert isinstance(t_start, float) and isinstance(t_end, float) | |
| timesteps = get_time_steps( | |
| t_start=t_start, | |
| t_end=t_end, | |
| num_step=num_step, | |
| t_shift=t_shift, | |
| device=device, | |
| ) | |
| for step in range(num_step): | |
| v = self.model( | |
| t=timesteps[step], | |
| x=x, | |
| text_condition=text_condition, | |
| speech_condition=speech_condition, | |
| padding_mask=padding_mask, | |
| guidance_scale=guidance_scale, | |
| **kwargs | |
| ) | |
| x = x + v * (timesteps[step + 1] - timesteps[step]) | |
| return x | |
| class DistillEulerSolver(EulerSolver): | |
| def __init__( | |
| self, | |
| model: torch.nn.Module, | |
| func_name: str = "forward_fm_decoder", | |
| ): | |
| """Construct a Euler Solver for distilled diffusion models. | |
| Args: | |
| model: The diffusion model. | |
| """ | |
| self.model = DistillDiffusionModel(model, func_name=func_name) | |
| def get_time_steps( | |
| t_start: float = 0.0, | |
| t_end: float = 1.0, | |
| num_step: int = 10, | |
| t_shift: float = 1.0, | |
| device: torch.device = torch.device("cpu"), | |
| ) -> torch.Tensor: | |
| """Compute the intermediate time steps for sampling. | |
| Args: | |
| t_start: The starting time of the sampling (default is 0). | |
| t_end: The starting time of the sampling (default is 1). | |
| num_step: The number of sampling. | |
| t_shift: shift the t toward smaller numbers so that the sampling | |
| will emphasize low SNR region. Should be in the range of (0, 1]. | |
| The shifting will be more significant when the number is smaller. | |
| device: A torch device. | |
| Returns: | |
| The time step with the shape (num_step + 1,). | |
| """ | |
| timesteps = torch.linspace(t_start, t_end, num_step + 1).to(device) | |
| timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps) | |
| return timesteps | |