| | |
| |
|
| | import torch |
| |
|
| |
|
| | class HypernetworkModule(torch.nn.Module): |
| | def __init__(self, dim, multiplier=1.0): |
| | super().__init__() |
| |
|
| | linear1 = torch.nn.Linear(dim, dim * 2) |
| | linear2 = torch.nn.Linear(dim * 2, dim) |
| | linear1.weight.data.normal_(mean=0.0, std=0.01) |
| | linear1.bias.data.zero_() |
| | linear2.weight.data.normal_(mean=0.0, std=0.01) |
| | linear2.bias.data.zero_() |
| | linears = [linear1, linear2] |
| |
|
| | self.linear = torch.nn.Sequential(*linears) |
| | self.multiplier = multiplier |
| |
|
| | def forward(self, x): |
| | return x + self.linear(x) * self.multiplier |
| |
|
| |
|
| | class Hypernetwork(torch.nn.Module): |
| | enable_sizes = [320, 640, 768, 1280] |
| | |
| |
|
| | def __init__(self, multiplier=1.0) -> None: |
| | super().__init__() |
| | self.modules = [] |
| | for size in Hypernetwork.enable_sizes: |
| | self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier))) |
| | self.register_module(f"{size}_0", self.modules[-1][0]) |
| | self.register_module(f"{size}_1", self.modules[-1][1]) |
| |
|
| | def apply_to_stable_diffusion(self, text_encoder, vae, unet): |
| | blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks |
| | for block in blocks: |
| | for subblk in block: |
| | if 'SpatialTransformer' in str(type(subblk)): |
| | for tf_block in subblk.transformer_blocks: |
| | for attn in [tf_block.attn1, tf_block.attn2]: |
| | size = attn.context_dim |
| | if size in Hypernetwork.enable_sizes: |
| | attn.hypernetwork = self |
| | else: |
| | attn.hypernetwork = None |
| |
|
| | def apply_to_diffusers(self, text_encoder, vae, unet): |
| | blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks |
| | for block in blocks: |
| | if hasattr(block, 'attentions'): |
| | for subblk in block.attentions: |
| | if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): |
| | for tf_block in subblk.transformer_blocks: |
| | for attn in [tf_block.attn1, tf_block.attn2]: |
| | size = attn.to_k.in_features |
| | if size in Hypernetwork.enable_sizes: |
| | attn.hypernetwork = self |
| | else: |
| | attn.hypernetwork = None |
| | return True |
| |
|
| | def forward(self, x, context): |
| | size = context.shape[-1] |
| | assert size in Hypernetwork.enable_sizes |
| | module = self.modules[Hypernetwork.enable_sizes.index(size)] |
| | return module[0].forward(context), module[1].forward(context) |
| |
|
| | def load_from_state_dict(self, state_dict): |
| | |
| | changes = { |
| | 'linear1.bias': 'linear.0.bias', |
| | 'linear1.weight': 'linear.0.weight', |
| | 'linear2.bias': 'linear.1.bias', |
| | 'linear2.weight': 'linear.1.weight', |
| | } |
| | for key_from, key_to in changes.items(): |
| | if key_from in state_dict: |
| | state_dict[key_to] = state_dict[key_from] |
| | del state_dict[key_from] |
| |
|
| | for size, sd in state_dict.items(): |
| | if type(size) == int: |
| | self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True) |
| | self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True) |
| | return True |
| |
|
| | def get_state_dict(self): |
| | state_dict = {} |
| | for i, size in enumerate(Hypernetwork.enable_sizes): |
| | sd0 = self.modules[i][0].state_dict() |
| | sd1 = self.modules[i][1].state_dict() |
| | state_dict[size] = [sd0, sd1] |
| | return state_dict |
| |
|