|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class ModLN(nn.Module): |
|
|
""" |
|
|
Modulation with adaLN. |
|
|
|
|
|
References: |
|
|
DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L101 |
|
|
""" |
|
|
def __init__(self, inner_dim: int, mod_dim: int, eps: float): |
|
|
super().__init__() |
|
|
self.norm = nn.LayerNorm(inner_dim, eps=eps) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.SiLU(), |
|
|
nn.Linear(mod_dim, inner_dim * 2), |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def modulate(x, shift, scale): |
|
|
|
|
|
|
|
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|
|
|
|
|
def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: |
|
|
shift, scale = self.mlp(mod).chunk(2, dim=-1) |
|
|
return self.modulate(self.norm(x), shift, scale) |
|
|
|