Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (C) 2025 AIDC-AI | |
| # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. | |
| import os | |
| from dataclasses import dataclass | |
| import torch | |
| from einops import rearrange | |
| from safetensors.torch import load_file as load_sft | |
| from torch import nn, Tensor | |
| class AutoEncoderParams: | |
| resolution: int = 256 | |
| in_channels: int = 3 | |
| ch: int = 128 | |
| out_ch: int = 3 | |
| ch_mult: tuple[int] = (1, 2, 4, 4) | |
| num_res_blocks: int = 2 | |
| z_channels: int = 16 | |
| scale_factor: float = 0.3611 | |
| shift_factor: float = 0.1159 | |
| use_quant_conv: bool = False | |
| use_post_quant_conv: bool = False | |
| def swish(x: Tensor) -> Tensor: | |
| return x * torch.sigmoid(x) | |
| class AttnBlock(nn.Module): | |
| def __init__(self, in_channels: int): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.norm = nn.GroupNorm( | |
| num_groups=32, num_channels=in_channels, eps=1e-6, affine=True | |
| ) | |
| self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) | |
| self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) | |
| self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) | |
| self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) | |
| def attention(self, h_: Tensor) -> Tensor: | |
| h_ = self.norm(h_) | |
| q = self.q(h_) | |
| k = self.k(h_) | |
| v = self.v(h_) | |
| b, c, h, w = q.shape | |
| q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() | |
| k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() | |
| v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() | |
| h_ = nn.functional.scaled_dot_product_attention(q, k, v) | |
| return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return x + self.proj_out(self.attention(x)) | |
| class ResnetBlock(nn.Module): | |
| def __init__(self, in_channels: int, out_channels: int): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| out_channels = in_channels if out_channels is None else out_channels | |
| self.out_channels = out_channels | |
| self.norm1 = nn.GroupNorm( | |
| num_groups=32, num_channels=in_channels, eps=1e-6, affine=True | |
| ) | |
| self.conv1 = nn.Conv2d( | |
| in_channels, out_channels, kernel_size=3, stride=1, padding=1 | |
| ) | |
| self.norm2 = nn.GroupNorm( | |
| num_groups=32, num_channels=out_channels, eps=1e-6, affine=True | |
| ) | |
| self.conv2 = nn.Conv2d( | |
| out_channels, out_channels, kernel_size=3, stride=1, padding=1 | |
| ) | |
| if self.in_channels != self.out_channels: | |
| self.nin_shortcut = nn.Conv2d( | |
| in_channels, out_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| def forward(self, x): | |
| h = x | |
| h = self.norm1(h) | |
| h = swish(h) | |
| h = self.conv1(h) | |
| h = self.norm2(h) | |
| h = swish(h) | |
| h = self.conv2(h) | |
| if self.in_channels != self.out_channels: | |
| x = self.nin_shortcut(x) | |
| return x + h | |
| class Downsample(nn.Module): | |
| def __init__(self, in_channels: int): | |
| super().__init__() | |
| # no asymmetric padding in torch conv, must do it ourselves | |
| self.conv = nn.Conv2d( | |
| in_channels, in_channels, kernel_size=3, stride=2, padding=0 | |
| ) | |
| def forward(self, x: Tensor): | |
| pad = (0, 1, 0, 1) | |
| x = nn.functional.pad(x, pad, mode="constant", value=0) | |
| x = self.conv(x) | |
| return x | |
| class Upsample(nn.Module): | |
| def __init__(self, in_channels: int): | |
| super().__init__() | |
| self.conv = nn.Conv2d( | |
| in_channels, in_channels, kernel_size=3, stride=1, padding=1 | |
| ) | |
| def forward(self, x: Tensor): | |
| x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") | |
| x = self.conv(x) | |
| return x | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| resolution: int, | |
| in_channels: int, | |
| ch: int, | |
| ch_mult: list[int], | |
| num_res_blocks: int, | |
| z_channels: int, | |
| ): | |
| super().__init__() | |
| self.ch = ch | |
| self.num_resolutions = len(ch_mult) | |
| self.num_res_blocks = num_res_blocks | |
| self.resolution = resolution | |
| self.in_channels = in_channels | |
| # downsampling | |
| self.conv_in = nn.Conv2d( | |
| in_channels, self.ch, kernel_size=3, stride=1, padding=1 | |
| ) | |
| curr_res = resolution | |
| in_ch_mult = (1,) + tuple(ch_mult) | |
| self.in_ch_mult = in_ch_mult | |
| self.down = nn.ModuleList() | |
| block_in = self.ch | |
| for i_level in range(self.num_resolutions): | |
| block = nn.ModuleList() | |
| attn = nn.ModuleList() | |
| block_in = ch * in_ch_mult[i_level] | |
| block_out = ch * ch_mult[i_level] | |
| for _ in range(self.num_res_blocks): | |
| block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) | |
| block_in = block_out | |
| down = nn.Module() | |
| down.block = block | |
| down.attn = attn | |
| if i_level != self.num_resolutions - 1: | |
| down.downsample = Downsample(block_in) | |
| curr_res = curr_res // 2 | |
| self.down.append(down) | |
| # middle | |
| self.mid = nn.Module() | |
| self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) | |
| self.mid.attn_1 = AttnBlock(block_in) | |
| self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) | |
| # end | |
| self.norm_out = nn.GroupNorm( | |
| num_groups=32, num_channels=block_in, eps=1e-6, affine=True | |
| ) | |
| self.conv_out = nn.Conv2d( | |
| block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1 | |
| ) | |
| def forward(self, x: Tensor) -> Tensor: | |
| # downsampling | |
| hs = [self.conv_in(x)] | |
| for i_level in range(self.num_resolutions): | |
| for i_block in range(self.num_res_blocks): | |
| h = self.down[i_level].block[i_block](hs[-1]) | |
| if len(self.down[i_level].attn) > 0: | |
| h = self.down[i_level].attn[i_block](h) | |
| hs.append(h) | |
| if i_level != self.num_resolutions - 1: | |
| hs.append(self.down[i_level].downsample(hs[-1])) | |
| # middle | |
| h = hs[-1] | |
| h = self.mid.block_1(h) | |
| h = self.mid.attn_1(h) | |
| h = self.mid.block_2(h) | |
| # end | |
| h = self.norm_out(h) | |
| h = swish(h) | |
| h = self.conv_out(h) | |
| return h | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| ch: int, | |
| out_ch: int, | |
| ch_mult: list[int], | |
| num_res_blocks: int, | |
| in_channels: int, | |
| resolution: int, | |
| z_channels: int, | |
| ): | |
| super().__init__() | |
| self.ch = ch | |
| self.num_resolutions = len(ch_mult) | |
| self.num_res_blocks = num_res_blocks | |
| self.resolution = resolution | |
| self.in_channels = in_channels | |
| self.ffactor = 2 ** (self.num_resolutions - 1) | |
| # compute in_ch_mult, block_in and curr_res at lowest res | |
| block_in = ch * ch_mult[self.num_resolutions - 1] | |
| curr_res = resolution // 2 ** (self.num_resolutions - 1) | |
| self.z_shape = (1, z_channels, curr_res, curr_res) | |
| # z to block_in | |
| self.conv_in = nn.Conv2d( | |
| z_channels, block_in, kernel_size=3, stride=1, padding=1 | |
| ) | |
| # middle | |
| self.mid = nn.Module() | |
| self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) | |
| self.mid.attn_1 = AttnBlock(block_in) | |
| self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) | |
| # upsampling | |
| self.up = nn.ModuleList() | |
| for i_level in reversed(range(self.num_resolutions)): | |
| block = nn.ModuleList() | |
| attn = nn.ModuleList() | |
| block_out = ch * ch_mult[i_level] | |
| for _ in range(self.num_res_blocks + 1): | |
| block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) | |
| block_in = block_out | |
| up = nn.Module() | |
| up.block = block | |
| up.attn = attn | |
| if i_level != 0: | |
| up.upsample = Upsample(block_in) | |
| curr_res = curr_res * 2 | |
| self.up.insert(0, up) # prepend to get consistent order | |
| # end | |
| self.norm_out = nn.GroupNorm( | |
| num_groups=32, num_channels=block_in, eps=1e-6, affine=True | |
| ) | |
| self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) | |
| def forward(self, z: Tensor) -> Tensor: | |
| # get dtype for proper tracing | |
| upscale_dtype = next(self.up.parameters()).dtype | |
| # z to block_in | |
| h = self.conv_in(z) | |
| # middle | |
| h = self.mid.block_1(h) | |
| h = self.mid.attn_1(h) | |
| h = self.mid.block_2(h) | |
| # cast to proper dtype | |
| h = h.to(upscale_dtype) | |
| # upsampling | |
| for i_level in reversed(range(self.num_resolutions)): | |
| for i_block in range(self.num_res_blocks + 1): | |
| h = self.up[i_level].block[i_block](h) | |
| if len(self.up[i_level].attn) > 0: | |
| h = self.up[i_level].attn[i_block](h) | |
| if i_level != 0: | |
| h = self.up[i_level].upsample(h) | |
| # end | |
| h = self.norm_out(h) | |
| h = swish(h) | |
| h = self.conv_out(h) | |
| return h | |
| class DiagonalGaussian(nn.Module): | |
| def __init__(self, sample: bool = True, chunk_dim: int = 1): | |
| super().__init__() | |
| self.sample = sample | |
| self.chunk_dim = chunk_dim | |
| def forward(self, z: Tensor) -> Tensor: | |
| mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) | |
| if self.sample: | |
| std = torch.exp(0.5 * logvar) | |
| return mean + std * torch.randn_like(mean) | |
| else: | |
| return mean | |
| class AutoEncoder(nn.Module): | |
| def __init__(self, params: AutoEncoderParams): | |
| super().__init__() | |
| self.params = params | |
| self.encoder = Encoder( | |
| resolution=params.resolution, | |
| in_channels=params.in_channels, | |
| ch=params.ch, | |
| ch_mult=params.ch_mult, | |
| num_res_blocks=params.num_res_blocks, | |
| z_channels=params.z_channels, | |
| ) | |
| self.decoder = Decoder( | |
| resolution=params.resolution, | |
| in_channels=params.in_channels, | |
| ch=params.ch, | |
| out_ch=params.out_ch, | |
| ch_mult=params.ch_mult, | |
| num_res_blocks=params.num_res_blocks, | |
| z_channels=params.z_channels, | |
| ) | |
| self.reg = DiagonalGaussian() | |
| self.scale_factor = params.scale_factor | |
| self.shift_factor = params.shift_factor | |
| self.quant_conv = nn.Conv2d(2 * params.z_channels, 2 * params.z_channels, 1) if params.use_quant_conv else None | |
| self.post_quant_conv = nn.Conv2d(params.z_channels, params.z_channels, 1) if params.use_post_quant_conv else None | |
| def encode(self, x: Tensor) -> Tensor: | |
| x = self.encoder(x) | |
| if self.quant_conv is not None: | |
| x = self.quant_conv(x) | |
| z = self.reg(x) | |
| z = self.scale_factor * (z - self.shift_factor) | |
| return z | |
| def decode(self, z: Tensor) -> Tensor: | |
| z = z / self.scale_factor + self.shift_factor | |
| if self.post_quant_conv is not None: | |
| z = self.post_quant_conv(z) | |
| return self.decoder(z) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self.decode(self.encode(x)) | |
| def load_ae( | |
| ckpt_path: str, | |
| autoencoder_params: AutoEncoderParams, | |
| device: str | torch.device = "cuda", | |
| dtype=torch.bfloat16, | |
| random_init=False, | |
| ) -> AutoEncoder: | |
| """ | |
| Load the autoencoder from the given model name. | |
| Args: | |
| name (str): The name of the autoencoder. | |
| device (str or torch.device): The device to load the autoencoder to. | |
| Returns: | |
| AutoEncoder: The loaded autoencoder. | |
| """ | |
| # Loading the autoencoder | |
| with torch.device(device): | |
| ae = AutoEncoder(autoencoder_params) | |
| if random_init: | |
| print(f"Random Init VAE") | |
| return ae.to(dtype=dtype) | |
| if not os.path.exists(ckpt_path): | |
| raise ValueError( | |
| f"Autoencoder path {ckpt_path} does not exist. Please download it first." | |
| ) | |
| if ckpt_path is not None: | |
| print(f"Loading {ckpt_path}") | |
| sd = load_sft(ckpt_path, device=str(device)) | |
| missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) | |
| if len(missing) > 0: | |
| print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) | |
| if len(unexpected) > 0: | |
| print( | |
| f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected) | |
| ) | |
| return ae.to(dtype=dtype) | |