| | import torch |
| | import torch.nn as nn |
| | from einops import rearrange |
| |
|
| | from .gfq import GFQ |
| |
|
| | def swish(x): |
| | |
| | return x*torch.sigmoid(x) |
| |
|
| | class ResBlock(nn.Module): |
| | def __init__(self, |
| | in_filters, |
| | out_filters, |
| | use_conv_shortcut = False, |
| | use_agn = False, |
| | ) -> None: |
| | super().__init__() |
| |
|
| | self.in_filters = in_filters |
| | self.out_filters = out_filters |
| | self.use_conv_shortcut = use_conv_shortcut |
| | self.use_agn = use_agn |
| |
|
| | if not use_agn: |
| | self.norm1 = nn.GroupNorm(32, in_filters, eps=1e-6) |
| | self.norm2 = nn.GroupNorm(32, out_filters, eps=1e-6) |
| |
|
| | self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False) |
| | self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False) |
| |
|
| | if in_filters != out_filters: |
| | if self.use_conv_shortcut: |
| | self.conv_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False) |
| | else: |
| | self.nin_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=(1, 1), padding=0, bias=False) |
| | |
| |
|
| | def forward(self, x, **kwargs): |
| | residual = x |
| |
|
| | if not self.use_agn: |
| | x = self.norm1(x) |
| | x = swish(x) |
| | x = self.conv1(x) |
| | x = self.norm2(x) |
| | x = swish(x) |
| | x = self.conv2(x) |
| | if self.in_filters != self.out_filters: |
| | if self.use_conv_shortcut: |
| | residual = self.conv_shortcut(residual) |
| | else: |
| | residual = self.nin_shortcut(residual) |
| |
|
| | return x + residual |
| | |
| | class Encoder(nn.Module): |
| | def __init__(self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4), |
| | resolution=None, double_z=False, |
| | ): |
| | super().__init__() |
| |
|
| | self.in_channels = in_channels |
| | self.z_channels = z_channels |
| | self.resolution = resolution |
| |
|
| | self.num_res_blocks = num_res_blocks |
| | self.num_blocks = len(ch_mult) |
| | |
| | self.conv_in = nn.Conv2d(in_channels, |
| | ch, |
| | kernel_size=(3, 3), |
| | padding=1, |
| | bias=False |
| | ) |
| |
|
| | |
| | self.down = nn.ModuleList() |
| |
|
| | in_ch_mult = (1,)+tuple(ch_mult) |
| | for i_level in range(self.num_blocks): |
| | block = nn.ModuleList() |
| | block_in = ch*in_ch_mult[i_level] |
| | block_out = ch*ch_mult[i_level] |
| | for _ in range(self.num_res_blocks): |
| | block.append(ResBlock(block_in, block_out)) |
| | block_in = block_out |
| | |
| | down = nn.Module() |
| | down.block = block |
| | if i_level < self.num_blocks - 1: |
| | down.downsample = nn.Conv2d(block_out, block_out, kernel_size=(3, 3), stride=(2, 2), padding=1) |
| |
|
| | self.down.append(down) |
| | |
| | |
| | self.mid_block = nn.ModuleList() |
| | for res_idx in range(self.num_res_blocks): |
| | self.mid_block.append(ResBlock(block_in, block_in)) |
| | |
| | |
| | self.norm_out = nn.GroupNorm(32, block_out, eps=1e-6) |
| | self.conv_out = nn.Conv2d(block_out, z_channels, kernel_size=(1, 1)) |
| | |
| | def forward(self, x): |
| |
|
| | |
| | x = self.conv_in(x) |
| | for i_level in range(self.num_blocks): |
| | for i_block in range(self.num_res_blocks): |
| | x = self.down[i_level].block[i_block](x) |
| | |
| | if i_level < self.num_blocks - 1: |
| | x = self.down[i_level].downsample(x) |
| | |
| | |
| | for res in range(self.num_res_blocks): |
| | x = self.mid_block[res](x) |
| | |
| |
|
| | x = self.norm_out(x) |
| | x = swish(x) |
| | x = self.conv_out(x) |
| |
|
| | return x |
| |
|
| | class Decoder(nn.Module): |
| | def __init__(self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4), |
| | resolution=None, double_z=False,) -> None: |
| | super().__init__() |
| |
|
| | self.ch = ch |
| | self.num_blocks = len(ch_mult) |
| | self.num_res_blocks = num_res_blocks |
| | self.resolution = resolution |
| | self.in_channels = in_channels |
| |
|
| | block_in = ch*ch_mult[self.num_blocks-1] |
| |
|
| | self.conv_in = nn.Conv2d( |
| | z_channels, block_in, kernel_size=(3, 3), padding=1, bias=True |
| | ) |
| |
|
| | self.mid_block = nn.ModuleList() |
| | for res_idx in range(self.num_res_blocks): |
| | self.mid_block.append(ResBlock(block_in, block_in)) |
| | |
| | self.up = nn.ModuleList() |
| |
|
| | self.adaptive = nn.ModuleList() |
| |
|
| | for i_level in reversed(range(self.num_blocks)): |
| | block = nn.ModuleList() |
| | block_out = ch*ch_mult[i_level] |
| | self.adaptive.insert(0, AdaptiveGroupNorm(z_channels, block_in)) |
| | for i_block in range(self.num_res_blocks): |
| | block.append(ResBlock(block_in, block_out)) |
| | block_in = block_out |
| | |
| | up = nn.Module() |
| | up.block = block |
| | if i_level > 0: |
| | up.upsample = Upsampler(block_in) |
| | self.up.insert(0, up) |
| | |
| | self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6) |
| |
|
| | self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=(3, 3), padding=1) |
| | |
| | def forward(self, z): |
| | |
| | style = z.clone() |
| |
|
| | z = self.conv_in(z) |
| |
|
| | |
| | for res in range(self.num_res_blocks): |
| | z = self.mid_block[res](z) |
| | |
| | |
| | for i_level in reversed(range(self.num_blocks)): |
| | |
| | z = self.adaptive[i_level](z, style) |
| | for i_block in range(self.num_res_blocks): |
| | z = self.up[i_level].block[i_block](z) |
| | |
| | if i_level > 0: |
| | z = self.up[i_level].upsample(z) |
| | |
| | z = self.norm_out(z) |
| | z = swish(z) |
| | z = self.conv_out(z) |
| |
|
| | return z |
| |
|
| | def depth_to_space(x: torch.Tensor, block_size: int) -> torch.Tensor: |
| | """ Depth-to-Space DCR mode (depth-column-row) core implementation. |
| | |
| | Args: |
| | x (torch.Tensor): input tensor. The channels-first (*CHW) layout is supported. |
| | block_size (int): block side size |
| | """ |
| | |
| | if x.dim() < 3: |
| | raise ValueError( |
| | f"Expecting a channels-first (*CHW) tensor of at least 3 dimensions" |
| | ) |
| | c, h, w = x.shape[-3:] |
| |
|
| | s = block_size**2 |
| | if c % s != 0: |
| | raise ValueError( |
| | f"Expecting a channels-first (*CHW) tensor with C divisible by {s}, but got C={c} channels" |
| | ) |
| |
|
| | outer_dims = x.shape[:-3] |
| |
|
| | |
| | x = x.view(-1, block_size, block_size, c // s, h, w) |
| |
|
| | |
| | x = x.permute(0, 3, 4, 1, 5, 2) |
| |
|
| | |
| | x = x.contiguous().view(*outer_dims, c // s, h * block_size, |
| | w * block_size) |
| |
|
| | return x |
| |
|
| | class Upsampler(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | dim_out = None |
| | ): |
| | super().__init__() |
| | dim_out = dim * 4 |
| | self.conv1 = nn.Conv2d(dim, dim_out, (3, 3), padding=1) |
| | self.depth2space = depth_to_space |
| |
|
| | def forward(self, x): |
| | """ |
| | input_image: [B C H W] |
| | """ |
| | out = self.conv1(x) |
| | out = self.depth2space(out, block_size=2) |
| | return out |
| | |
| | class AdaptiveGroupNorm(nn.Module): |
| | def __init__(self, z_channel, in_filters, num_groups=32, eps=1e-6): |
| | super().__init__() |
| | self.gn = nn.GroupNorm(num_groups=32, num_channels=in_filters, eps=eps, affine=False) |
| | |
| | self.gamma = nn.Linear(z_channel, in_filters) |
| | self.beta = nn.Linear(z_channel, in_filters) |
| | self.eps = eps |
| | |
| | def forward(self, x, quantizer): |
| | B, C, _, _ = x.shape |
| | |
| | |
| | scale = rearrange(quantizer, "b c h w -> b c (h w)") |
| | scale = scale.var(dim=-1) + self.eps |
| | scale = scale.sqrt() |
| | scale = self.gamma(scale).view(B, C, 1, 1) |
| |
|
| | |
| | bias = rearrange(quantizer, "b c h w -> b c (h w)") |
| | bias = bias.mean(dim=-1) |
| | bias = self.beta(bias).view(B, C, 1, 1) |
| | |
| | x = self.gn(x) |
| | x = scale * x + bias |
| |
|
| | return x |
| |
|
| | class GANDecoder(nn.Module): |
| | def __init__(self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4), |
| | resolution=None, double_z=False,) -> None: |
| | super().__init__() |
| |
|
| | self.ch = ch |
| | self.num_blocks = len(ch_mult) |
| | self.num_res_blocks = num_res_blocks |
| | self.resolution = resolution |
| | self.in_channels = in_channels |
| |
|
| | block_in = ch*ch_mult[self.num_blocks-1] |
| |
|
| | self.conv_in = nn.Conv2d( |
| | z_channels * 2, block_in, kernel_size=(3, 3), padding=1, bias=True |
| | ) |
| |
|
| | self.mid_block = nn.ModuleList() |
| | for res_idx in range(self.num_res_blocks): |
| | self.mid_block.append(ResBlock(block_in, block_in)) |
| | |
| | self.up = nn.ModuleList() |
| |
|
| | self.adaptive = nn.ModuleList() |
| |
|
| | for i_level in reversed(range(self.num_blocks)): |
| | block = nn.ModuleList() |
| | block_out = ch*ch_mult[i_level] |
| | self.adaptive.insert(0, AdaptiveGroupNorm(z_channels, block_in)) |
| | for i_block in range(self.num_res_blocks): |
| | |
| | |
| | |
| | block.append(ResBlock(block_in, block_out)) |
| | block_in = block_out |
| | |
| | up = nn.Module() |
| | up.block = block |
| | if i_level > 0: |
| | up.upsample = Upsampler(block_in) |
| | self.up.insert(0, up) |
| | |
| | self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6) |
| |
|
| | self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=(3, 3), padding=1) |
| | |
| | def forward(self, z): |
| | |
| | style = z.clone() |
| |
|
| | noise = torch.randn_like(z).to(z.device) |
| | z = torch.cat([z, noise], dim=1) |
| | z = self.conv_in(z) |
| |
|
| | |
| | for res in range(self.num_res_blocks): |
| | z = self.mid_block[res](z) |
| | |
| | |
| | for i_level in reversed(range(self.num_blocks)): |
| | |
| | z = self.adaptive[i_level](z, style) |
| | for i_block in range(self.num_res_blocks): |
| | z = self.up[i_level].block[i_block](z) |
| | |
| | if i_level > 0: |
| | z = self.up[i_level].upsample(z) |
| | |
| | z = self.norm_out(z) |
| | z = swish(z) |
| | z = self.conv_out(z) |
| |
|
| | return z |
| | |
| |
|
| | class VQModel(nn.Module): |
| | def __init__(self, |
| | ddconfig, |
| | num_codebooks = 1, |
| | sample_minimization_weight=1, |
| | batch_maximization_weight=1, |
| | gan_decoder = False, |
| | |
| | ): |
| | super().__init__() |
| | self.encoder = Encoder(**ddconfig) |
| | self.decoder = GANDecoder(**ddconfig) if gan_decoder else Decoder(**ddconfig) |
| | self.quantize = GFQ(dim=ddconfig.get("z_channels", 32), |
| | num_codebooks=num_codebooks, |
| | sample_minimization_weight=sample_minimization_weight, |
| | batch_maximization_weight=batch_maximization_weight, |
| | ) |
| |
|
| | def encode(self, x): |
| | h = self.encoder(x) |
| | (quant, emb_loss, info), loss_breakdown = self.quantize(h, return_loss_breakdown=True) |
| | return quant, emb_loss, info, loss_breakdown |
| |
|
| | def decode(self, quant): |
| | dec = self.decoder(quant) |
| | return dec |
| |
|
| | def forward(self, input): |
| | quant, _, _, loss_break = self.encode(input) |
| | dec = self.decode(quant) |
| | return dec, loss_break |