# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # ========================================================================= # Adapted from https://github.com/huggingface/diffusers # which has the following license: # https://github.com/huggingface/diffusers/blob/main/LICENSE # # Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ========================================================================= from __future__ import annotations import importlib.util import math from collections.abc import Sequence import torch import torch.nn.functional as F from monai.networks.blocks import Convolution, MLPBlock from monai.networks.layers.factories import Pool from monai.utils import ensure_tuple_rep from torch import nn # To install xformers, use pip install xformers==0.0.16rc401 if importlib.util.find_spec("xformers") is not None: import xformers import xformers.ops has_xformers = True else: xformers = None has_xformers = False # TODO: Use MONAI's optional_import # from monai.utils import optional_import # xformers, has_xformers = optional_import("xformers.ops", name="xformers") __all__ = ["DiffusionModelUNet"] def zero_module(module: nn.Module) -> nn.Module: """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module class CrossAttention(nn.Module): """ A cross attention layer. Args: query_dim: number of channels in the query. cross_attention_dim: number of channels in the context. num_attention_heads: number of heads to use for multi-head attention. num_head_channels: number of channels in each head. dropout: dropout probability to use. upcast_attention: if True, upcast attention operations to full precision. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( self, query_dim: int, cross_attention_dim: int | None = None, num_attention_heads: int = 8, num_head_channels: int = 64, dropout: float = 0.0, upcast_attention: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() self.use_flash_attention = use_flash_attention inner_dim = num_head_channels * num_attention_heads cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.scale = 1 / math.sqrt(num_head_channels) self.num_heads = num_attention_heads self.upcast_attention = upcast_attention self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: """ Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. """ batch_size, seq_len, dim = x.shape x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) return x def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: """Combine the output of the attention heads back into the hidden state dimension.""" batch_size, seq_len, dim = x.shape x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) return x def _memory_efficient_attention_xformers( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> torch.Tensor: query = query.contiguous() key = key.contiguous() value = value.contiguous() x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) return x def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: dtype = query.dtype if self.upcast_attention: query = query.float() key = key.float() attention_scores = torch.baddbmm( torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), query, key.transpose(-1, -2), beta=0, alpha=self.scale, ) attention_probs = attention_scores.softmax(dim=-1) attention_probs = attention_probs.to(dtype=dtype) x = torch.bmm(attention_probs, value) return x def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: query = self.to_q(x) context = context if context is not None else x key = self.to_k(context) value = self.to_v(context) # Multi-Head Attention query = self.reshape_heads_to_batch_dim(query) key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) if self.use_flash_attention: x = self._memory_efficient_attention_xformers(query, key, value) else: x = self._attention(query, key, value) x = self.reshape_batch_dim_to_heads(x) x = x.to(query.dtype) return self.to_out(x) class BasicTransformerBlock(nn.Module): """ A basic Transformer block. Args: num_channels: number of channels in the input and output. num_attention_heads: number of heads to use for multi-head attention. num_head_channels: number of channels in each attention head. dropout: dropout probability to use. cross_attention_dim: size of the context vector for cross attention. upcast_attention: if True, upcast attention operations to full precision. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( self, num_channels: int, num_attention_heads: int, num_head_channels: int, dropout: float = 0.0, cross_attention_dim: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() self.attn1 = CrossAttention( query_dim=num_channels, num_attention_heads=num_attention_heads, num_head_channels=num_head_channels, dropout=dropout, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, ) # is a self-attention self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) self.attn2 = CrossAttention( query_dim=num_channels, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, num_head_channels=num_head_channels, dropout=dropout, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, ) # is a self-attention if context is None self.norm1 = nn.LayerNorm(num_channels) self.norm2 = nn.LayerNorm(num_channels) self.norm3 = nn.LayerNorm(num_channels) def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: # 1. Self-Attention x = self.attn1(self.norm1(x)) + x # 2. Cross-Attention x = self.attn2(self.norm2(x), context=context) + x # 3. Feed-forward x = self.ff(self.norm3(x)) + x return x class SpatialTransformer(nn.Module): """ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply standard transformer action. Finally, reshape to image. Args: spatial_dims: number of spatial dimensions. in_channels: number of channels in the input and output. num_attention_heads: number of heads to use for multi-head attention. num_head_channels: number of channels in each attention head. num_layers: number of layers of Transformer blocks to use. dropout: dropout probability to use. norm_num_groups: number of groups for the normalization. norm_eps: epsilon for the normalization. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( self, spatial_dims: int, in_channels: int, num_attention_heads: int, num_head_channels: int, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, norm_eps: float = 1e-6, cross_attention_dim: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims self.in_channels = in_channels inner_dim = num_attention_heads * num_head_channels self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) self.proj_in = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=inner_dim, strides=1, kernel_size=1, padding=0, conv_only=True, ) self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( num_channels=inner_dim, num_attention_heads=num_attention_heads, num_head_channels=num_head_channels, dropout=dropout, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, ) for _ in range(num_layers) ] ) self.proj_out = zero_module( Convolution( spatial_dims=spatial_dims, in_channels=inner_dim, out_channels=in_channels, strides=1, kernel_size=1, padding=0, conv_only=True, ) ) def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: # note: if no context is given, cross-attention defaults to self-attention batch = channel = height = width = depth = -1 if self.spatial_dims == 2: batch, channel, height, width = x.shape if self.spatial_dims == 3: batch, channel, height, width, depth = x.shape residual = x x = self.norm(x) x = self.proj_in(x) inner_dim = x.shape[1] if self.spatial_dims == 2: x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) if self.spatial_dims == 3: x = x.permute(0, 2, 3, 4, 1).reshape(batch, height * width * depth, inner_dim) for block in self.transformer_blocks: x = block(x, context=context) if self.spatial_dims == 2: x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() if self.spatial_dims == 3: x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous() x = self.proj_out(x) return x + residual class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. Uses three q, k, v linear layers to compute attention. Args: spatial_dims: number of spatial dimensions. num_channels: number of input channels. num_head_channels: number of channels in each attention head. norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of channels is divisible by this number. norm_eps: epsilon value to use for the normalisation. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( self, spatial_dims: int, num_channels: int, num_head_channels: int | None = None, norm_num_groups: int = 32, norm_eps: float = 1e-6, use_flash_attention: bool = False, ) -> None: super().__init__() self.use_flash_attention = use_flash_attention self.spatial_dims = spatial_dims self.num_channels = num_channels self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 self.scale = 1 / math.sqrt(num_channels / self.num_heads) self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) self.to_q = nn.Linear(num_channels, num_channels) self.to_k = nn.Linear(num_channels, num_channels) self.to_v = nn.Linear(num_channels, num_channels) self.proj_attn = nn.Linear(num_channels, num_channels) def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: batch_size, seq_len, dim = x.shape x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) return x def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: batch_size, seq_len, dim = x.shape x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) return x def _memory_efficient_attention_xformers( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> torch.Tensor: query = query.contiguous() key = key.contiguous() value = value.contiguous() x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) return x def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: attention_scores = torch.baddbmm( torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), query, key.transpose(-1, -2), beta=0, alpha=self.scale, ) attention_probs = attention_scores.softmax(dim=-1) x = torch.bmm(attention_probs, value) return x def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x batch = channel = height = width = depth = -1 if self.spatial_dims == 2: batch, channel, height, width = x.shape if self.spatial_dims == 3: batch, channel, height, width, depth = x.shape # norm x = self.norm(x) if self.spatial_dims == 2: x = x.view(batch, channel, height * width).transpose(1, 2) if self.spatial_dims == 3: x = x.view(batch, channel, height * width * depth).transpose(1, 2) # proj to q, k, v query = self.to_q(x) key = self.to_k(x) value = self.to_v(x) # Multi-Head Attention query = self.reshape_heads_to_batch_dim(query) key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) if self.use_flash_attention: x = self._memory_efficient_attention_xformers(query, key, value) else: x = self._attention(query, key, value) x = self.reshape_batch_dim_to_heads(x) x = x.to(query.dtype) if self.spatial_dims == 2: x = x.transpose(-1, -2).reshape(batch, channel, height, width) if self.spatial_dims == 3: x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) return x + residual def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor: """ Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic Models" https://arxiv.org/abs/2006.11239. Args: timesteps: a 1-D Tensor of N indices, one per batch element. embedding_dim: the dimension of the output. max_period: controls the minimum frequency of the embeddings. """ if timesteps.ndim != 1: raise ValueError("Timesteps should be a 1d-array") half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) freqs = torch.exp(exponent / half_dim) args = timesteps[:, None].float() * freqs[None, :] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # zero pad if embedding_dim % 2 == 1: embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0)) return embedding class Downsample(nn.Module): """ Downsampling layer. Args: spatial_dims: number of spatial dimensions. num_channels: number of input channels. use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is False, the number of output channels must be the same as the number of input channels. out_channels: number of output channels. padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each dimension. """ def __init__( self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 ) -> None: super().__init__() self.num_channels = num_channels self.out_channels = out_channels or num_channels self.use_conv = use_conv if use_conv: self.op = Convolution( spatial_dims=spatial_dims, in_channels=self.num_channels, out_channels=self.out_channels, strides=2, kernel_size=3, padding=padding, conv_only=True, ) else: if self.num_channels != self.out_channels: raise ValueError("num_channels and out_channels must be equal when use_conv=False") self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2) def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: del emb if x.shape[1] != self.num_channels: raise ValueError( f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels " f"({self.num_channels})" ) return self.op(x) class Upsample(nn.Module): """ Upsampling layer with an optional convolution. Args: spatial_dims: number of spatial dimensions. num_channels: number of input channels. use_conv: if True uses Convolution instead of Pool average to perform downsampling. out_channels: number of output channels. padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each dimension. """ def __init__( self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 ) -> None: super().__init__() self.num_channels = num_channels self.out_channels = out_channels or num_channels self.use_conv = use_conv if use_conv: self.conv = Convolution( spatial_dims=spatial_dims, in_channels=self.num_channels, out_channels=self.out_channels, strides=1, kernel_size=3, padding=padding, conv_only=True, ) else: self.conv = None def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: del emb if x.shape[1] != self.num_channels: raise ValueError("Input channels should be equal to num_channels") # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 # https://github.com/pytorch/pytorch/issues/86679 dtype = x.dtype if dtype == torch.bfloat16: x = x.to(torch.float32) x = F.interpolate(x, scale_factor=2.0, mode="nearest") # If the input is bfloat16, we cast back to bfloat16 if dtype == torch.bfloat16: x = x.to(dtype) if self.use_conv: x = self.conv(x) return x class ResnetBlock(nn.Module): """ Residual block with timestep conditioning. Args: spatial_dims: The number of spatial dimensions. in_channels: number of input channels. temb_channels: number of timestep embedding channels. out_channels: number of output channels. up: if True, performs upsampling. down: if True, performs downsampling. norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. """ def __init__( self, spatial_dims: int, in_channels: int, temb_channels: int, out_channels: int | None = None, up: bool = False, down: bool = False, norm_num_groups: int = 32, norm_eps: float = 1e-6, ) -> None: super().__init__() self.spatial_dims = spatial_dims self.channels = in_channels self.emb_channels = temb_channels self.out_channels = out_channels or in_channels self.up = up self.down = down self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) self.nonlinearity = nn.SiLU() self.conv1 = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=self.out_channels, strides=1, kernel_size=3, padding=1, conv_only=True, dilation=3 ) self.upsample = self.downsample = None if self.up: self.upsample = Upsample(spatial_dims, in_channels, use_conv=False) elif down: self.downsample = Downsample(spatial_dims, in_channels, use_conv=False) self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.out_channels, eps=norm_eps, affine=True) self.conv2 = zero_module( Convolution( spatial_dims=spatial_dims, in_channels=self.out_channels, out_channels=self.out_channels, strides=1, kernel_size=3, padding=1, conv_only=True, dilation=2 ) ) if self.out_channels == in_channels: self.skip_connection = nn.Identity() else: self.skip_connection = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=self.out_channels, strides=1, kernel_size=1, padding=0, conv_only=True, dilation=1 ) def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: h = x h = self.norm1(h) h = self.nonlinearity(h) if self.upsample is not None: if h.shape[0] >= 64: x = x.contiguous() h = h.contiguous() x = self.upsample(x) h = self.upsample(h) elif self.downsample is not None: x = self.downsample(x) h = self.downsample(h) h = self.conv1(h) if self.spatial_dims == 2: temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] else: temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] h = h + temb h = self.norm2(h) h = self.nonlinearity(h) h = self.conv2(h) return self.skip_connection(x) + h class DownBlock(nn.Module): """ Unet's down block containing resnet and downsamplers blocks. Args: spatial_dims: The number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. temb_channels: number of timestep embedding channels. num_res_blocks: number of residual blocks. norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. add_downsample: if True add downsample block. resblock_updown: if True use residual blocks for downsampling. downsample_padding: padding used in the downsampling block. """ def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, temb_channels: int, num_res_blocks: int = 1, norm_num_groups: int = 32, norm_eps: float = 1e-6, add_downsample: bool = True, resblock_updown: bool = False, downsample_padding: int = 1, ) -> None: super().__init__() self.resblock_updown = resblock_updown resnets = [] for i in range(num_res_blocks): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: if resblock_updown: self.downsampler = ResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, down=True, ) else: self.downsampler = Downsample( spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, ) else: self.downsampler = None def forward( self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None ) -> tuple[torch.Tensor, list[torch.Tensor]]: del context output_states = [] for resnet in self.resnets: hidden_states = resnet(hidden_states, temb) output_states.append(hidden_states) if self.downsampler is not None: hidden_states = self.downsampler(hidden_states, temb) output_states.append(hidden_states) return hidden_states, output_states class AttnDownBlock(nn.Module): """ Unet's down block containing resnet, downsamplers and self-attention blocks. Args: spatial_dims: The number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. temb_channels: number of timestep embedding channels. num_res_blocks: number of residual blocks. norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. add_downsample: if True add downsample block. resblock_updown: if True use residual blocks for downsampling. downsample_padding: padding used in the downsampling block. num_head_channels: number of channels in each attention head. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, temb_channels: int, num_res_blocks: int = 1, norm_num_groups: int = 32, norm_eps: float = 1e-6, add_downsample: bool = True, resblock_updown: bool = False, downsample_padding: int = 1, num_head_channels: int = 1, use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown resnets = [] attentions = [] for i in range(num_res_blocks): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, ) ) attentions.append( AttentionBlock( spatial_dims=spatial_dims, num_channels=out_channels, num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, use_flash_attention=use_flash_attention, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: if resblock_updown: self.downsampler = ResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, down=True, ) else: self.downsampler = Downsample( spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, ) else: self.downsampler = None def forward( self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None ) -> tuple[torch.Tensor, list[torch.Tensor]]: del context output_states = [] for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states) output_states.append(hidden_states) if self.downsampler is not None: hidden_states = self.downsampler(hidden_states, temb) output_states.append(hidden_states) return hidden_states, output_states class CrossAttnDownBlock(nn.Module): """ Unet's down block containing resnet, downsamplers and cross-attention blocks. Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. temb_channels: number of timestep embedding channels. num_res_blocks: number of residual blocks. norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. add_downsample: if True add downsample block. resblock_updown: if True use residual blocks for downsampling. downsample_padding: padding used in the downsampling block. num_head_channels: number of channels in each attention head. transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, temb_channels: int, num_res_blocks: int = 1, norm_num_groups: int = 32, norm_eps: float = 1e-6, add_downsample: bool = True, resblock_updown: bool = False, downsample_padding: int = 1, num_head_channels: int = 1, transformer_num_layers: int = 1, cross_attention_dim: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown resnets = [] attentions = [] for i in range(num_res_blocks): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, ) ) attentions.append( SpatialTransformer( spatial_dims=spatial_dims, in_channels=out_channels, num_attention_heads=out_channels // num_head_channels, num_head_channels=num_head_channels, num_layers=transformer_num_layers, norm_num_groups=norm_num_groups, norm_eps=norm_eps, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: if resblock_updown: self.downsampler = ResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, down=True, ) else: self.downsampler = Downsample( spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, ) else: self.downsampler = None def forward( self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None ) -> tuple[torch.Tensor, list[torch.Tensor]]: output_states = [] for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states, context=context) output_states.append(hidden_states) if self.downsampler is not None: hidden_states = self.downsampler(hidden_states, temb) output_states.append(hidden_states) return hidden_states, output_states class AttnMidBlock(nn.Module): """ Unet's mid block containing resnet and self-attention blocks. Args: spatial_dims: The number of spatial dimensions. in_channels: number of input channels. temb_channels: number of timestep embedding channels. norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. num_head_channels: number of channels in each attention head. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( self, spatial_dims: int, in_channels: int, temb_channels: int, norm_num_groups: int = 32, norm_eps: float = 1e-6, num_head_channels: int = 1, use_flash_attention: bool = False, ) -> None: super().__init__() self.attention = None self.resnet_1 = ResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, ) self.attention = AttentionBlock( spatial_dims=spatial_dims, num_channels=in_channels, num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, use_flash_attention=use_flash_attention, ) self.resnet_2 = ResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, ) def forward( self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None ) -> torch.Tensor: del context hidden_states = self.resnet_1(hidden_states, temb) hidden_states = self.attention(hidden_states) hidden_states = self.resnet_2(hidden_states, temb) return hidden_states class CrossAttnMidBlock(nn.Module): """ Unet's mid block containing resnet and cross-attention blocks. Args: spatial_dims: The number of spatial dimensions. in_channels: number of input channels. temb_channels: number of timestep embedding channels norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. num_head_channels: number of channels in each attention head. transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( self, spatial_dims: int, in_channels: int, temb_channels: int, norm_num_groups: int = 32, norm_eps: float = 1e-6, num_head_channels: int = 1, transformer_num_layers: int = 1, cross_attention_dim: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() self.attention = None self.resnet_1 = ResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, ) self.attention = SpatialTransformer( spatial_dims=spatial_dims, in_channels=in_channels, num_attention_heads=in_channels // num_head_channels, num_head_channels=num_head_channels, num_layers=transformer_num_layers, norm_num_groups=norm_num_groups, norm_eps=norm_eps, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, ) self.resnet_2 = ResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, ) def forward( self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None ) -> torch.Tensor: hidden_states = self.resnet_1(hidden_states, temb) hidden_states = self.attention(hidden_states, context=context) hidden_states = self.resnet_2(hidden_states, temb) return hidden_states class UpBlock(nn.Module): """ Unet's up block containing resnet and upsamplers blocks. Args: spatial_dims: The number of spatial dimensions. in_channels: number of input channels. prev_output_channel: number of channels from residual connection. out_channels: number of output channels. temb_channels: number of timestep embedding channels. num_res_blocks: number of residual blocks. norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. add_upsample: if True add downsample block. resblock_updown: if True use residual blocks for upsampling. """ def __init__( self, spatial_dims: int, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, num_res_blocks: int = 1, norm_num_groups: int = 32, norm_eps: float = 1e-6, add_upsample: bool = True, resblock_updown: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown resnets = [] for i in range(num_res_blocks): res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock( spatial_dims=spatial_dims, in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: if resblock_updown: self.upsampler = ResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, up=True, ) else: self.upsampler = Upsample( spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels ) else: self.upsampler = None def forward( self, hidden_states: torch.Tensor, res_hidden_states_list: list[torch.Tensor], temb: torch.Tensor, context: torch.Tensor | None = None, ) -> torch.Tensor: del context for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_list[-1] res_hidden_states_list = res_hidden_states_list[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) if self.upsampler is not None: hidden_states = self.upsampler(hidden_states, temb) return hidden_states class AttnUpBlock(nn.Module): """ Unet's up block containing resnet, upsamplers, and self-attention blocks. Args: spatial_dims: The number of spatial dimensions. in_channels: number of input channels. prev_output_channel: number of channels from residual connection. out_channels: number of output channels. temb_channels: number of timestep embedding channels. num_res_blocks: number of residual blocks. norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. add_upsample: if True add downsample block. resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( self, spatial_dims: int, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, num_res_blocks: int = 1, norm_num_groups: int = 32, norm_eps: float = 1e-6, add_upsample: bool = True, resblock_updown: bool = False, num_head_channels: int = 1, use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown resnets = [] attentions = [] for i in range(num_res_blocks): res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock( spatial_dims=spatial_dims, in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, ) ) attentions.append( AttentionBlock( spatial_dims=spatial_dims, num_channels=out_channels, num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, use_flash_attention=use_flash_attention, ) ) self.resnets = nn.ModuleList(resnets) self.attentions = nn.ModuleList(attentions) if add_upsample: if resblock_updown: self.upsampler = ResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, up=True, ) else: self.upsampler = Upsample( spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels ) else: self.upsampler = None def forward( self, hidden_states: torch.Tensor, res_hidden_states_list: list[torch.Tensor], temb: torch.Tensor, context: torch.Tensor | None = None, ) -> torch.Tensor: del context for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_list[-1] res_hidden_states_list = res_hidden_states_list[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states) if self.upsampler is not None: hidden_states = self.upsampler(hidden_states, temb) return hidden_states class CrossAttnUpBlock(nn.Module): """ Unet's up block containing resnet, upsamplers, and self-attention blocks. Args: spatial_dims: The number of spatial dimensions. in_channels: number of input channels. prev_output_channel: number of channels from residual connection. out_channels: number of output channels. temb_channels: number of timestep embedding channels. num_res_blocks: number of residual blocks. norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. add_upsample: if True add downsample block. resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( self, spatial_dims: int, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, num_res_blocks: int = 1, norm_num_groups: int = 32, norm_eps: float = 1e-6, add_upsample: bool = True, resblock_updown: bool = False, num_head_channels: int = 1, transformer_num_layers: int = 1, cross_attention_dim: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown resnets = [] attentions = [] for i in range(num_res_blocks): res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock( spatial_dims=spatial_dims, in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, ) ) attentions.append( SpatialTransformer( spatial_dims=spatial_dims, in_channels=out_channels, num_attention_heads=out_channels // num_head_channels, num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: if resblock_updown: self.upsampler = ResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, up=True, ) else: self.upsampler = Upsample( spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels ) else: self.upsampler = None def forward( self, hidden_states: torch.Tensor, res_hidden_states_list: list[torch.Tensor], temb: torch.Tensor, context: torch.Tensor | None = None, ) -> torch.Tensor: for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_list[-1] res_hidden_states_list = res_hidden_states_list[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states, context=context) if self.upsampler is not None: hidden_states = self.upsampler(hidden_states, temb) return hidden_states def get_down_block( spatial_dims: int, in_channels: int, out_channels: int, temb_channels: int, num_res_blocks: int, norm_num_groups: int, norm_eps: float, add_downsample: bool, resblock_updown: bool, with_attn: bool, with_cross_attn: bool, num_head_channels: int, transformer_num_layers: int, cross_attention_dim: int | None, upcast_attention: bool = False, use_flash_attention: bool = False, ) -> nn.Module: if with_attn: return AttnDownBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, num_res_blocks=num_res_blocks, norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_downsample=add_downsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, use_flash_attention=use_flash_attention, ) elif with_cross_attn: return CrossAttnDownBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, num_res_blocks=num_res_blocks, norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_downsample=add_downsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, ) else: return DownBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, num_res_blocks=num_res_blocks, norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_downsample=add_downsample, resblock_updown=resblock_updown, ) def get_mid_block( spatial_dims: int, in_channels: int, temb_channels: int, norm_num_groups: int, norm_eps: float, with_conditioning: bool, num_head_channels: int, transformer_num_layers: int, cross_attention_dim: int | None, upcast_attention: bool = False, use_flash_attention: bool = False, ) -> nn.Module: if with_conditioning: return CrossAttnMidBlock( spatial_dims=spatial_dims, in_channels=in_channels, temb_channels=temb_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, num_head_channels=num_head_channels, transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, ) else: return AttnMidBlock( spatial_dims=spatial_dims, in_channels=in_channels, temb_channels=temb_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, num_head_channels=num_head_channels, use_flash_attention=use_flash_attention, ) def get_up_block( spatial_dims: int, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, num_res_blocks: int, norm_num_groups: int, norm_eps: float, add_upsample: bool, resblock_updown: bool, with_attn: bool, with_cross_attn: bool, num_head_channels: int, transformer_num_layers: int, cross_attention_dim: int | None, upcast_attention: bool = False, use_flash_attention: bool = False, ) -> nn.Module: if with_attn: return AttnUpBlock( spatial_dims=spatial_dims, in_channels=in_channels, prev_output_channel=prev_output_channel, out_channels=out_channels, temb_channels=temb_channels, num_res_blocks=num_res_blocks, norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_upsample=add_upsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, use_flash_attention=use_flash_attention, ) elif with_cross_attn: return CrossAttnUpBlock( spatial_dims=spatial_dims, in_channels=in_channels, prev_output_channel=prev_output_channel, out_channels=out_channels, temb_channels=temb_channels, num_res_blocks=num_res_blocks, norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_upsample=add_upsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, ) else: return UpBlock( spatial_dims=spatial_dims, in_channels=in_channels, prev_output_channel=prev_output_channel, out_channels=out_channels, temb_channels=temb_channels, num_res_blocks=num_res_blocks, norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_upsample=add_upsample, resblock_updown=resblock_updown, ) class DiffusionModelUNet(nn.Module): """ Unet network with timestep embedding and attention mechanisms for conditioning based on Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. num_res_blocks: number of residual blocks (see ResnetBlock) per level. num_channels: tuple of block output channels. attention_levels: list of levels to add attention. norm_num_groups: number of groups for the normalization. norm_eps: epsilon for the normalization. resblock_updown: if True use residual blocks for up/downsampling. num_head_channels: number of channels in each attention head. with_conditioning: if True add spatial transformers to perform conditioning. transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: if True, upcast attention operations to full precision. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), num_channels: Sequence[int] = (32, 64, 64, 64), attention_levels: Sequence[bool] = (False, False, True, True), norm_num_groups: int = 32, norm_eps: float = 1e-6, resblock_updown: bool = False, num_head_channels: int | Sequence[int] = 8, with_conditioning: bool = False, transformer_num_layers: int = 1, cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: raise ValueError( "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " "when using with_conditioning." ) if cross_attention_dim is not None and with_conditioning is False: raise ValueError( "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." ) # All number of channels should be multiple of num_groups if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups") if len(num_channels) != len(attention_levels): raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels") if isinstance(num_head_channels, int): num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) if len(num_head_channels) != len(attention_levels): raise ValueError( "num_head_channels should have the same length as attention_levels. For the i levels without attention," " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." ) if isinstance(num_res_blocks, int): num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) if len(num_res_blocks) != len(num_channels): raise ValueError( "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " "`num_channels`." ) if use_flash_attention and not has_xformers: raise ValueError("use_flash_attention is True but xformers is not installed.") if use_flash_attention is True and not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." ) self.in_channels = in_channels self.block_out_channels = num_channels self.out_channels = out_channels self.num_res_blocks = num_res_blocks self.attention_levels = attention_levels self.num_head_channels = num_head_channels self.with_conditioning = with_conditioning # input self.conv_in = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=num_channels[0], strides=1, kernel_size=3, padding=1, conv_only=True, ) # time time_embed_dim = num_channels[0] * 4 self.time_embed = nn.Sequential( nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) ) # class embedding self.num_class_embeds = num_class_embeds if num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) # down self.down_blocks = nn.ModuleList([]) output_channel = num_channels[0] for i in range(len(num_channels)): input_channel = output_channel output_channel = num_channels[i] is_final_block = i == len(num_channels) - 1 down_block = get_down_block( spatial_dims=spatial_dims, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, num_res_blocks=num_res_blocks[i], norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_downsample=not is_final_block, resblock_updown=resblock_updown, with_attn=(attention_levels[i] and not with_conditioning), with_cross_attn=(attention_levels[i] and with_conditioning), num_head_channels=num_head_channels[i], transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) # mid self.middle_block = get_mid_block( spatial_dims=spatial_dims, in_channels=num_channels[-1], temb_channels=time_embed_dim, norm_num_groups=norm_num_groups, norm_eps=norm_eps, with_conditioning=with_conditioning, num_head_channels=num_head_channels[-1], transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, ) # up self.up_blocks = nn.ModuleList([]) reversed_block_out_channels = list(reversed(num_channels)) reversed_num_res_blocks = list(reversed(num_res_blocks)) reversed_attention_levels = list(reversed(attention_levels)) reversed_num_head_channels = list(reversed(num_head_channels)) output_channel = reversed_block_out_channels[0] for i in range(len(reversed_block_out_channels)): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)] is_final_block = i == len(num_channels) - 1 up_block = get_up_block( spatial_dims=spatial_dims, in_channels=input_channel, prev_output_channel=prev_output_channel, out_channels=output_channel, temb_channels=time_embed_dim, num_res_blocks=reversed_num_res_blocks[i] + 1, norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_upsample=not is_final_block, resblock_updown=resblock_updown, with_attn=(reversed_attention_levels[i] and not with_conditioning), with_cross_attn=(reversed_attention_levels[i] and with_conditioning), num_head_channels=reversed_num_head_channels[i], transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, ) self.up_blocks.append(up_block) # out self.out = nn.Sequential( nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True), nn.SiLU(), zero_module( Convolution( spatial_dims=spatial_dims, in_channels=num_channels[0], out_channels=out_channels, strides=1, kernel_size=3, padding=1, conv_only=True, dilation=2 ) ), ) def forward( self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor | None = None, class_labels: torch.Tensor | None = None, down_block_additional_residuals: tuple[torch.Tensor] | None = None, mid_block_additional_residual: torch.Tensor | None = None, ) -> torch.Tensor: """ Args: x: input tensor (N, C, SpatialDims). timesteps: timestep tensor (N,). context: context tensor (N, 1, ContextDim). class_labels: context tensor (N, ). down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). """ # 1. time t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=x.dtype) emb = self.time_embed(t_emb) # 2. class if self.num_class_embeds is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") class_emb = self.class_embedding(class_labels) class_emb = class_emb.to(dtype=x.dtype) emb = emb + class_emb # 3. initial convolution h = self.conv_in(x) # 4. down if context is not None and self.with_conditioning is False: raise ValueError("model should have with_conditioning = True if context is provided") down_block_res_samples: list[torch.Tensor] = [h] for downsample_block in self.down_blocks: h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) for residual in res_samples: down_block_res_samples.append(residual) # Additional residual conections for Controlnets if down_block_additional_residuals is not None: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples += (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 5. mid h = self.middle_block(hidden_states=h, temb=emb, context=context) # Additional residual conections for Controlnets if mid_block_additional_residual is not None: h = h + mid_block_additional_residual # 6. up for upsample_block in self.up_blocks: res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) # 7. output block h = self.out(h) return h class DiffusionModelEncoder(nn.Module): """ Classification Network based on the Encoder of the Diffusion Model, followed by fully connected layers. This network is based on Wolleb et al. "Diffusion Models for Medical Anomaly Detection" (https://arxiv.org/abs/2203.04306). Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. num_res_blocks: number of residual blocks (see ResnetBlock) per level. num_channels: tuple of block output channels. attention_levels: list of levels to add attention. norm_num_groups: number of groups for the normalization. norm_eps: epsilon for the normalization. resblock_updown: if True use residual blocks for downsampling. num_head_channels: number of channels in each attention head. with_conditioning: if True add spatial transformers to perform conditioning. transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: if True, upcast attention operations to full precision. """ def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), num_channels: Sequence[int] = (32, 64, 64, 64), attention_levels: Sequence[bool] = (False, False, True, True), norm_num_groups: int = 32, norm_eps: float = 1e-6, resblock_updown: bool = False, num_head_channels: int | Sequence[int] = 8, with_conditioning: bool = False, transformer_num_layers: int = 1, cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: raise ValueError( "DiffusionModelEncoder expects dimension of the cross-attention conditioning (cross_attention_dim) " "when using with_conditioning." ) if cross_attention_dim is not None and with_conditioning is False: raise ValueError( "DiffusionModelEncoder expects with_conditioning=True when specifying the cross_attention_dim." ) # All number of channels should be multiple of num_groups if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): raise ValueError("DiffusionModelEncoder expects all num_channels being multiple of norm_num_groups") if len(num_channels) != len(attention_levels): raise ValueError("DiffusionModelEncoder expects num_channels being same size of attention_levels") if isinstance(num_head_channels, int): num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) if len(num_head_channels) != len(attention_levels): raise ValueError( "num_head_channels should have the same length as attention_levels. For the i levels without attention," " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." ) self.in_channels = in_channels self.block_out_channels = num_channels self.out_channels = out_channels self.num_res_blocks = num_res_blocks self.attention_levels = attention_levels self.num_head_channels = num_head_channels self.with_conditioning = with_conditioning # input self.conv_in = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=num_channels[0], strides=1, kernel_size=3, padding=1, conv_only=True, ) # time time_embed_dim = num_channels[0] * 4 self.time_embed = nn.Sequential( nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) ) # class embedding self.num_class_embeds = num_class_embeds if num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) # down self.down_blocks = nn.ModuleList([]) output_channel = num_channels[0] for i in range(len(num_channels)): input_channel = output_channel output_channel = num_channels[i] is_final_block = i == len(num_channels) # - 1 down_block = get_down_block( spatial_dims=spatial_dims, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, num_res_blocks=num_res_blocks[i], norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_downsample=not is_final_block, resblock_updown=resblock_updown, with_attn=(attention_levels[i] and not with_conditioning), with_cross_attn=(attention_levels[i] and with_conditioning), num_head_channels=num_head_channels[i], transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, ) self.down_blocks.append(down_block) self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)) def forward( self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor | None = None, class_labels: torch.Tensor | None = None, ) -> torch.Tensor: """ Args: x: input tensor (N, C, SpatialDims). timesteps: timestep tensor (N,). context: context tensor (N, 1, ContextDim). class_labels: context tensor (N, ). """ # 1. time t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=x.dtype) emb = self.time_embed(t_emb) # 2. class if self.num_class_embeds is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") class_emb = self.class_embedding(class_labels) class_emb = class_emb.to(dtype=x.dtype) emb = emb + class_emb # 3. initial convolution h = self.conv_in(x) # 4. down if context is not None and self.with_conditioning is False: raise ValueError("model should have with_conditioning = True if context is provided") for downsample_block in self.down_blocks: h, _ = downsample_block(hidden_states=h, temb=emb, context=context) h = h.reshape(h.shape[0], -1) output = self.out(h) return output