# Copyright (C) 2025 AIDC-AI # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. import os from dataclasses import dataclass import torch from einops import rearrange from safetensors.torch import load_file as load_sft from torch import nn, Tensor @dataclass class AutoEncoderParams: resolution: int = 256 in_channels: int = 3 ch: int = 128 out_ch: int = 3 ch_mult: tuple[int] = (1, 2, 4, 4) num_res_blocks: int = 2 z_channels: int = 16 scale_factor: float = 0.3611 shift_factor: float = 0.1159 use_quant_conv: bool = False use_post_quant_conv: bool = False def swish(x: Tensor) -> Tensor: return x * torch.sigmoid(x) class AttnBlock(nn.Module): def __init__(self, in_channels: int): super().__init__() self.in_channels = in_channels self.norm = nn.GroupNorm( num_groups=32, num_channels=in_channels, eps=1e-6, affine=True ) self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) def attention(self, h_: Tensor) -> Tensor: h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) b, c, h, w = q.shape q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() h_ = nn.functional.scaled_dot_product_attention(q, k, v) return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) def forward(self, x: Tensor) -> Tensor: return x + self.proj_out(self.attention(x)) class ResnetBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.norm1 = nn.GroupNorm( num_groups=32, num_channels=in_channels, eps=1e-6, affine=True ) self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) self.norm2 = nn.GroupNorm( num_groups=32, num_channels=out_channels, eps=1e-6, affine=True ) self.conv2 = nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1 ) if self.in_channels != self.out_channels: self.nin_shortcut = nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0 ) def forward(self, x): h = x h = self.norm1(h) h = swish(h) h = self.conv1(h) h = self.norm2(h) h = swish(h) h = self.conv2(h) if self.in_channels != self.out_channels: x = self.nin_shortcut(x) return x + h class Downsample(nn.Module): def __init__(self, in_channels: int): super().__init__() # no asymmetric padding in torch conv, must do it ourselves self.conv = nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=2, padding=0 ) def forward(self, x: Tensor): pad = (0, 1, 0, 1) x = nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) return x class Upsample(nn.Module): def __init__(self, in_channels: int): super().__init__() self.conv = nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=1, padding=1 ) def forward(self, x: Tensor): x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") x = self.conv(x) return x class Encoder(nn.Module): def __init__( self, resolution: int, in_channels: int, ch: int, ch_mult: list[int], num_res_blocks: int, z_channels: int, ): super().__init__() self.ch = ch self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels # downsampling self.conv_in = nn.Conv2d( in_channels, self.ch, kernel_size=3, stride=1, padding=1 ) curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() block_in = self.ch for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) block_in = block_out down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) self.mid.attn_1 = AttnBlock(block_in) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) # end self.norm_out = nn.GroupNorm( num_groups=32, num_channels=block_in, eps=1e-6, affine=True ) self.conv_out = nn.Conv2d( block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1 ) def forward(self, x: Tensor) -> Tensor: # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1]) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) # end h = self.norm_out(h) h = swish(h) h = self.conv_out(h) return h class Decoder(nn.Module): def __init__( self, ch: int, out_ch: int, ch_mult: list[int], num_res_blocks: int, in_channels: int, resolution: int, z_channels: int, ): super().__init__() self.ch = ch self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.ffactor = 2 ** (self.num_resolutions - 1) # compute in_ch_mult, block_in and curr_res at lowest res block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) # z to block_in self.conv_in = nn.Conv2d( z_channels, block_in, kernel_size=3, stride=1, padding=1 ) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) self.mid.attn_1 = AttnBlock(block_in) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks + 1): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) block_in = block_out up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = nn.GroupNorm( num_groups=32, num_channels=block_in, eps=1e-6, affine=True ) self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, z: Tensor) -> Tensor: # get dtype for proper tracing upscale_dtype = next(self.up.parameters()).dtype # z to block_in h = self.conv_in(z) # middle h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) # cast to proper dtype h = h.to(upscale_dtype) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: h = self.up[i_level].upsample(h) # end h = self.norm_out(h) h = swish(h) h = self.conv_out(h) return h class DiagonalGaussian(nn.Module): def __init__(self, sample: bool = True, chunk_dim: int = 1): super().__init__() self.sample = sample self.chunk_dim = chunk_dim def forward(self, z: Tensor) -> Tensor: mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) if self.sample: std = torch.exp(0.5 * logvar) return mean + std * torch.randn_like(mean) else: return mean class AutoEncoder(nn.Module): def __init__(self, params: AutoEncoderParams): super().__init__() self.params = params self.encoder = Encoder( resolution=params.resolution, in_channels=params.in_channels, ch=params.ch, ch_mult=params.ch_mult, num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, ) self.decoder = Decoder( resolution=params.resolution, in_channels=params.in_channels, ch=params.ch, out_ch=params.out_ch, ch_mult=params.ch_mult, num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, ) self.reg = DiagonalGaussian() self.scale_factor = params.scale_factor self.shift_factor = params.shift_factor self.quant_conv = nn.Conv2d(2 * params.z_channels, 2 * params.z_channels, 1) if params.use_quant_conv else None self.post_quant_conv = nn.Conv2d(params.z_channels, params.z_channels, 1) if params.use_post_quant_conv else None def encode(self, x: Tensor) -> Tensor: x = self.encoder(x) if self.quant_conv is not None: x = self.quant_conv(x) z = self.reg(x) z = self.scale_factor * (z - self.shift_factor) return z def decode(self, z: Tensor) -> Tensor: z = z / self.scale_factor + self.shift_factor if self.post_quant_conv is not None: z = self.post_quant_conv(z) return self.decoder(z) def forward(self, x: Tensor) -> Tensor: return self.decode(self.encode(x)) def load_ae( ckpt_path: str, autoencoder_params: AutoEncoderParams, device: str | torch.device = "cuda", dtype=torch.bfloat16, random_init=False, ) -> AutoEncoder: """ Load the autoencoder from the given model name. Args: name (str): The name of the autoencoder. device (str or torch.device): The device to load the autoencoder to. Returns: AutoEncoder: The loaded autoencoder. """ # Loading the autoencoder with torch.device(device): ae = AutoEncoder(autoencoder_params) if random_init: print(f"Random Init VAE") return ae.to(dtype=dtype) if not os.path.exists(ckpt_path): raise ValueError( f"Autoencoder path {ckpt_path} does not exist. Please download it first." ) if ckpt_path is not None: print(f"Loading {ckpt_path}") sd = load_sft(ckpt_path, device=str(device)) missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) if len(missing) > 0: print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) if len(unexpected) > 0: print( f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected) ) return ae.to(dtype=dtype)