Spaces:
Paused
Paused
VAE Decoder: Inject noise between conv layers.
Browse files1.Add inject_noise flag to res_x, rex_x_y blocks.
2.Init noise to zero in ResnetBlock3D constructor.
2.Add _feed_spatial_noise method to inject noise between conv layers.
xora/models/autoencoders/causal_video_autoencoder.py
CHANGED
|
@@ -481,6 +481,7 @@ class Decoder(nn.Module):
|
|
| 481 |
resnet_eps=1e-6,
|
| 482 |
resnet_groups=norm_num_groups,
|
| 483 |
norm_layer=norm_layer,
|
|
|
|
| 484 |
)
|
| 485 |
elif block_name == "res_x_y":
|
| 486 |
output_channel = output_channel // block_params.get("multiplier", 2)
|
|
@@ -491,6 +492,7 @@ class Decoder(nn.Module):
|
|
| 491 |
eps=1e-6,
|
| 492 |
groups=norm_num_groups,
|
| 493 |
norm_layer=norm_layer,
|
|
|
|
| 494 |
)
|
| 495 |
elif block_name == "compress_time":
|
| 496 |
block = DepthToSpaceUpsample(
|
|
@@ -583,6 +585,7 @@ class UNetMidBlock3D(nn.Module):
|
|
| 583 |
resnet_eps: float = 1e-6,
|
| 584 |
resnet_groups: int = 32,
|
| 585 |
norm_layer: str = "group_norm",
|
|
|
|
| 586 |
):
|
| 587 |
super().__init__()
|
| 588 |
resnet_groups = (
|
|
@@ -599,6 +602,7 @@ class UNetMidBlock3D(nn.Module):
|
|
| 599 |
groups=resnet_groups,
|
| 600 |
dropout=dropout,
|
| 601 |
norm_layer=norm_layer,
|
|
|
|
| 602 |
)
|
| 603 |
for _ in range(num_layers)
|
| 604 |
]
|
|
@@ -690,11 +694,13 @@ class ResnetBlock3D(nn.Module):
|
|
| 690 |
groups: int = 32,
|
| 691 |
eps: float = 1e-6,
|
| 692 |
norm_layer: str = "group_norm",
|
|
|
|
| 693 |
):
|
| 694 |
super().__init__()
|
| 695 |
self.in_channels = in_channels
|
| 696 |
out_channels = in_channels if out_channels is None else out_channels
|
| 697 |
self.out_channels = out_channels
|
|
|
|
| 698 |
|
| 699 |
if norm_layer == "group_norm":
|
| 700 |
self.norm1 = nn.GroupNorm(
|
|
@@ -717,6 +723,9 @@ class ResnetBlock3D(nn.Module):
|
|
| 717 |
causal=True,
|
| 718 |
)
|
| 719 |
|
|
|
|
|
|
|
|
|
|
| 720 |
if norm_layer == "group_norm":
|
| 721 |
self.norm2 = nn.GroupNorm(
|
| 722 |
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
|
@@ -738,6 +747,9 @@ class ResnetBlock3D(nn.Module):
|
|
| 738 |
causal=True,
|
| 739 |
)
|
| 740 |
|
|
|
|
|
|
|
|
|
|
| 741 |
self.conv_shortcut = (
|
| 742 |
make_linear_nd(
|
| 743 |
dims=dims, in_channels=in_channels, out_channels=out_channels
|
|
@@ -752,6 +764,20 @@ class ResnetBlock3D(nn.Module):
|
|
| 752 |
else nn.Identity()
|
| 753 |
)
|
| 754 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 755 |
def forward(
|
| 756 |
self,
|
| 757 |
input_tensor: torch.FloatTensor,
|
|
@@ -765,6 +791,11 @@ class ResnetBlock3D(nn.Module):
|
|
| 765 |
|
| 766 |
hidden_states = self.conv1(hidden_states, causal=causal)
|
| 767 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 768 |
hidden_states = self.norm2(hidden_states)
|
| 769 |
|
| 770 |
hidden_states = self.non_linearity(hidden_states)
|
|
@@ -773,6 +804,11 @@ class ResnetBlock3D(nn.Module):
|
|
| 773 |
|
| 774 |
hidden_states = self.conv2(hidden_states, causal=causal)
|
| 775 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 776 |
input_tensor = self.norm3(input_tensor)
|
| 777 |
|
| 778 |
input_tensor = self.conv_shortcut(input_tensor)
|
|
|
|
| 481 |
resnet_eps=1e-6,
|
| 482 |
resnet_groups=norm_num_groups,
|
| 483 |
norm_layer=norm_layer,
|
| 484 |
+
inject_noise=block_params.get("inject_noise", False),
|
| 485 |
)
|
| 486 |
elif block_name == "res_x_y":
|
| 487 |
output_channel = output_channel // block_params.get("multiplier", 2)
|
|
|
|
| 492 |
eps=1e-6,
|
| 493 |
groups=norm_num_groups,
|
| 494 |
norm_layer=norm_layer,
|
| 495 |
+
inject_noise=block_params.get("inject_noise", False),
|
| 496 |
)
|
| 497 |
elif block_name == "compress_time":
|
| 498 |
block = DepthToSpaceUpsample(
|
|
|
|
| 585 |
resnet_eps: float = 1e-6,
|
| 586 |
resnet_groups: int = 32,
|
| 587 |
norm_layer: str = "group_norm",
|
| 588 |
+
inject_noise: bool = False,
|
| 589 |
):
|
| 590 |
super().__init__()
|
| 591 |
resnet_groups = (
|
|
|
|
| 602 |
groups=resnet_groups,
|
| 603 |
dropout=dropout,
|
| 604 |
norm_layer=norm_layer,
|
| 605 |
+
inject_noise=inject_noise,
|
| 606 |
)
|
| 607 |
for _ in range(num_layers)
|
| 608 |
]
|
|
|
|
| 694 |
groups: int = 32,
|
| 695 |
eps: float = 1e-6,
|
| 696 |
norm_layer: str = "group_norm",
|
| 697 |
+
inject_noise: bool = False,
|
| 698 |
):
|
| 699 |
super().__init__()
|
| 700 |
self.in_channels = in_channels
|
| 701 |
out_channels = in_channels if out_channels is None else out_channels
|
| 702 |
self.out_channels = out_channels
|
| 703 |
+
self.inject_noise = inject_noise
|
| 704 |
|
| 705 |
if norm_layer == "group_norm":
|
| 706 |
self.norm1 = nn.GroupNorm(
|
|
|
|
| 723 |
causal=True,
|
| 724 |
)
|
| 725 |
|
| 726 |
+
if inject_noise:
|
| 727 |
+
self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
| 728 |
+
|
| 729 |
if norm_layer == "group_norm":
|
| 730 |
self.norm2 = nn.GroupNorm(
|
| 731 |
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
|
|
|
| 747 |
causal=True,
|
| 748 |
)
|
| 749 |
|
| 750 |
+
if inject_noise:
|
| 751 |
+
self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
| 752 |
+
|
| 753 |
self.conv_shortcut = (
|
| 754 |
make_linear_nd(
|
| 755 |
dims=dims, in_channels=in_channels, out_channels=out_channels
|
|
|
|
| 764 |
else nn.Identity()
|
| 765 |
)
|
| 766 |
|
| 767 |
+
def _feed_spatial_noise(
|
| 768 |
+
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
|
| 769 |
+
) -> torch.FloatTensor:
|
| 770 |
+
spatial_shape = hidden_states.shape[-2:]
|
| 771 |
+
device = hidden_states.device
|
| 772 |
+
dtype = hidden_states.dtype
|
| 773 |
+
|
| 774 |
+
# similar to the "explicit noise inputs" method in style-gan
|
| 775 |
+
spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
|
| 776 |
+
scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
|
| 777 |
+
hidden_states = hidden_states + scaled_noise
|
| 778 |
+
|
| 779 |
+
return hidden_states
|
| 780 |
+
|
| 781 |
def forward(
|
| 782 |
self,
|
| 783 |
input_tensor: torch.FloatTensor,
|
|
|
|
| 791 |
|
| 792 |
hidden_states = self.conv1(hidden_states, causal=causal)
|
| 793 |
|
| 794 |
+
if self.inject_noise:
|
| 795 |
+
hidden_states = self._feed_spatial_noise(
|
| 796 |
+
hidden_states, self.per_channel_scale1
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
hidden_states = self.norm2(hidden_states)
|
| 800 |
|
| 801 |
hidden_states = self.non_linearity(hidden_states)
|
|
|
|
| 804 |
|
| 805 |
hidden_states = self.conv2(hidden_states, causal=causal)
|
| 806 |
|
| 807 |
+
if self.inject_noise:
|
| 808 |
+
hidden_states = self._feed_spatial_noise(
|
| 809 |
+
hidden_states, self.per_channel_scale2
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
input_tensor = self.norm3(input_tensor)
|
| 813 |
|
| 814 |
input_tensor = self.conv_shortcut(input_tensor)
|