|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import importlib.util |
|
|
import math |
|
|
from collections.abc import Sequence |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from monai.networks.blocks import Convolution, MLPBlock |
|
|
from monai.networks.layers.factories import Pool |
|
|
from monai.utils import ensure_tuple_rep |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
if importlib.util.find_spec("xformers") is not None: |
|
|
import xformers |
|
|
import xformers.ops |
|
|
|
|
|
has_xformers = True |
|
|
else: |
|
|
xformers = None |
|
|
has_xformers = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["DiffusionModelUNet"] |
|
|
|
|
|
|
|
|
def zero_module(module: nn.Module) -> nn.Module: |
|
|
""" |
|
|
Zero out the parameters of a module and return it. |
|
|
""" |
|
|
for p in module.parameters(): |
|
|
p.detach().zero_() |
|
|
return module |
|
|
|
|
|
|
|
|
class CrossAttention(nn.Module): |
|
|
""" |
|
|
A cross attention layer. |
|
|
|
|
|
Args: |
|
|
query_dim: number of channels in the query. |
|
|
cross_attention_dim: number of channels in the context. |
|
|
num_attention_heads: number of heads to use for multi-head attention. |
|
|
num_head_channels: number of channels in each head. |
|
|
dropout: dropout probability to use. |
|
|
upcast_attention: if True, upcast attention operations to full precision. |
|
|
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
query_dim: int, |
|
|
cross_attention_dim: int | None = None, |
|
|
num_attention_heads: int = 8, |
|
|
num_head_channels: int = 64, |
|
|
dropout: float = 0.0, |
|
|
upcast_attention: bool = False, |
|
|
use_flash_attention: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.use_flash_attention = use_flash_attention |
|
|
inner_dim = num_head_channels * num_attention_heads |
|
|
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim |
|
|
|
|
|
self.scale = 1 / math.sqrt(num_head_channels) |
|
|
self.num_heads = num_attention_heads |
|
|
|
|
|
self.upcast_attention = upcast_attention |
|
|
|
|
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) |
|
|
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) |
|
|
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) |
|
|
|
|
|
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) |
|
|
|
|
|
def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. |
|
|
""" |
|
|
batch_size, seq_len, dim = x.shape |
|
|
x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) |
|
|
x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) |
|
|
return x |
|
|
|
|
|
def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Combine the output of the attention heads back into the hidden state dimension.""" |
|
|
batch_size, seq_len, dim = x.shape |
|
|
x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) |
|
|
x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) |
|
|
return x |
|
|
|
|
|
def _memory_efficient_attention_xformers( |
|
|
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
query = query.contiguous() |
|
|
key = key.contiguous() |
|
|
value = value.contiguous() |
|
|
x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) |
|
|
return x |
|
|
|
|
|
def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: |
|
|
dtype = query.dtype |
|
|
if self.upcast_attention: |
|
|
query = query.float() |
|
|
key = key.float() |
|
|
|
|
|
attention_scores = torch.baddbmm( |
|
|
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), |
|
|
query, |
|
|
key.transpose(-1, -2), |
|
|
beta=0, |
|
|
alpha=self.scale, |
|
|
) |
|
|
attention_probs = attention_scores.softmax(dim=-1) |
|
|
attention_probs = attention_probs.to(dtype=dtype) |
|
|
|
|
|
x = torch.bmm(attention_probs, value) |
|
|
return x |
|
|
|
|
|
def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: |
|
|
query = self.to_q(x) |
|
|
context = context if context is not None else x |
|
|
key = self.to_k(context) |
|
|
value = self.to_v(context) |
|
|
|
|
|
|
|
|
query = self.reshape_heads_to_batch_dim(query) |
|
|
key = self.reshape_heads_to_batch_dim(key) |
|
|
value = self.reshape_heads_to_batch_dim(value) |
|
|
|
|
|
if self.use_flash_attention: |
|
|
x = self._memory_efficient_attention_xformers(query, key, value) |
|
|
else: |
|
|
x = self._attention(query, key, value) |
|
|
|
|
|
x = self.reshape_batch_dim_to_heads(x) |
|
|
x = x.to(query.dtype) |
|
|
|
|
|
return self.to_out(x) |
|
|
|
|
|
|
|
|
class BasicTransformerBlock(nn.Module): |
|
|
""" |
|
|
A basic Transformer block. |
|
|
|
|
|
Args: |
|
|
num_channels: number of channels in the input and output. |
|
|
num_attention_heads: number of heads to use for multi-head attention. |
|
|
num_head_channels: number of channels in each attention head. |
|
|
dropout: dropout probability to use. |
|
|
cross_attention_dim: size of the context vector for cross attention. |
|
|
upcast_attention: if True, upcast attention operations to full precision. |
|
|
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_channels: int, |
|
|
num_attention_heads: int, |
|
|
num_head_channels: int, |
|
|
dropout: float = 0.0, |
|
|
cross_attention_dim: int | None = None, |
|
|
upcast_attention: bool = False, |
|
|
use_flash_attention: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.attn1 = CrossAttention( |
|
|
query_dim=num_channels, |
|
|
num_attention_heads=num_attention_heads, |
|
|
num_head_channels=num_head_channels, |
|
|
dropout=dropout, |
|
|
upcast_attention=upcast_attention, |
|
|
use_flash_attention=use_flash_attention, |
|
|
) |
|
|
self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) |
|
|
self.attn2 = CrossAttention( |
|
|
query_dim=num_channels, |
|
|
cross_attention_dim=cross_attention_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
num_head_channels=num_head_channels, |
|
|
dropout=dropout, |
|
|
upcast_attention=upcast_attention, |
|
|
use_flash_attention=use_flash_attention, |
|
|
) |
|
|
self.norm1 = nn.LayerNorm(num_channels) |
|
|
self.norm2 = nn.LayerNorm(num_channels) |
|
|
self.norm3 = nn.LayerNorm(num_channels) |
|
|
|
|
|
def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: |
|
|
|
|
|
x = self.attn1(self.norm1(x)) + x |
|
|
|
|
|
|
|
|
x = self.attn2(self.norm2(x), context=context) + x |
|
|
|
|
|
|
|
|
x = self.ff(self.norm3(x)) + x |
|
|
return x |
|
|
|
|
|
|
|
|
class SpatialTransformer(nn.Module): |
|
|
""" |
|
|
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply |
|
|
standard transformer action. Finally, reshape to image. |
|
|
|
|
|
Args: |
|
|
spatial_dims: number of spatial dimensions. |
|
|
in_channels: number of channels in the input and output. |
|
|
num_attention_heads: number of heads to use for multi-head attention. |
|
|
num_head_channels: number of channels in each attention head. |
|
|
num_layers: number of layers of Transformer blocks to use. |
|
|
dropout: dropout probability to use. |
|
|
norm_num_groups: number of groups for the normalization. |
|
|
norm_eps: epsilon for the normalization. |
|
|
cross_attention_dim: number of context dimensions to use. |
|
|
upcast_attention: if True, upcast attention operations to full precision. |
|
|
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
spatial_dims: int, |
|
|
in_channels: int, |
|
|
num_attention_heads: int, |
|
|
num_head_channels: int, |
|
|
num_layers: int = 1, |
|
|
dropout: float = 0.0, |
|
|
norm_num_groups: int = 32, |
|
|
norm_eps: float = 1e-6, |
|
|
cross_attention_dim: int | None = None, |
|
|
upcast_attention: bool = False, |
|
|
use_flash_attention: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.spatial_dims = spatial_dims |
|
|
self.in_channels = in_channels |
|
|
inner_dim = num_attention_heads * num_head_channels |
|
|
|
|
|
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) |
|
|
|
|
|
self.proj_in = Convolution( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
out_channels=inner_dim, |
|
|
strides=1, |
|
|
kernel_size=1, |
|
|
padding=0, |
|
|
conv_only=True, |
|
|
) |
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
|
[ |
|
|
BasicTransformerBlock( |
|
|
num_channels=inner_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
num_head_channels=num_head_channels, |
|
|
dropout=dropout, |
|
|
cross_attention_dim=cross_attention_dim, |
|
|
upcast_attention=upcast_attention, |
|
|
use_flash_attention=use_flash_attention, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.proj_out = zero_module( |
|
|
Convolution( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=inner_dim, |
|
|
out_channels=in_channels, |
|
|
strides=1, |
|
|
kernel_size=1, |
|
|
padding=0, |
|
|
conv_only=True, |
|
|
) |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: |
|
|
|
|
|
batch = channel = height = width = depth = -1 |
|
|
if self.spatial_dims == 2: |
|
|
batch, channel, height, width = x.shape |
|
|
if self.spatial_dims == 3: |
|
|
batch, channel, height, width, depth = x.shape |
|
|
|
|
|
residual = x |
|
|
x = self.norm(x) |
|
|
x = self.proj_in(x) |
|
|
|
|
|
inner_dim = x.shape[1] |
|
|
|
|
|
if self.spatial_dims == 2: |
|
|
x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) |
|
|
if self.spatial_dims == 3: |
|
|
x = x.permute(0, 2, 3, 4, 1).reshape(batch, height * width * depth, inner_dim) |
|
|
|
|
|
for block in self.transformer_blocks: |
|
|
x = block(x, context=context) |
|
|
|
|
|
if self.spatial_dims == 2: |
|
|
x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() |
|
|
if self.spatial_dims == 3: |
|
|
x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous() |
|
|
|
|
|
x = self.proj_out(x) |
|
|
return x + residual |
|
|
|
|
|
|
|
|
class AttentionBlock(nn.Module): |
|
|
""" |
|
|
An attention block that allows spatial positions to attend to each other. Uses three q, k, v linear layers to |
|
|
compute attention. |
|
|
|
|
|
Args: |
|
|
spatial_dims: number of spatial dimensions. |
|
|
num_channels: number of input channels. |
|
|
num_head_channels: number of channels in each attention head. |
|
|
norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of |
|
|
channels is divisible by this number. |
|
|
norm_eps: epsilon value to use for the normalisation. |
|
|
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
spatial_dims: int, |
|
|
num_channels: int, |
|
|
num_head_channels: int | None = None, |
|
|
norm_num_groups: int = 32, |
|
|
norm_eps: float = 1e-6, |
|
|
use_flash_attention: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.use_flash_attention = use_flash_attention |
|
|
self.spatial_dims = spatial_dims |
|
|
self.num_channels = num_channels |
|
|
|
|
|
self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 |
|
|
self.scale = 1 / math.sqrt(num_channels / self.num_heads) |
|
|
|
|
|
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) |
|
|
|
|
|
self.to_q = nn.Linear(num_channels, num_channels) |
|
|
self.to_k = nn.Linear(num_channels, num_channels) |
|
|
self.to_v = nn.Linear(num_channels, num_channels) |
|
|
|
|
|
self.proj_attn = nn.Linear(num_channels, num_channels) |
|
|
|
|
|
def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: |
|
|
batch_size, seq_len, dim = x.shape |
|
|
x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) |
|
|
x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) |
|
|
return x |
|
|
|
|
|
def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: |
|
|
batch_size, seq_len, dim = x.shape |
|
|
x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) |
|
|
x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) |
|
|
return x |
|
|
|
|
|
def _memory_efficient_attention_xformers( |
|
|
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
query = query.contiguous() |
|
|
key = key.contiguous() |
|
|
value = value.contiguous() |
|
|
x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) |
|
|
return x |
|
|
|
|
|
def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: |
|
|
attention_scores = torch.baddbmm( |
|
|
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), |
|
|
query, |
|
|
key.transpose(-1, -2), |
|
|
beta=0, |
|
|
alpha=self.scale, |
|
|
) |
|
|
attention_probs = attention_scores.softmax(dim=-1) |
|
|
x = torch.bmm(attention_probs, value) |
|
|
return x |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
residual = x |
|
|
|
|
|
batch = channel = height = width = depth = -1 |
|
|
if self.spatial_dims == 2: |
|
|
batch, channel, height, width = x.shape |
|
|
if self.spatial_dims == 3: |
|
|
batch, channel, height, width, depth = x.shape |
|
|
|
|
|
|
|
|
x = self.norm(x) |
|
|
|
|
|
if self.spatial_dims == 2: |
|
|
x = x.view(batch, channel, height * width).transpose(1, 2) |
|
|
if self.spatial_dims == 3: |
|
|
x = x.view(batch, channel, height * width * depth).transpose(1, 2) |
|
|
|
|
|
|
|
|
query = self.to_q(x) |
|
|
key = self.to_k(x) |
|
|
value = self.to_v(x) |
|
|
|
|
|
|
|
|
query = self.reshape_heads_to_batch_dim(query) |
|
|
key = self.reshape_heads_to_batch_dim(key) |
|
|
value = self.reshape_heads_to_batch_dim(value) |
|
|
|
|
|
if self.use_flash_attention: |
|
|
x = self._memory_efficient_attention_xformers(query, key, value) |
|
|
else: |
|
|
x = self._attention(query, key, value) |
|
|
|
|
|
x = self.reshape_batch_dim_to_heads(x) |
|
|
x = x.to(query.dtype) |
|
|
|
|
|
if self.spatial_dims == 2: |
|
|
x = x.transpose(-1, -2).reshape(batch, channel, height, width) |
|
|
if self.spatial_dims == 3: |
|
|
x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) |
|
|
|
|
|
return x + residual |
|
|
|
|
|
|
|
|
def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor: |
|
|
""" |
|
|
Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic |
|
|
Models" https://arxiv.org/abs/2006.11239. |
|
|
|
|
|
Args: |
|
|
timesteps: a 1-D Tensor of N indices, one per batch element. |
|
|
embedding_dim: the dimension of the output. |
|
|
max_period: controls the minimum frequency of the embeddings. |
|
|
""" |
|
|
if timesteps.ndim != 1: |
|
|
raise ValueError("Timesteps should be a 1d-array") |
|
|
|
|
|
half_dim = embedding_dim // 2 |
|
|
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) |
|
|
freqs = torch.exp(exponent / half_dim) |
|
|
|
|
|
args = timesteps[:, None].float() * freqs[None, :] |
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
|
|
|
|
|
|
|
if embedding_dim % 2 == 1: |
|
|
embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0)) |
|
|
|
|
|
return embedding |
|
|
|
|
|
|
|
|
class Downsample(nn.Module): |
|
|
""" |
|
|
Downsampling layer. |
|
|
|
|
|
Args: |
|
|
spatial_dims: number of spatial dimensions. |
|
|
num_channels: number of input channels. |
|
|
use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is |
|
|
False, the number of output channels must be the same as the number of input channels. |
|
|
out_channels: number of output channels. |
|
|
padding: controls the amount of implicit zero-paddings on both sides for padding number of points |
|
|
for each dimension. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.num_channels = num_channels |
|
|
self.out_channels = out_channels or num_channels |
|
|
self.use_conv = use_conv |
|
|
if use_conv: |
|
|
self.op = Convolution( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=self.num_channels, |
|
|
out_channels=self.out_channels, |
|
|
strides=2, |
|
|
kernel_size=3, |
|
|
padding=padding, |
|
|
conv_only=True, |
|
|
) |
|
|
else: |
|
|
if self.num_channels != self.out_channels: |
|
|
raise ValueError("num_channels and out_channels must be equal when use_conv=False") |
|
|
self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2) |
|
|
|
|
|
def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: |
|
|
del emb |
|
|
if x.shape[1] != self.num_channels: |
|
|
raise ValueError( |
|
|
f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels " |
|
|
f"({self.num_channels})" |
|
|
) |
|
|
return self.op(x) |
|
|
|
|
|
|
|
|
class Upsample(nn.Module): |
|
|
""" |
|
|
Upsampling layer with an optional convolution. |
|
|
|
|
|
Args: |
|
|
spatial_dims: number of spatial dimensions. |
|
|
num_channels: number of input channels. |
|
|
use_conv: if True uses Convolution instead of Pool average to perform downsampling. |
|
|
out_channels: number of output channels. |
|
|
padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each |
|
|
dimension. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.num_channels = num_channels |
|
|
self.out_channels = out_channels or num_channels |
|
|
self.use_conv = use_conv |
|
|
if use_conv: |
|
|
self.conv = Convolution( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=self.num_channels, |
|
|
out_channels=self.out_channels, |
|
|
strides=1, |
|
|
kernel_size=3, |
|
|
padding=padding, |
|
|
conv_only=True, |
|
|
) |
|
|
else: |
|
|
self.conv = None |
|
|
|
|
|
def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: |
|
|
del emb |
|
|
if x.shape[1] != self.num_channels: |
|
|
raise ValueError("Input channels should be equal to num_channels") |
|
|
|
|
|
|
|
|
|
|
|
dtype = x.dtype |
|
|
if dtype == torch.bfloat16: |
|
|
x = x.to(torch.float32) |
|
|
|
|
|
x = F.interpolate(x, scale_factor=2.0, mode="nearest") |
|
|
|
|
|
|
|
|
if dtype == torch.bfloat16: |
|
|
x = x.to(dtype) |
|
|
|
|
|
if self.use_conv: |
|
|
x = self.conv(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class ResnetBlock(nn.Module): |
|
|
""" |
|
|
Residual block with timestep conditioning. |
|
|
|
|
|
Args: |
|
|
spatial_dims: The number of spatial dimensions. |
|
|
in_channels: number of input channels. |
|
|
temb_channels: number of timestep embedding channels. |
|
|
out_channels: number of output channels. |
|
|
up: if True, performs upsampling. |
|
|
down: if True, performs downsampling. |
|
|
norm_num_groups: number of groups for the group normalization. |
|
|
norm_eps: epsilon for the group normalization. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
spatial_dims: int, |
|
|
in_channels: int, |
|
|
temb_channels: int, |
|
|
out_channels: int | None = None, |
|
|
up: bool = False, |
|
|
down: bool = False, |
|
|
norm_num_groups: int = 32, |
|
|
norm_eps: float = 1e-6, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.spatial_dims = spatial_dims |
|
|
self.channels = in_channels |
|
|
self.emb_channels = temb_channels |
|
|
self.out_channels = out_channels or in_channels |
|
|
self.up = up |
|
|
self.down = down |
|
|
|
|
|
self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) |
|
|
self.nonlinearity = nn.SiLU() |
|
|
self.conv1 = Convolution( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
out_channels=self.out_channels, |
|
|
strides=1, |
|
|
kernel_size=3, |
|
|
padding=1, |
|
|
conv_only=True, |
|
|
dilation=3 |
|
|
) |
|
|
|
|
|
self.upsample = self.downsample = None |
|
|
if self.up: |
|
|
self.upsample = Upsample(spatial_dims, in_channels, use_conv=False) |
|
|
elif down: |
|
|
self.downsample = Downsample(spatial_dims, in_channels, use_conv=False) |
|
|
|
|
|
self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) |
|
|
|
|
|
self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.out_channels, eps=norm_eps, affine=True) |
|
|
self.conv2 = zero_module( |
|
|
Convolution( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=self.out_channels, |
|
|
out_channels=self.out_channels, |
|
|
strides=1, |
|
|
kernel_size=3, |
|
|
padding=1, |
|
|
conv_only=True, |
|
|
dilation=2 |
|
|
) |
|
|
) |
|
|
|
|
|
if self.out_channels == in_channels: |
|
|
self.skip_connection = nn.Identity() |
|
|
else: |
|
|
self.skip_connection = Convolution( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
out_channels=self.out_channels, |
|
|
strides=1, |
|
|
kernel_size=1, |
|
|
padding=0, |
|
|
conv_only=True, |
|
|
dilation=1 |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: |
|
|
h = x |
|
|
h = self.norm1(h) |
|
|
h = self.nonlinearity(h) |
|
|
|
|
|
if self.upsample is not None: |
|
|
if h.shape[0] >= 64: |
|
|
x = x.contiguous() |
|
|
h = h.contiguous() |
|
|
x = self.upsample(x) |
|
|
h = self.upsample(h) |
|
|
elif self.downsample is not None: |
|
|
x = self.downsample(x) |
|
|
h = self.downsample(h) |
|
|
|
|
|
h = self.conv1(h) |
|
|
|
|
|
if self.spatial_dims == 2: |
|
|
temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] |
|
|
else: |
|
|
temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] |
|
|
h = h + temb |
|
|
|
|
|
h = self.norm2(h) |
|
|
h = self.nonlinearity(h) |
|
|
h = self.conv2(h) |
|
|
|
|
|
return self.skip_connection(x) + h |
|
|
|
|
|
|
|
|
class DownBlock(nn.Module): |
|
|
""" |
|
|
Unet's down block containing resnet and downsamplers blocks. |
|
|
|
|
|
Args: |
|
|
spatial_dims: The number of spatial dimensions. |
|
|
in_channels: number of input channels. |
|
|
out_channels: number of output channels. |
|
|
temb_channels: number of timestep embedding channels. |
|
|
num_res_blocks: number of residual blocks. |
|
|
norm_num_groups: number of groups for the group normalization. |
|
|
norm_eps: epsilon for the group normalization. |
|
|
add_downsample: if True add downsample block. |
|
|
resblock_updown: if True use residual blocks for downsampling. |
|
|
downsample_padding: padding used in the downsampling block. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
spatial_dims: int, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
temb_channels: int, |
|
|
num_res_blocks: int = 1, |
|
|
norm_num_groups: int = 32, |
|
|
norm_eps: float = 1e-6, |
|
|
add_downsample: bool = True, |
|
|
resblock_updown: bool = False, |
|
|
downsample_padding: int = 1, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.resblock_updown = resblock_updown |
|
|
|
|
|
resnets = [] |
|
|
|
|
|
for i in range(num_res_blocks): |
|
|
in_channels = in_channels if i == 0 else out_channels |
|
|
resnets.append( |
|
|
ResnetBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
) |
|
|
|
|
|
self.resnets = nn.ModuleList(resnets) |
|
|
|
|
|
if add_downsample: |
|
|
if resblock_updown: |
|
|
self.downsampler = ResnetBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=out_channels, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
down=True, |
|
|
) |
|
|
else: |
|
|
self.downsampler = Downsample( |
|
|
spatial_dims=spatial_dims, |
|
|
num_channels=out_channels, |
|
|
use_conv=True, |
|
|
out_channels=out_channels, |
|
|
padding=downsample_padding, |
|
|
) |
|
|
else: |
|
|
self.downsampler = None |
|
|
|
|
|
def forward( |
|
|
self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None |
|
|
) -> tuple[torch.Tensor, list[torch.Tensor]]: |
|
|
del context |
|
|
output_states = [] |
|
|
|
|
|
for resnet in self.resnets: |
|
|
hidden_states = resnet(hidden_states, temb) |
|
|
output_states.append(hidden_states) |
|
|
|
|
|
if self.downsampler is not None: |
|
|
hidden_states = self.downsampler(hidden_states, temb) |
|
|
output_states.append(hidden_states) |
|
|
|
|
|
return hidden_states, output_states |
|
|
|
|
|
|
|
|
class AttnDownBlock(nn.Module): |
|
|
""" |
|
|
Unet's down block containing resnet, downsamplers and self-attention blocks. |
|
|
|
|
|
Args: |
|
|
spatial_dims: The number of spatial dimensions. |
|
|
in_channels: number of input channels. |
|
|
out_channels: number of output channels. |
|
|
temb_channels: number of timestep embedding channels. |
|
|
num_res_blocks: number of residual blocks. |
|
|
norm_num_groups: number of groups for the group normalization. |
|
|
norm_eps: epsilon for the group normalization. |
|
|
add_downsample: if True add downsample block. |
|
|
resblock_updown: if True use residual blocks for downsampling. |
|
|
downsample_padding: padding used in the downsampling block. |
|
|
num_head_channels: number of channels in each attention head. |
|
|
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
spatial_dims: int, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
temb_channels: int, |
|
|
num_res_blocks: int = 1, |
|
|
norm_num_groups: int = 32, |
|
|
norm_eps: float = 1e-6, |
|
|
add_downsample: bool = True, |
|
|
resblock_updown: bool = False, |
|
|
downsample_padding: int = 1, |
|
|
num_head_channels: int = 1, |
|
|
use_flash_attention: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.resblock_updown = resblock_updown |
|
|
|
|
|
resnets = [] |
|
|
attentions = [] |
|
|
|
|
|
for i in range(num_res_blocks): |
|
|
in_channels = in_channels if i == 0 else out_channels |
|
|
resnets.append( |
|
|
ResnetBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
) |
|
|
attentions.append( |
|
|
AttentionBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
num_channels=out_channels, |
|
|
num_head_channels=num_head_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
use_flash_attention=use_flash_attention, |
|
|
) |
|
|
) |
|
|
|
|
|
self.attentions = nn.ModuleList(attentions) |
|
|
self.resnets = nn.ModuleList(resnets) |
|
|
|
|
|
if add_downsample: |
|
|
if resblock_updown: |
|
|
self.downsampler = ResnetBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=out_channels, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
down=True, |
|
|
) |
|
|
else: |
|
|
self.downsampler = Downsample( |
|
|
spatial_dims=spatial_dims, |
|
|
num_channels=out_channels, |
|
|
use_conv=True, |
|
|
out_channels=out_channels, |
|
|
padding=downsample_padding, |
|
|
) |
|
|
else: |
|
|
self.downsampler = None |
|
|
|
|
|
def forward( |
|
|
self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None |
|
|
) -> tuple[torch.Tensor, list[torch.Tensor]]: |
|
|
del context |
|
|
output_states = [] |
|
|
|
|
|
for resnet, attn in zip(self.resnets, self.attentions): |
|
|
hidden_states = resnet(hidden_states, temb) |
|
|
hidden_states = attn(hidden_states) |
|
|
output_states.append(hidden_states) |
|
|
|
|
|
if self.downsampler is not None: |
|
|
hidden_states = self.downsampler(hidden_states, temb) |
|
|
output_states.append(hidden_states) |
|
|
|
|
|
return hidden_states, output_states |
|
|
|
|
|
|
|
|
class CrossAttnDownBlock(nn.Module): |
|
|
""" |
|
|
Unet's down block containing resnet, downsamplers and cross-attention blocks. |
|
|
|
|
|
Args: |
|
|
spatial_dims: number of spatial dimensions. |
|
|
in_channels: number of input channels. |
|
|
out_channels: number of output channels. |
|
|
temb_channels: number of timestep embedding channels. |
|
|
num_res_blocks: number of residual blocks. |
|
|
norm_num_groups: number of groups for the group normalization. |
|
|
norm_eps: epsilon for the group normalization. |
|
|
add_downsample: if True add downsample block. |
|
|
resblock_updown: if True use residual blocks for downsampling. |
|
|
downsample_padding: padding used in the downsampling block. |
|
|
num_head_channels: number of channels in each attention head. |
|
|
transformer_num_layers: number of layers of Transformer blocks to use. |
|
|
cross_attention_dim: number of context dimensions to use. |
|
|
upcast_attention: if True, upcast attention operations to full precision. |
|
|
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
spatial_dims: int, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
temb_channels: int, |
|
|
num_res_blocks: int = 1, |
|
|
norm_num_groups: int = 32, |
|
|
norm_eps: float = 1e-6, |
|
|
add_downsample: bool = True, |
|
|
resblock_updown: bool = False, |
|
|
downsample_padding: int = 1, |
|
|
num_head_channels: int = 1, |
|
|
transformer_num_layers: int = 1, |
|
|
cross_attention_dim: int | None = None, |
|
|
upcast_attention: bool = False, |
|
|
use_flash_attention: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.resblock_updown = resblock_updown |
|
|
|
|
|
resnets = [] |
|
|
attentions = [] |
|
|
|
|
|
for i in range(num_res_blocks): |
|
|
in_channels = in_channels if i == 0 else out_channels |
|
|
resnets.append( |
|
|
ResnetBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
) |
|
|
|
|
|
attentions.append( |
|
|
SpatialTransformer( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=out_channels, |
|
|
num_attention_heads=out_channels // num_head_channels, |
|
|
num_head_channels=num_head_channels, |
|
|
num_layers=transformer_num_layers, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
cross_attention_dim=cross_attention_dim, |
|
|
upcast_attention=upcast_attention, |
|
|
use_flash_attention=use_flash_attention, |
|
|
) |
|
|
) |
|
|
|
|
|
self.attentions = nn.ModuleList(attentions) |
|
|
self.resnets = nn.ModuleList(resnets) |
|
|
|
|
|
if add_downsample: |
|
|
if resblock_updown: |
|
|
self.downsampler = ResnetBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=out_channels, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
down=True, |
|
|
) |
|
|
else: |
|
|
self.downsampler = Downsample( |
|
|
spatial_dims=spatial_dims, |
|
|
num_channels=out_channels, |
|
|
use_conv=True, |
|
|
out_channels=out_channels, |
|
|
padding=downsample_padding, |
|
|
) |
|
|
else: |
|
|
self.downsampler = None |
|
|
|
|
|
def forward( |
|
|
self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None |
|
|
) -> tuple[torch.Tensor, list[torch.Tensor]]: |
|
|
output_states = [] |
|
|
|
|
|
for resnet, attn in zip(self.resnets, self.attentions): |
|
|
hidden_states = resnet(hidden_states, temb) |
|
|
hidden_states = attn(hidden_states, context=context) |
|
|
output_states.append(hidden_states) |
|
|
|
|
|
if self.downsampler is not None: |
|
|
hidden_states = self.downsampler(hidden_states, temb) |
|
|
output_states.append(hidden_states) |
|
|
|
|
|
return hidden_states, output_states |
|
|
|
|
|
|
|
|
class AttnMidBlock(nn.Module): |
|
|
""" |
|
|
Unet's mid block containing resnet and self-attention blocks. |
|
|
|
|
|
Args: |
|
|
spatial_dims: The number of spatial dimensions. |
|
|
in_channels: number of input channels. |
|
|
temb_channels: number of timestep embedding channels. |
|
|
norm_num_groups: number of groups for the group normalization. |
|
|
norm_eps: epsilon for the group normalization. |
|
|
num_head_channels: number of channels in each attention head. |
|
|
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
spatial_dims: int, |
|
|
in_channels: int, |
|
|
temb_channels: int, |
|
|
norm_num_groups: int = 32, |
|
|
norm_eps: float = 1e-6, |
|
|
num_head_channels: int = 1, |
|
|
use_flash_attention: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.attention = None |
|
|
|
|
|
self.resnet_1 = ResnetBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
out_channels=in_channels, |
|
|
temb_channels=temb_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
self.attention = AttentionBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
num_channels=in_channels, |
|
|
num_head_channels=num_head_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
use_flash_attention=use_flash_attention, |
|
|
) |
|
|
|
|
|
self.resnet_2 = ResnetBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
out_channels=in_channels, |
|
|
temb_channels=temb_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None |
|
|
) -> torch.Tensor: |
|
|
del context |
|
|
hidden_states = self.resnet_1(hidden_states, temb) |
|
|
hidden_states = self.attention(hidden_states) |
|
|
hidden_states = self.resnet_2(hidden_states, temb) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class CrossAttnMidBlock(nn.Module): |
|
|
""" |
|
|
Unet's mid block containing resnet and cross-attention blocks. |
|
|
|
|
|
Args: |
|
|
spatial_dims: The number of spatial dimensions. |
|
|
in_channels: number of input channels. |
|
|
temb_channels: number of timestep embedding channels |
|
|
norm_num_groups: number of groups for the group normalization. |
|
|
norm_eps: epsilon for the group normalization. |
|
|
num_head_channels: number of channels in each attention head. |
|
|
transformer_num_layers: number of layers of Transformer blocks to use. |
|
|
cross_attention_dim: number of context dimensions to use. |
|
|
upcast_attention: if True, upcast attention operations to full precision. |
|
|
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
spatial_dims: int, |
|
|
in_channels: int, |
|
|
temb_channels: int, |
|
|
norm_num_groups: int = 32, |
|
|
norm_eps: float = 1e-6, |
|
|
num_head_channels: int = 1, |
|
|
transformer_num_layers: int = 1, |
|
|
cross_attention_dim: int | None = None, |
|
|
upcast_attention: bool = False, |
|
|
use_flash_attention: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.attention = None |
|
|
|
|
|
self.resnet_1 = ResnetBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
out_channels=in_channels, |
|
|
temb_channels=temb_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
self.attention = SpatialTransformer( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
num_attention_heads=in_channels // num_head_channels, |
|
|
num_head_channels=num_head_channels, |
|
|
num_layers=transformer_num_layers, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
cross_attention_dim=cross_attention_dim, |
|
|
upcast_attention=upcast_attention, |
|
|
use_flash_attention=use_flash_attention, |
|
|
) |
|
|
self.resnet_2 = ResnetBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
out_channels=in_channels, |
|
|
temb_channels=temb_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None |
|
|
) -> torch.Tensor: |
|
|
hidden_states = self.resnet_1(hidden_states, temb) |
|
|
hidden_states = self.attention(hidden_states, context=context) |
|
|
hidden_states = self.resnet_2(hidden_states, temb) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class UpBlock(nn.Module): |
|
|
""" |
|
|
Unet's up block containing resnet and upsamplers blocks. |
|
|
|
|
|
Args: |
|
|
spatial_dims: The number of spatial dimensions. |
|
|
in_channels: number of input channels. |
|
|
prev_output_channel: number of channels from residual connection. |
|
|
out_channels: number of output channels. |
|
|
temb_channels: number of timestep embedding channels. |
|
|
num_res_blocks: number of residual blocks. |
|
|
norm_num_groups: number of groups for the group normalization. |
|
|
norm_eps: epsilon for the group normalization. |
|
|
add_upsample: if True add downsample block. |
|
|
resblock_updown: if True use residual blocks for upsampling. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
spatial_dims: int, |
|
|
in_channels: int, |
|
|
prev_output_channel: int, |
|
|
out_channels: int, |
|
|
temb_channels: int, |
|
|
num_res_blocks: int = 1, |
|
|
norm_num_groups: int = 32, |
|
|
norm_eps: float = 1e-6, |
|
|
add_upsample: bool = True, |
|
|
resblock_updown: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.resblock_updown = resblock_updown |
|
|
resnets = [] |
|
|
|
|
|
for i in range(num_res_blocks): |
|
|
res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels |
|
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels |
|
|
|
|
|
resnets.append( |
|
|
ResnetBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=resnet_in_channels + res_skip_channels, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
) |
|
|
|
|
|
self.resnets = nn.ModuleList(resnets) |
|
|
|
|
|
if add_upsample: |
|
|
if resblock_updown: |
|
|
self.upsampler = ResnetBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=out_channels, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
up=True, |
|
|
) |
|
|
else: |
|
|
self.upsampler = Upsample( |
|
|
spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels |
|
|
) |
|
|
else: |
|
|
self.upsampler = None |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
res_hidden_states_list: list[torch.Tensor], |
|
|
temb: torch.Tensor, |
|
|
context: torch.Tensor | None = None, |
|
|
) -> torch.Tensor: |
|
|
del context |
|
|
for resnet in self.resnets: |
|
|
|
|
|
res_hidden_states = res_hidden_states_list[-1] |
|
|
res_hidden_states_list = res_hidden_states_list[:-1] |
|
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
|
|
|
|
|
hidden_states = resnet(hidden_states, temb) |
|
|
|
|
|
if self.upsampler is not None: |
|
|
hidden_states = self.upsampler(hidden_states, temb) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class AttnUpBlock(nn.Module): |
|
|
""" |
|
|
Unet's up block containing resnet, upsamplers, and self-attention blocks. |
|
|
|
|
|
Args: |
|
|
spatial_dims: The number of spatial dimensions. |
|
|
in_channels: number of input channels. |
|
|
prev_output_channel: number of channels from residual connection. |
|
|
out_channels: number of output channels. |
|
|
temb_channels: number of timestep embedding channels. |
|
|
num_res_blocks: number of residual blocks. |
|
|
norm_num_groups: number of groups for the group normalization. |
|
|
norm_eps: epsilon for the group normalization. |
|
|
add_upsample: if True add downsample block. |
|
|
resblock_updown: if True use residual blocks for upsampling. |
|
|
num_head_channels: number of channels in each attention head. |
|
|
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
spatial_dims: int, |
|
|
in_channels: int, |
|
|
prev_output_channel: int, |
|
|
out_channels: int, |
|
|
temb_channels: int, |
|
|
num_res_blocks: int = 1, |
|
|
norm_num_groups: int = 32, |
|
|
norm_eps: float = 1e-6, |
|
|
add_upsample: bool = True, |
|
|
resblock_updown: bool = False, |
|
|
num_head_channels: int = 1, |
|
|
use_flash_attention: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.resblock_updown = resblock_updown |
|
|
|
|
|
resnets = [] |
|
|
attentions = [] |
|
|
|
|
|
for i in range(num_res_blocks): |
|
|
res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels |
|
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels |
|
|
|
|
|
resnets.append( |
|
|
ResnetBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=resnet_in_channels + res_skip_channels, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
) |
|
|
attentions.append( |
|
|
AttentionBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
num_channels=out_channels, |
|
|
num_head_channels=num_head_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
use_flash_attention=use_flash_attention, |
|
|
) |
|
|
) |
|
|
|
|
|
self.resnets = nn.ModuleList(resnets) |
|
|
self.attentions = nn.ModuleList(attentions) |
|
|
|
|
|
if add_upsample: |
|
|
if resblock_updown: |
|
|
self.upsampler = ResnetBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=out_channels, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
up=True, |
|
|
) |
|
|
else: |
|
|
self.upsampler = Upsample( |
|
|
spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels |
|
|
) |
|
|
else: |
|
|
self.upsampler = None |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
res_hidden_states_list: list[torch.Tensor], |
|
|
temb: torch.Tensor, |
|
|
context: torch.Tensor | None = None, |
|
|
) -> torch.Tensor: |
|
|
del context |
|
|
for resnet, attn in zip(self.resnets, self.attentions): |
|
|
|
|
|
res_hidden_states = res_hidden_states_list[-1] |
|
|
res_hidden_states_list = res_hidden_states_list[:-1] |
|
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
|
|
|
|
|
hidden_states = resnet(hidden_states, temb) |
|
|
hidden_states = attn(hidden_states) |
|
|
|
|
|
if self.upsampler is not None: |
|
|
hidden_states = self.upsampler(hidden_states, temb) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class CrossAttnUpBlock(nn.Module): |
|
|
""" |
|
|
Unet's up block containing resnet, upsamplers, and self-attention blocks. |
|
|
|
|
|
Args: |
|
|
spatial_dims: The number of spatial dimensions. |
|
|
in_channels: number of input channels. |
|
|
prev_output_channel: number of channels from residual connection. |
|
|
out_channels: number of output channels. |
|
|
temb_channels: number of timestep embedding channels. |
|
|
num_res_blocks: number of residual blocks. |
|
|
norm_num_groups: number of groups for the group normalization. |
|
|
norm_eps: epsilon for the group normalization. |
|
|
add_upsample: if True add downsample block. |
|
|
resblock_updown: if True use residual blocks for upsampling. |
|
|
num_head_channels: number of channels in each attention head. |
|
|
transformer_num_layers: number of layers of Transformer blocks to use. |
|
|
cross_attention_dim: number of context dimensions to use. |
|
|
upcast_attention: if True, upcast attention operations to full precision. |
|
|
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
spatial_dims: int, |
|
|
in_channels: int, |
|
|
prev_output_channel: int, |
|
|
out_channels: int, |
|
|
temb_channels: int, |
|
|
num_res_blocks: int = 1, |
|
|
norm_num_groups: int = 32, |
|
|
norm_eps: float = 1e-6, |
|
|
add_upsample: bool = True, |
|
|
resblock_updown: bool = False, |
|
|
num_head_channels: int = 1, |
|
|
transformer_num_layers: int = 1, |
|
|
cross_attention_dim: int | None = None, |
|
|
upcast_attention: bool = False, |
|
|
use_flash_attention: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.resblock_updown = resblock_updown |
|
|
|
|
|
resnets = [] |
|
|
attentions = [] |
|
|
|
|
|
for i in range(num_res_blocks): |
|
|
res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels |
|
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels |
|
|
|
|
|
resnets.append( |
|
|
ResnetBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=resnet_in_channels + res_skip_channels, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
) |
|
|
attentions.append( |
|
|
SpatialTransformer( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=out_channels, |
|
|
num_attention_heads=out_channels // num_head_channels, |
|
|
num_head_channels=num_head_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
num_layers=transformer_num_layers, |
|
|
cross_attention_dim=cross_attention_dim, |
|
|
upcast_attention=upcast_attention, |
|
|
use_flash_attention=use_flash_attention, |
|
|
) |
|
|
) |
|
|
|
|
|
self.attentions = nn.ModuleList(attentions) |
|
|
self.resnets = nn.ModuleList(resnets) |
|
|
|
|
|
if add_upsample: |
|
|
if resblock_updown: |
|
|
self.upsampler = ResnetBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=out_channels, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
up=True, |
|
|
) |
|
|
else: |
|
|
self.upsampler = Upsample( |
|
|
spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels |
|
|
) |
|
|
else: |
|
|
self.upsampler = None |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
res_hidden_states_list: list[torch.Tensor], |
|
|
temb: torch.Tensor, |
|
|
context: torch.Tensor | None = None, |
|
|
) -> torch.Tensor: |
|
|
for resnet, attn in zip(self.resnets, self.attentions): |
|
|
|
|
|
res_hidden_states = res_hidden_states_list[-1] |
|
|
res_hidden_states_list = res_hidden_states_list[:-1] |
|
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
|
|
|
|
|
hidden_states = resnet(hidden_states, temb) |
|
|
hidden_states = attn(hidden_states, context=context) |
|
|
|
|
|
if self.upsampler is not None: |
|
|
hidden_states = self.upsampler(hidden_states, temb) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
def get_down_block( |
|
|
spatial_dims: int, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
temb_channels: int, |
|
|
num_res_blocks: int, |
|
|
norm_num_groups: int, |
|
|
norm_eps: float, |
|
|
add_downsample: bool, |
|
|
resblock_updown: bool, |
|
|
with_attn: bool, |
|
|
with_cross_attn: bool, |
|
|
num_head_channels: int, |
|
|
transformer_num_layers: int, |
|
|
cross_attention_dim: int | None, |
|
|
upcast_attention: bool = False, |
|
|
use_flash_attention: bool = False, |
|
|
) -> nn.Module: |
|
|
if with_attn: |
|
|
return AttnDownBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
num_res_blocks=num_res_blocks, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
add_downsample=add_downsample, |
|
|
resblock_updown=resblock_updown, |
|
|
num_head_channels=num_head_channels, |
|
|
use_flash_attention=use_flash_attention, |
|
|
) |
|
|
elif with_cross_attn: |
|
|
return CrossAttnDownBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
num_res_blocks=num_res_blocks, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
add_downsample=add_downsample, |
|
|
resblock_updown=resblock_updown, |
|
|
num_head_channels=num_head_channels, |
|
|
transformer_num_layers=transformer_num_layers, |
|
|
cross_attention_dim=cross_attention_dim, |
|
|
upcast_attention=upcast_attention, |
|
|
use_flash_attention=use_flash_attention, |
|
|
) |
|
|
else: |
|
|
return DownBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
num_res_blocks=num_res_blocks, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
add_downsample=add_downsample, |
|
|
resblock_updown=resblock_updown, |
|
|
) |
|
|
|
|
|
|
|
|
def get_mid_block( |
|
|
spatial_dims: int, |
|
|
in_channels: int, |
|
|
temb_channels: int, |
|
|
norm_num_groups: int, |
|
|
norm_eps: float, |
|
|
with_conditioning: bool, |
|
|
num_head_channels: int, |
|
|
transformer_num_layers: int, |
|
|
cross_attention_dim: int | None, |
|
|
upcast_attention: bool = False, |
|
|
use_flash_attention: bool = False, |
|
|
) -> nn.Module: |
|
|
if with_conditioning: |
|
|
return CrossAttnMidBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
temb_channels=temb_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
num_head_channels=num_head_channels, |
|
|
transformer_num_layers=transformer_num_layers, |
|
|
cross_attention_dim=cross_attention_dim, |
|
|
upcast_attention=upcast_attention, |
|
|
use_flash_attention=use_flash_attention, |
|
|
) |
|
|
else: |
|
|
return AttnMidBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
temb_channels=temb_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
num_head_channels=num_head_channels, |
|
|
use_flash_attention=use_flash_attention, |
|
|
) |
|
|
|
|
|
|
|
|
def get_up_block( |
|
|
spatial_dims: int, |
|
|
in_channels: int, |
|
|
prev_output_channel: int, |
|
|
out_channels: int, |
|
|
temb_channels: int, |
|
|
num_res_blocks: int, |
|
|
norm_num_groups: int, |
|
|
norm_eps: float, |
|
|
add_upsample: bool, |
|
|
resblock_updown: bool, |
|
|
with_attn: bool, |
|
|
with_cross_attn: bool, |
|
|
num_head_channels: int, |
|
|
transformer_num_layers: int, |
|
|
cross_attention_dim: int | None, |
|
|
upcast_attention: bool = False, |
|
|
use_flash_attention: bool = False, |
|
|
) -> nn.Module: |
|
|
if with_attn: |
|
|
return AttnUpBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
prev_output_channel=prev_output_channel, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
num_res_blocks=num_res_blocks, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
add_upsample=add_upsample, |
|
|
resblock_updown=resblock_updown, |
|
|
num_head_channels=num_head_channels, |
|
|
use_flash_attention=use_flash_attention, |
|
|
) |
|
|
elif with_cross_attn: |
|
|
return CrossAttnUpBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
prev_output_channel=prev_output_channel, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
num_res_blocks=num_res_blocks, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
add_upsample=add_upsample, |
|
|
resblock_updown=resblock_updown, |
|
|
num_head_channels=num_head_channels, |
|
|
transformer_num_layers=transformer_num_layers, |
|
|
cross_attention_dim=cross_attention_dim, |
|
|
upcast_attention=upcast_attention, |
|
|
use_flash_attention=use_flash_attention, |
|
|
) |
|
|
else: |
|
|
return UpBlock( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
prev_output_channel=prev_output_channel, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
num_res_blocks=num_res_blocks, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
add_upsample=add_upsample, |
|
|
resblock_updown=resblock_updown, |
|
|
) |
|
|
|
|
|
|
|
|
class DiffusionModelUNet(nn.Module): |
|
|
""" |
|
|
Unet network with timestep embedding and attention mechanisms for conditioning based on |
|
|
Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 |
|
|
and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 |
|
|
|
|
|
Args: |
|
|
spatial_dims: number of spatial dimensions. |
|
|
in_channels: number of input channels. |
|
|
out_channels: number of output channels. |
|
|
num_res_blocks: number of residual blocks (see ResnetBlock) per level. |
|
|
num_channels: tuple of block output channels. |
|
|
attention_levels: list of levels to add attention. |
|
|
norm_num_groups: number of groups for the normalization. |
|
|
norm_eps: epsilon for the normalization. |
|
|
resblock_updown: if True use residual blocks for up/downsampling. |
|
|
num_head_channels: number of channels in each attention head. |
|
|
with_conditioning: if True add spatial transformers to perform conditioning. |
|
|
transformer_num_layers: number of layers of Transformer blocks to use. |
|
|
cross_attention_dim: number of context dimensions to use. |
|
|
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` |
|
|
classes. |
|
|
upcast_attention: if True, upcast attention operations to full precision. |
|
|
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
spatial_dims: int, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), |
|
|
num_channels: Sequence[int] = (32, 64, 64, 64), |
|
|
attention_levels: Sequence[bool] = (False, False, True, True), |
|
|
norm_num_groups: int = 32, |
|
|
norm_eps: float = 1e-6, |
|
|
resblock_updown: bool = False, |
|
|
num_head_channels: int | Sequence[int] = 8, |
|
|
with_conditioning: bool = False, |
|
|
transformer_num_layers: int = 1, |
|
|
cross_attention_dim: int | None = None, |
|
|
num_class_embeds: int | None = None, |
|
|
upcast_attention: bool = False, |
|
|
use_flash_attention: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
if with_conditioning is True and cross_attention_dim is None: |
|
|
raise ValueError( |
|
|
"DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " |
|
|
"when using with_conditioning." |
|
|
) |
|
|
if cross_attention_dim is not None and with_conditioning is False: |
|
|
raise ValueError( |
|
|
"DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." |
|
|
) |
|
|
|
|
|
|
|
|
if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): |
|
|
raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups") |
|
|
|
|
|
if len(num_channels) != len(attention_levels): |
|
|
raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels") |
|
|
|
|
|
if isinstance(num_head_channels, int): |
|
|
num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) |
|
|
|
|
|
if len(num_head_channels) != len(attention_levels): |
|
|
raise ValueError( |
|
|
"num_head_channels should have the same length as attention_levels. For the i levels without attention," |
|
|
" i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." |
|
|
) |
|
|
|
|
|
if isinstance(num_res_blocks, int): |
|
|
num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) |
|
|
|
|
|
if len(num_res_blocks) != len(num_channels): |
|
|
raise ValueError( |
|
|
"`num_res_blocks` should be a single integer or a tuple of integers with the same length as " |
|
|
"`num_channels`." |
|
|
) |
|
|
|
|
|
if use_flash_attention and not has_xformers: |
|
|
raise ValueError("use_flash_attention is True but xformers is not installed.") |
|
|
|
|
|
if use_flash_attention is True and not torch.cuda.is_available(): |
|
|
raise ValueError( |
|
|
"torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." |
|
|
) |
|
|
|
|
|
self.in_channels = in_channels |
|
|
self.block_out_channels = num_channels |
|
|
self.out_channels = out_channels |
|
|
self.num_res_blocks = num_res_blocks |
|
|
self.attention_levels = attention_levels |
|
|
self.num_head_channels = num_head_channels |
|
|
self.with_conditioning = with_conditioning |
|
|
|
|
|
|
|
|
self.conv_in = Convolution( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
out_channels=num_channels[0], |
|
|
strides=1, |
|
|
kernel_size=3, |
|
|
padding=1, |
|
|
conv_only=True, |
|
|
) |
|
|
|
|
|
|
|
|
time_embed_dim = num_channels[0] * 4 |
|
|
self.time_embed = nn.Sequential( |
|
|
nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) |
|
|
) |
|
|
|
|
|
|
|
|
self.num_class_embeds = num_class_embeds |
|
|
if num_class_embeds is not None: |
|
|
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) |
|
|
|
|
|
|
|
|
self.down_blocks = nn.ModuleList([]) |
|
|
output_channel = num_channels[0] |
|
|
for i in range(len(num_channels)): |
|
|
input_channel = output_channel |
|
|
output_channel = num_channels[i] |
|
|
is_final_block = i == len(num_channels) - 1 |
|
|
|
|
|
down_block = get_down_block( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=input_channel, |
|
|
out_channels=output_channel, |
|
|
temb_channels=time_embed_dim, |
|
|
num_res_blocks=num_res_blocks[i], |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
add_downsample=not is_final_block, |
|
|
resblock_updown=resblock_updown, |
|
|
with_attn=(attention_levels[i] and not with_conditioning), |
|
|
with_cross_attn=(attention_levels[i] and with_conditioning), |
|
|
num_head_channels=num_head_channels[i], |
|
|
transformer_num_layers=transformer_num_layers, |
|
|
cross_attention_dim=cross_attention_dim, |
|
|
upcast_attention=upcast_attention, |
|
|
use_flash_attention=use_flash_attention, |
|
|
) |
|
|
|
|
|
self.down_blocks.append(down_block) |
|
|
|
|
|
|
|
|
self.middle_block = get_mid_block( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=num_channels[-1], |
|
|
temb_channels=time_embed_dim, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
with_conditioning=with_conditioning, |
|
|
num_head_channels=num_head_channels[-1], |
|
|
transformer_num_layers=transformer_num_layers, |
|
|
cross_attention_dim=cross_attention_dim, |
|
|
upcast_attention=upcast_attention, |
|
|
use_flash_attention=use_flash_attention, |
|
|
) |
|
|
|
|
|
|
|
|
self.up_blocks = nn.ModuleList([]) |
|
|
reversed_block_out_channels = list(reversed(num_channels)) |
|
|
reversed_num_res_blocks = list(reversed(num_res_blocks)) |
|
|
reversed_attention_levels = list(reversed(attention_levels)) |
|
|
reversed_num_head_channels = list(reversed(num_head_channels)) |
|
|
output_channel = reversed_block_out_channels[0] |
|
|
for i in range(len(reversed_block_out_channels)): |
|
|
prev_output_channel = output_channel |
|
|
output_channel = reversed_block_out_channels[i] |
|
|
input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)] |
|
|
|
|
|
is_final_block = i == len(num_channels) - 1 |
|
|
|
|
|
up_block = get_up_block( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=input_channel, |
|
|
prev_output_channel=prev_output_channel, |
|
|
out_channels=output_channel, |
|
|
temb_channels=time_embed_dim, |
|
|
num_res_blocks=reversed_num_res_blocks[i] + 1, |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
add_upsample=not is_final_block, |
|
|
resblock_updown=resblock_updown, |
|
|
with_attn=(reversed_attention_levels[i] and not with_conditioning), |
|
|
with_cross_attn=(reversed_attention_levels[i] and with_conditioning), |
|
|
num_head_channels=reversed_num_head_channels[i], |
|
|
transformer_num_layers=transformer_num_layers, |
|
|
cross_attention_dim=cross_attention_dim, |
|
|
upcast_attention=upcast_attention, |
|
|
use_flash_attention=use_flash_attention, |
|
|
) |
|
|
|
|
|
self.up_blocks.append(up_block) |
|
|
|
|
|
|
|
|
self.out = nn.Sequential( |
|
|
nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True), |
|
|
nn.SiLU(), |
|
|
zero_module( |
|
|
Convolution( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=num_channels[0], |
|
|
out_channels=out_channels, |
|
|
strides=1, |
|
|
kernel_size=3, |
|
|
padding=1, |
|
|
conv_only=True, |
|
|
dilation=2 |
|
|
) |
|
|
), |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
timesteps: torch.Tensor, |
|
|
context: torch.Tensor | None = None, |
|
|
class_labels: torch.Tensor | None = None, |
|
|
down_block_additional_residuals: tuple[torch.Tensor] | None = None, |
|
|
mid_block_additional_residual: torch.Tensor | None = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: input tensor (N, C, SpatialDims). |
|
|
timesteps: timestep tensor (N,). |
|
|
context: context tensor (N, 1, ContextDim). |
|
|
class_labels: context tensor (N, ). |
|
|
down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). |
|
|
mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). |
|
|
""" |
|
|
|
|
|
t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t_emb = t_emb.to(dtype=x.dtype) |
|
|
emb = self.time_embed(t_emb) |
|
|
|
|
|
|
|
|
if self.num_class_embeds is not None: |
|
|
if class_labels is None: |
|
|
raise ValueError("class_labels should be provided when num_class_embeds > 0") |
|
|
class_emb = self.class_embedding(class_labels) |
|
|
class_emb = class_emb.to(dtype=x.dtype) |
|
|
emb = emb + class_emb |
|
|
|
|
|
|
|
|
h = self.conv_in(x) |
|
|
|
|
|
|
|
|
if context is not None and self.with_conditioning is False: |
|
|
raise ValueError("model should have with_conditioning = True if context is provided") |
|
|
down_block_res_samples: list[torch.Tensor] = [h] |
|
|
for downsample_block in self.down_blocks: |
|
|
h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) |
|
|
for residual in res_samples: |
|
|
down_block_res_samples.append(residual) |
|
|
|
|
|
|
|
|
if down_block_additional_residuals is not None: |
|
|
new_down_block_res_samples = () |
|
|
for down_block_res_sample, down_block_additional_residual in zip( |
|
|
down_block_res_samples, down_block_additional_residuals |
|
|
): |
|
|
down_block_res_sample = down_block_res_sample + down_block_additional_residual |
|
|
new_down_block_res_samples += (down_block_res_sample,) |
|
|
|
|
|
down_block_res_samples = new_down_block_res_samples |
|
|
|
|
|
|
|
|
h = self.middle_block(hidden_states=h, temb=emb, context=context) |
|
|
|
|
|
|
|
|
if mid_block_additional_residual is not None: |
|
|
h = h + mid_block_additional_residual |
|
|
|
|
|
|
|
|
for upsample_block in self.up_blocks: |
|
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
|
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
|
|
h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) |
|
|
|
|
|
|
|
|
h = self.out(h) |
|
|
|
|
|
return h |
|
|
|
|
|
|
|
|
class DiffusionModelEncoder(nn.Module): |
|
|
""" |
|
|
Classification Network based on the Encoder of the Diffusion Model, followed by fully connected layers. This network is based on |
|
|
Wolleb et al. "Diffusion Models for Medical Anomaly Detection" (https://arxiv.org/abs/2203.04306). |
|
|
|
|
|
Args: |
|
|
spatial_dims: number of spatial dimensions. |
|
|
in_channels: number of input channels. |
|
|
out_channels: number of output channels. |
|
|
num_res_blocks: number of residual blocks (see ResnetBlock) per level. |
|
|
num_channels: tuple of block output channels. |
|
|
attention_levels: list of levels to add attention. |
|
|
norm_num_groups: number of groups for the normalization. |
|
|
norm_eps: epsilon for the normalization. |
|
|
resblock_updown: if True use residual blocks for downsampling. |
|
|
num_head_channels: number of channels in each attention head. |
|
|
with_conditioning: if True add spatial transformers to perform conditioning. |
|
|
transformer_num_layers: number of layers of Transformer blocks to use. |
|
|
cross_attention_dim: number of context dimensions to use. |
|
|
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. |
|
|
upcast_attention: if True, upcast attention operations to full precision. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
spatial_dims: int, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), |
|
|
num_channels: Sequence[int] = (32, 64, 64, 64), |
|
|
attention_levels: Sequence[bool] = (False, False, True, True), |
|
|
norm_num_groups: int = 32, |
|
|
norm_eps: float = 1e-6, |
|
|
resblock_updown: bool = False, |
|
|
num_head_channels: int | Sequence[int] = 8, |
|
|
with_conditioning: bool = False, |
|
|
transformer_num_layers: int = 1, |
|
|
cross_attention_dim: int | None = None, |
|
|
num_class_embeds: int | None = None, |
|
|
upcast_attention: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
if with_conditioning is True and cross_attention_dim is None: |
|
|
raise ValueError( |
|
|
"DiffusionModelEncoder expects dimension of the cross-attention conditioning (cross_attention_dim) " |
|
|
"when using with_conditioning." |
|
|
) |
|
|
if cross_attention_dim is not None and with_conditioning is False: |
|
|
raise ValueError( |
|
|
"DiffusionModelEncoder expects with_conditioning=True when specifying the cross_attention_dim." |
|
|
) |
|
|
|
|
|
|
|
|
if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): |
|
|
raise ValueError("DiffusionModelEncoder expects all num_channels being multiple of norm_num_groups") |
|
|
if len(num_channels) != len(attention_levels): |
|
|
raise ValueError("DiffusionModelEncoder expects num_channels being same size of attention_levels") |
|
|
|
|
|
if isinstance(num_head_channels, int): |
|
|
num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) |
|
|
|
|
|
if len(num_head_channels) != len(attention_levels): |
|
|
raise ValueError( |
|
|
"num_head_channels should have the same length as attention_levels. For the i levels without attention," |
|
|
" i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." |
|
|
) |
|
|
|
|
|
self.in_channels = in_channels |
|
|
self.block_out_channels = num_channels |
|
|
self.out_channels = out_channels |
|
|
self.num_res_blocks = num_res_blocks |
|
|
self.attention_levels = attention_levels |
|
|
self.num_head_channels = num_head_channels |
|
|
self.with_conditioning = with_conditioning |
|
|
|
|
|
|
|
|
self.conv_in = Convolution( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=in_channels, |
|
|
out_channels=num_channels[0], |
|
|
strides=1, |
|
|
kernel_size=3, |
|
|
padding=1, |
|
|
conv_only=True, |
|
|
) |
|
|
|
|
|
|
|
|
time_embed_dim = num_channels[0] * 4 |
|
|
self.time_embed = nn.Sequential( |
|
|
nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) |
|
|
) |
|
|
|
|
|
|
|
|
self.num_class_embeds = num_class_embeds |
|
|
if num_class_embeds is not None: |
|
|
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) |
|
|
|
|
|
|
|
|
self.down_blocks = nn.ModuleList([]) |
|
|
output_channel = num_channels[0] |
|
|
for i in range(len(num_channels)): |
|
|
input_channel = output_channel |
|
|
output_channel = num_channels[i] |
|
|
is_final_block = i == len(num_channels) |
|
|
|
|
|
down_block = get_down_block( |
|
|
spatial_dims=spatial_dims, |
|
|
in_channels=input_channel, |
|
|
out_channels=output_channel, |
|
|
temb_channels=time_embed_dim, |
|
|
num_res_blocks=num_res_blocks[i], |
|
|
norm_num_groups=norm_num_groups, |
|
|
norm_eps=norm_eps, |
|
|
add_downsample=not is_final_block, |
|
|
resblock_updown=resblock_updown, |
|
|
with_attn=(attention_levels[i] and not with_conditioning), |
|
|
with_cross_attn=(attention_levels[i] and with_conditioning), |
|
|
num_head_channels=num_head_channels[i], |
|
|
transformer_num_layers=transformer_num_layers, |
|
|
cross_attention_dim=cross_attention_dim, |
|
|
upcast_attention=upcast_attention, |
|
|
) |
|
|
|
|
|
self.down_blocks.append(down_block) |
|
|
|
|
|
self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
timesteps: torch.Tensor, |
|
|
context: torch.Tensor | None = None, |
|
|
class_labels: torch.Tensor | None = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: input tensor (N, C, SpatialDims). |
|
|
timesteps: timestep tensor (N,). |
|
|
context: context tensor (N, 1, ContextDim). |
|
|
class_labels: context tensor (N, ). |
|
|
""" |
|
|
|
|
|
t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t_emb = t_emb.to(dtype=x.dtype) |
|
|
emb = self.time_embed(t_emb) |
|
|
|
|
|
|
|
|
if self.num_class_embeds is not None: |
|
|
if class_labels is None: |
|
|
raise ValueError("class_labels should be provided when num_class_embeds > 0") |
|
|
class_emb = self.class_embedding(class_labels) |
|
|
class_emb = class_emb.to(dtype=x.dtype) |
|
|
emb = emb + class_emb |
|
|
|
|
|
|
|
|
h = self.conv_in(x) |
|
|
|
|
|
|
|
|
if context is not None and self.with_conditioning is False: |
|
|
raise ValueError("model should have with_conditioning = True if context is provided") |
|
|
for downsample_block in self.down_blocks: |
|
|
h, _ = downsample_block(hidden_states=h, temb=emb, context=context) |
|
|
|
|
|
h = h.reshape(h.shape[0], -1) |
|
|
output = self.out(h) |
|
|
|
|
|
return output |
|
|
|