| from typing import Optional, Union | |
| import torch | |
| import diffusers | |
| class LCMScheduler(diffusers.schedulers.LCMScheduler): | |
| def __init__(self, timesteps_step_map: Optional[dict] = None, **kwargs) -> None: | |
| super(LCMScheduler, self).__init__(**kwargs) | |
| self.timesteps_step_map = timesteps_step_map | |
| def set_timesteps(self, num_inference_steps: Optional[int] = None, | |
| device: Union[str, torch.device] = None, **kwargs) -> None: | |
| if self.timesteps_step_map is None: | |
| super().set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs) | |
| else: | |
| assert num_inference_steps is not None | |
| self.num_inference_steps = num_inference_steps | |
| timesteps = self.timesteps_step_map[num_inference_steps] | |
| assert all([timestep < self.config.num_train_timesteps for timestep in timesteps]) | |
| self.timesteps = torch.tensor(timesteps).to(device=device, dtype=torch.long) | |
| self._step_index = None | |