Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Optional, Tuple | |
| import torch | |
| from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, SD35AdaLayerNormZeroX | |
| class TruncAdaLayerNorm(AdaLayerNorm): | |
| def forward( | |
| self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None | |
| ) -> torch.Tensor: | |
| batch_size = x.shape[0] | |
| return self.forward_old( | |
| x, | |
| temb[:batch_size] if temb is not None else None, | |
| ) | |
| class TruncAdaLayerNormContinuous(AdaLayerNormContinuous): | |
| def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: | |
| batch_size = x.shape[0] | |
| return self.forward_old(x, conditioning_embedding[:batch_size]) | |
| class TruncAdaLayerNormZero(AdaLayerNormZero): | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| timestep: Optional[torch.Tensor] = None, | |
| class_labels: Optional[torch.LongTensor] = None, | |
| hidden_dtype: Optional[torch.dtype] = None, | |
| emb: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| batch_size = x.shape[0] | |
| return self.forward_old( | |
| x, | |
| timestep[:batch_size] if timestep is not None else None, | |
| class_labels[:batch_size] if class_labels is not None else None, | |
| hidden_dtype, | |
| emb[:batch_size] if emb is not None else None, | |
| ) | |
| class TruncSD35AdaLayerNormZeroX(SD35AdaLayerNormZeroX): | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| emb: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, ...]: | |
| batch_size = hidden_states.shape[0] | |
| return self.forward_old( | |
| hidden_states, | |
| emb[:batch_size] if emb is not None else None, | |
| ) | |