ouclxy commited on
Commit
0daa129
·
verified ·
1 Parent(s): 8ca3766

Upload 72 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. hair_resposity/2.jpg +0 -0
  3. hair_resposity/6.jpg +3 -0
  4. hair_resposity/7.jpg +3 -0
  5. imgs/background.png +3 -0
  6. imgs/multiview1.gif +3 -0
  7. imgs/multiview2.gif +3 -0
  8. imgs/teaser.jpg +3 -0
  9. ref_encoder/__init__.py +0 -0
  10. ref_encoder/adapter.py +40 -0
  11. ref_encoder/attention.py +481 -0
  12. ref_encoder/attention_processor.py +391 -0
  13. ref_encoder/latent_controlnet.py +0 -0
  14. ref_encoder/motion_module.py +388 -0
  15. ref_encoder/mutual_self_attention.py +365 -0
  16. ref_encoder/pose_guider.py +57 -0
  17. ref_encoder/reference_control.py +528 -0
  18. ref_encoder/reference_unet.py +1053 -0
  19. ref_encoder/reference_unetv2.py +1037 -0
  20. ref_encoder/resnet.py +263 -0
  21. ref_encoder/transformer_2d.py +396 -0
  22. ref_encoder/transformer_3d.py +202 -0
  23. ref_encoder/unet_2d_blocks.py +1074 -0
  24. ref_encoder/unet_2d_condition.py +1308 -0
  25. ref_encoder/unet_3d.py +702 -0
  26. ref_encoder/unet_3d_blocks.py +906 -0
  27. src/__init__.py +0 -0
  28. src/dataset/dance_image.py +124 -0
  29. src/dataset/dance_video.py +137 -0
  30. src/dwpose/__init__.py +123 -0
  31. src/dwpose/onnxdet.py +130 -0
  32. src/dwpose/onnxpose.py +370 -0
  33. src/dwpose/util.py +378 -0
  34. src/dwpose/wholebody.py +48 -0
  35. src/models/attention.py +481 -0
  36. src/models/motion_module.py +388 -0
  37. src/models/mutual_self_attention.py +365 -0
  38. src/models/pose_guider.py +57 -0
  39. src/models/resnet.py +252 -0
  40. src/models/transformer_2d.py +396 -0
  41. src/models/transformer_3d.py +202 -0
  42. src/models/unet_2d_blocks.py +1074 -0
  43. src/models/unet_2d_condition.py +1308 -0
  44. src/models/unet_3d.py +707 -0
  45. src/models/unet_3d_blocks.py +906 -0
  46. src/pipelines/__init__.py +0 -0
  47. src/pipelines/context.py +76 -0
  48. src/pipelines/pipeline_lmks2vid_long.py +622 -0
  49. src/pipelines/pipeline_pose2img.py +360 -0
  50. src/pipelines/pipeline_pose2vid.py +458 -0
.gitattributes CHANGED
@@ -44,3 +44,9 @@ hair_resposity/8.jpg filter=lfs diff=lfs merge=lfs -text
44
  test_imgs/ref2.jpg filter=lfs diff=lfs merge=lfs -text
45
  test_imgs/wjl.jpg filter=lfs diff=lfs merge=lfs -text
46
  test_imgs/zzf.jpg filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
44
  test_imgs/ref2.jpg filter=lfs diff=lfs merge=lfs -text
45
  test_imgs/wjl.jpg filter=lfs diff=lfs merge=lfs -text
46
  test_imgs/zzf.jpg filter=lfs diff=lfs merge=lfs -text
47
+ hair_resposity/6.jpg filter=lfs diff=lfs merge=lfs -text
48
+ hair_resposity/7.jpg filter=lfs diff=lfs merge=lfs -text
49
+ imgs/background.png filter=lfs diff=lfs merge=lfs -text
50
+ imgs/multiview1.gif filter=lfs diff=lfs merge=lfs -text
51
+ imgs/multiview2.gif filter=lfs diff=lfs merge=lfs -text
52
+ imgs/teaser.jpg filter=lfs diff=lfs merge=lfs -text
hair_resposity/2.jpg ADDED
hair_resposity/6.jpg ADDED

Git LFS Details

  • SHA256: e5268b6ab06260858316ba7ef965b7d352d2992d4f8351f940affc2e6959c5cb
  • Pointer size: 132 Bytes
  • Size of remote file: 3.15 MB
hair_resposity/7.jpg ADDED

Git LFS Details

  • SHA256: ed3641ab12d0e0129bb21c12207410f46cdb80241cb1e7c5c3bab26371734fc8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.42 MB
imgs/background.png ADDED

Git LFS Details

  • SHA256: 5235db98e1e2665efde4ca215472c0eb99fcd543659235ca0a6050d92d94117c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.34 MB
imgs/multiview1.gif ADDED

Git LFS Details

  • SHA256: 2af598ad60dd0b0d71cc962308ab8128dd0b3abf9202aa9c75e646850ce6f04c
  • Pointer size: 133 Bytes
  • Size of remote file: 18.2 MB
imgs/multiview2.gif ADDED

Git LFS Details

  • SHA256: 62b066dedf7663ba00fb85a204dbeca75d97feab9b231c7b2fc983bb27832cb8
  • Pointer size: 133 Bytes
  • Size of remote file: 19.6 MB
imgs/teaser.jpg ADDED

Git LFS Details

  • SHA256: b689abe4524c6c79c75b062c822643d0f52d6026ee3f84aa645897eb613b45b3
  • Pointer size: 131 Bytes
  • Size of remote file: 267 kB
ref_encoder/__init__.py ADDED
File without changes
ref_encoder/adapter.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ def is_torch2_available():
5
+ return hasattr(F, "scaled_dot_product_attention")
6
+ if is_torch2_available():
7
+ from .attention_processor import HairAttnProcessor2_0 as HairAttnProcessor, AttnProcessor2_0 as AttnProcessor
8
+ else:
9
+ from .attention_processor import HairAttnProcessor, AttnProcessor
10
+
11
+ def adapter_injection(unet, device="cuda", dtype=torch.float32, use_resampler=False, continue_learning_path=None):
12
+ device = device
13
+ dtype = dtype
14
+ # load Hair attention layers
15
+ attn_procs = {}
16
+ for name in unet.attn_processors.keys():
17
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
18
+ if name.startswith("mid_block"):
19
+ hidden_size = unet.config.block_out_channels[-1]
20
+ elif name.startswith("up_blocks"):
21
+ block_id = int(name[len("up_blocks.")])
22
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
23
+ elif name.startswith("down_blocks"):
24
+ block_id = int(name[len("down_blocks.")])
25
+ hidden_size = unet.config.block_out_channels[block_id]
26
+ if cross_attention_dim is None:
27
+ attn_procs[name] = HairAttnProcessor(hidden_size=hidden_size, cross_attention_dim=hidden_size, scale=1, use_resampler=use_resampler).to(device, dtype=dtype)
28
+ else:
29
+ attn_procs[name] = AttnProcessor()
30
+ unet.set_attn_processor(attn_procs)
31
+ adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
32
+ adapter_layers = adapter_modules
33
+ adapter_layers.to(device, dtype=dtype)
34
+
35
+ return adapter_layers
36
+
37
+ def set_scale(unet, scale):
38
+ for attn_processor in unet.attn_processors.values():
39
+ if isinstance(attn_processor, HairAttnProcessor):
40
+ attn_processor.scale = scale
ref_encoder/attention.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward
7
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
8
+ from einops import rearrange
9
+ from torch import nn
10
+
11
+
12
+ class BasicTransformerBlock(nn.Module):
13
+ r"""
14
+ A basic Transformer block.
15
+
16
+ Parameters:
17
+ dim (`int`): The number of channels in the input and output.
18
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
19
+ attention_head_dim (`int`): The number of channels in each head.
20
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
21
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
22
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
23
+ num_embeds_ada_norm (:
24
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
25
+ attention_bias (:
26
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
27
+ only_cross_attention (`bool`, *optional*):
28
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
29
+ double_self_attention (`bool`, *optional*):
30
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
31
+ upcast_attention (`bool`, *optional*):
32
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
33
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
34
+ Whether to use learnable elementwise affine parameters for normalization.
35
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
36
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
37
+ final_dropout (`bool` *optional*, defaults to False):
38
+ Whether to apply a final dropout after the last feed-forward layer.
39
+ attention_type (`str`, *optional*, defaults to `"default"`):
40
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
41
+ positional_embeddings (`str`, *optional*, defaults to `None`):
42
+ The type of positional embeddings to apply to.
43
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
44
+ The maximum number of positional embeddings to apply.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ dim: int,
50
+ num_attention_heads: int,
51
+ attention_head_dim: int,
52
+ dropout=0.0,
53
+ cross_attention_dim: Optional[int] = None,
54
+ activation_fn: str = "geglu",
55
+ num_embeds_ada_norm: Optional[int] = None,
56
+ attention_bias: bool = False,
57
+ only_cross_attention: bool = False,
58
+ double_self_attention: bool = False,
59
+ upcast_attention: bool = False,
60
+ norm_elementwise_affine: bool = True,
61
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
62
+ norm_eps: float = 1e-5,
63
+ final_dropout: bool = False,
64
+ attention_type: str = "default",
65
+ positional_embeddings: Optional[str] = None,
66
+ num_positional_embeddings: Optional[int] = None,
67
+ ):
68
+ super().__init__()
69
+ self.only_cross_attention = only_cross_attention
70
+
71
+ self.use_ada_layer_norm_zero = (
72
+ num_embeds_ada_norm is not None
73
+ ) and norm_type == "ada_norm_zero"
74
+ self.use_ada_layer_norm = (
75
+ num_embeds_ada_norm is not None
76
+ ) and norm_type == "ada_norm"
77
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
78
+ self.use_layer_norm = norm_type == "layer_norm"
79
+
80
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
81
+ raise ValueError(
82
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
83
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
84
+ )
85
+
86
+ if positional_embeddings and (num_positional_embeddings is None):
87
+ raise ValueError(
88
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
89
+ )
90
+
91
+ if positional_embeddings == "sinusoidal":
92
+ self.pos_embed = SinusoidalPositionalEmbedding(
93
+ dim, max_seq_length=num_positional_embeddings
94
+ )
95
+ else:
96
+ self.pos_embed = None
97
+
98
+ # Define 3 blocks. Each block has its own normalization layer.
99
+ # 1. Self-Attn
100
+ if self.use_ada_layer_norm:
101
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
102
+ elif self.use_ada_layer_norm_zero:
103
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
104
+ else:
105
+ self.norm1 = nn.LayerNorm(
106
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
107
+ )
108
+
109
+ self.attn1 = Attention(
110
+ query_dim=dim,
111
+ heads=num_attention_heads,
112
+ dim_head=attention_head_dim,
113
+ dropout=dropout,
114
+ bias=attention_bias,
115
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
116
+ upcast_attention=upcast_attention,
117
+ )
118
+
119
+ # 2. Cross-Attn
120
+ if cross_attention_dim is not None or double_self_attention:
121
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
122
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
123
+ # the second cross attention block.
124
+ self.norm2 = (
125
+ AdaLayerNorm(dim, num_embeds_ada_norm)
126
+ if self.use_ada_layer_norm
127
+ else nn.LayerNorm(
128
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
129
+ )
130
+ )
131
+ self.attn2 = Attention(
132
+ query_dim=dim,
133
+ cross_attention_dim=cross_attention_dim
134
+ if not double_self_attention
135
+ else None,
136
+ heads=num_attention_heads,
137
+ dim_head=attention_head_dim,
138
+ dropout=dropout,
139
+ bias=attention_bias,
140
+ upcast_attention=upcast_attention,
141
+ ) # is self-attn if encoder_hidden_states is none
142
+ else:
143
+ self.norm2 = None
144
+ self.attn2 = None
145
+
146
+ # 3. Feed-forward
147
+ if not self.use_ada_layer_norm_single:
148
+ self.norm3 = nn.LayerNorm(
149
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
150
+ )
151
+
152
+ self.ff = FeedForward(
153
+ dim,
154
+ dropout=dropout,
155
+ activation_fn=activation_fn,
156
+ final_dropout=final_dropout,
157
+ )
158
+
159
+ # 4. Fuser
160
+ if attention_type == "gated" or attention_type == "gated-text-image":
161
+ self.fuser = GatedSelfAttentionDense(
162
+ dim, cross_attention_dim, num_attention_heads, attention_head_dim
163
+ )
164
+
165
+ # 5. Scale-shift for PixArt-Alpha.
166
+ if self.use_ada_layer_norm_single:
167
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
168
+
169
+ # let chunk size default to None
170
+ self._chunk_size = None
171
+ self._chunk_dim = 0
172
+
173
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
174
+ # Sets chunk feed-forward
175
+ self._chunk_size = chunk_size
176
+ self._chunk_dim = dim
177
+
178
+ def forward(
179
+ self,
180
+ hidden_states: torch.FloatTensor,
181
+ attention_mask: Optional[torch.FloatTensor] = None,
182
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
183
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
184
+ timestep: Optional[torch.LongTensor] = None,
185
+ cross_attention_kwargs: Dict[str, Any] = None,
186
+ class_labels: Optional[torch.LongTensor] = None,
187
+ ) -> torch.FloatTensor:
188
+ # Notice that normalization is always applied before the real computation in the following blocks.
189
+ # 0. Self-Attention
190
+ batch_size = hidden_states.shape[0]
191
+
192
+ if self.use_ada_layer_norm:
193
+ norm_hidden_states = self.norm1(hidden_states, timestep)
194
+ elif self.use_ada_layer_norm_zero:
195
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
196
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
197
+ )
198
+ elif self.use_layer_norm:
199
+ norm_hidden_states = self.norm1(hidden_states)
200
+ elif self.use_ada_layer_norm_single:
201
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
202
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
203
+ ).chunk(6, dim=1)
204
+ norm_hidden_states = self.norm1(hidden_states)
205
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
206
+ norm_hidden_states = norm_hidden_states.squeeze(1)
207
+ else:
208
+ raise ValueError("Incorrect norm used")
209
+
210
+ if self.pos_embed is not None:
211
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
212
+
213
+ # 1. Retrieve lora scale.
214
+ lora_scale = (
215
+ cross_attention_kwargs.get("scale", 1.0)
216
+ if cross_attention_kwargs is not None
217
+ else 1.0
218
+ )
219
+
220
+ # 2. Prepare GLIGEN inputs
221
+ cross_attention_kwargs = (
222
+ cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
223
+ )
224
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
225
+
226
+ attn_output = self.attn1(
227
+ norm_hidden_states,
228
+ encoder_hidden_states=encoder_hidden_states
229
+ if self.only_cross_attention
230
+ else None,
231
+ attention_mask=attention_mask,
232
+ **cross_attention_kwargs,
233
+ )
234
+ if self.use_ada_layer_norm_zero:
235
+ attn_output = gate_msa.unsqueeze(1) * attn_output
236
+ elif self.use_ada_layer_norm_single:
237
+ attn_output = gate_msa * attn_output
238
+
239
+ hidden_states = attn_output + hidden_states
240
+ if hidden_states.ndim == 4:
241
+ hidden_states = hidden_states.squeeze(1)
242
+
243
+ # 2.5 GLIGEN Control
244
+ if gligen_kwargs is not None:
245
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
246
+
247
+ # 3. Cross-Attention
248
+ if self.attn2 is not None:
249
+ if self.use_ada_layer_norm:
250
+ norm_hidden_states = self.norm2(hidden_states, timestep)
251
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
252
+ norm_hidden_states = self.norm2(hidden_states)
253
+ elif self.use_ada_layer_norm_single:
254
+ # For PixArt norm2 isn't applied here:
255
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
256
+ norm_hidden_states = hidden_states
257
+ else:
258
+ raise ValueError("Incorrect norm")
259
+
260
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
261
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
262
+
263
+ attn_output = self.attn2(
264
+ norm_hidden_states,
265
+ encoder_hidden_states=encoder_hidden_states,
266
+ attention_mask=encoder_attention_mask,
267
+ **cross_attention_kwargs,
268
+ )
269
+ hidden_states = attn_output + hidden_states
270
+
271
+ # 4. Feed-forward
272
+ if not self.use_ada_layer_norm_single:
273
+ norm_hidden_states = self.norm3(hidden_states)
274
+
275
+ if self.use_ada_layer_norm_zero:
276
+ norm_hidden_states = (
277
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
278
+ )
279
+
280
+ if self.use_ada_layer_norm_single:
281
+ norm_hidden_states = self.norm2(hidden_states)
282
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
283
+
284
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
285
+
286
+ if self.use_ada_layer_norm_zero:
287
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
288
+ elif self.use_ada_layer_norm_single:
289
+ ff_output = gate_mlp * ff_output
290
+
291
+ hidden_states = ff_output + hidden_states
292
+ if hidden_states.ndim == 4:
293
+ hidden_states = hidden_states.squeeze(1)
294
+
295
+ return hidden_states
296
+
297
+
298
+ class TemporalBasicTransformerBlock(nn.Module):
299
+ def __init__(
300
+ self,
301
+ dim: int,
302
+ num_attention_heads: int,
303
+ attention_head_dim: int,
304
+ dropout=0.0,
305
+ cross_attention_dim: Optional[int] = None,
306
+ activation_fn: str = "geglu",
307
+ num_embeds_ada_norm: Optional[int] = None,
308
+ attention_bias: bool = False,
309
+ only_cross_attention: bool = False,
310
+ upcast_attention: bool = False,
311
+ unet_use_cross_frame_attention=None,
312
+ unet_use_temporal_attention=None,
313
+ name=None,
314
+ ):
315
+ super().__init__()
316
+ self.only_cross_attention = only_cross_attention
317
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
318
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
319
+ self.unet_use_temporal_attention = unet_use_temporal_attention
320
+ self.name=name
321
+
322
+ # SC-Attn
323
+ self.attn1 = Attention(
324
+ query_dim=dim,
325
+ heads=num_attention_heads,
326
+ dim_head=attention_head_dim,
327
+ dropout=dropout,
328
+ bias=attention_bias,
329
+ upcast_attention=upcast_attention,
330
+ )
331
+ self.norm1 = (
332
+ AdaLayerNorm(dim, num_embeds_ada_norm)
333
+ if self.use_ada_layer_norm
334
+ else nn.LayerNorm(dim)
335
+ )
336
+
337
+ # Cross-Attn
338
+ if cross_attention_dim is not None:
339
+ self.attn2 = Attention(
340
+ query_dim=dim,
341
+ cross_attention_dim=cross_attention_dim,
342
+ heads=num_attention_heads,
343
+ dim_head=attention_head_dim,
344
+ dropout=dropout,
345
+ bias=attention_bias,
346
+ upcast_attention=upcast_attention,
347
+ )
348
+ else:
349
+ self.attn2 = None
350
+
351
+ if cross_attention_dim is not None:
352
+ self.norm2 = (
353
+ AdaLayerNorm(dim, num_embeds_ada_norm)
354
+ if self.use_ada_layer_norm
355
+ else nn.LayerNorm(dim)
356
+ )
357
+ else:
358
+ self.norm2 = None
359
+
360
+ # Feed-forward
361
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
362
+ self.norm3 = nn.LayerNorm(dim)
363
+ self.use_ada_layer_norm_zero = False
364
+ # Temp-Attn
365
+ assert unet_use_temporal_attention is not None
366
+ if unet_use_temporal_attention:
367
+ self.attn_temp = Attention(
368
+ query_dim=dim,
369
+ heads=num_attention_heads,
370
+ dim_head=attention_head_dim,
371
+ dropout=dropout,
372
+ bias=attention_bias,
373
+ upcast_attention=upcast_attention,
374
+ )
375
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
376
+ self.norm_temp = (
377
+ AdaLayerNorm(dim, num_embeds_ada_norm)
378
+ if self.use_ada_layer_norm
379
+ else nn.LayerNorm(dim)
380
+ )
381
+
382
+ def forward(
383
+ self,
384
+ hidden_states,
385
+ encoder_hidden_states=None,
386
+ timestep=None,
387
+ attention_mask=None,
388
+ video_length=None,
389
+ self_attention_additional_feats=None,
390
+ mode=None,
391
+ ):
392
+ norm_hidden_states = (
393
+ self.norm1(hidden_states, timestep)
394
+ if self.use_ada_layer_norm
395
+ else self.norm1(hidden_states)
396
+ )
397
+ if self.name:
398
+ modify_norm_hidden_states = norm_hidden_states
399
+ if mode == "write":
400
+ self_attention_additional_feats[self.name]=norm_hidden_states
401
+ elif mode == "read" and self_attention_additional_feats:
402
+ ref_states = self_attention_additional_feats[self.name]
403
+ bank_fea = [
404
+ rearrange(
405
+ ref_states.unsqueeze(1).repeat(1, video_length, 1, 1),
406
+ "b t l c -> (b t) l c",
407
+ )
408
+ ]
409
+ modify_norm_hidden_states = torch.cat(
410
+ [norm_hidden_states] + bank_fea, dim=1
411
+ )
412
+
413
+ if self.unet_use_cross_frame_attention:
414
+ hidden_states = (
415
+ self.attn1(
416
+ norm_hidden_states,
417
+ attention_mask=attention_mask,
418
+ encoder_hidden_states=modify_norm_hidden_states,
419
+ video_length=video_length,
420
+ )
421
+ + hidden_states
422
+ )
423
+ else:
424
+ hidden_states = (
425
+ self.attn1(
426
+ norm_hidden_states,
427
+ encoder_hidden_states=modify_norm_hidden_states,
428
+ attention_mask=attention_mask
429
+ )
430
+ + hidden_states
431
+ )
432
+ else:
433
+ if self.unet_use_cross_frame_attention:
434
+ hidden_states = (
435
+ self.attn1(
436
+ norm_hidden_states,
437
+ attention_mask=attention_mask,
438
+ video_length=video_length,
439
+ )
440
+ + hidden_states
441
+ )
442
+ else:
443
+ hidden_states = (
444
+ self.attn1(norm_hidden_states, attention_mask=attention_mask)
445
+ + hidden_states
446
+ )
447
+
448
+ if self.attn2 is not None:
449
+ # Cross-Attention
450
+ norm_hidden_states = (
451
+ self.norm2(hidden_states, timestep)
452
+ if self.use_ada_layer_norm
453
+ else self.norm2(hidden_states)
454
+ )
455
+ hidden_states = (
456
+ self.attn2(
457
+ norm_hidden_states,
458
+ encoder_hidden_states=encoder_hidden_states,
459
+ attention_mask=attention_mask,
460
+ )
461
+ + hidden_states
462
+ )
463
+
464
+ # Feed-forward
465
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
466
+
467
+ # Temporal-Attention
468
+ if self.unet_use_temporal_attention:
469
+ d = hidden_states.shape[1]
470
+ hidden_states = rearrange(
471
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
472
+ )
473
+ norm_hidden_states = (
474
+ self.norm_temp(hidden_states, timestep)
475
+ if self.use_ada_layer_norm
476
+ else self.norm_temp(hidden_states)
477
+ )
478
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
479
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
480
+
481
+ return hidden_states
ref_encoder/attention_processor.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from diffusers.utils.import_utils import is_xformers_available
6
+ if is_xformers_available():
7
+ import xformers
8
+ import xformers.ops
9
+ else:
10
+ xformers = None
11
+
12
+ class HairAttnProcessor(nn.Module):
13
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, use_resampler=False):
14
+ super().__init__()
15
+
16
+ self.hidden_size = hidden_size
17
+ self.cross_attention_dim = cross_attention_dim
18
+ self.scale = scale
19
+ self.use_resampler = use_resampler
20
+ if self.use_resampler:
21
+ self.resampler = Resampler(query_dim=hidden_size)
22
+ self.to_k_SSR = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
23
+ self.to_v_SSR = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
24
+
25
+ def __call__(
26
+ self,
27
+ attn,
28
+ hidden_states,
29
+ encoder_hidden_states=None,
30
+ attention_mask=None,
31
+ temb=None,
32
+ ):
33
+ residual = hidden_states
34
+
35
+ if attn.spatial_norm is not None:
36
+ hidden_states = attn.spatial_norm(hidden_states, temb)
37
+
38
+ input_ndim = hidden_states.ndim
39
+
40
+ if input_ndim == 4:
41
+ batch_size, channel, height, width = hidden_states.shape
42
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
43
+
44
+ batch_size, sequence_length, _ = (
45
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
46
+ )
47
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
48
+
49
+ if attn.group_norm is not None:
50
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
51
+
52
+ query = attn.to_q(hidden_states)
53
+
54
+ if encoder_hidden_states is None:
55
+ encoder_hidden_states = hidden_states
56
+ elif attn.norm_cross:
57
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
58
+
59
+ # split hidden states
60
+ split_num = encoder_hidden_states.shape[1] // 2
61
+ encoder_hidden_states, _hidden_states = encoder_hidden_states[:, :split_num,
62
+ :], encoder_hidden_states[:, split_num:, :]
63
+
64
+ if self.use_resampler:
65
+ _hidden_states = self.resampler(_hidden_states)
66
+
67
+ key = attn.to_k(encoder_hidden_states)
68
+ value = attn.to_v(encoder_hidden_states)
69
+
70
+ query = attn.head_to_batch_dim(query)
71
+ key = attn.head_to_batch_dim(key)
72
+ value = attn.head_to_batch_dim(value)
73
+
74
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
75
+ hidden_states = torch.bmm(attention_probs, value)
76
+ hidden_states = attn.batch_to_head_dim(hidden_states)
77
+
78
+ _key = self.to_k_SSR(_hidden_states)
79
+ _value = self.to_v_SSR(_hidden_states)
80
+
81
+ _key = attn.head_to_batch_dim(_key)
82
+ _value = attn.head_to_batch_dim(_value)
83
+
84
+ _attention_probs = attn.get_attention_scores(query, _key, None)
85
+ _hidden_states = torch.bmm(_attention_probs, _value)
86
+ _hidden_states = attn.batch_to_head_dim(_hidden_states)
87
+
88
+ # # assume _hidden_states is a tensor of shape (batch_size, num_patches, hidden_size)
89
+ # batch_size, num_patches, hidden_size = _hidden_states.shape
90
+ # # create a mask tensor of shape (batch_size, num_patches)
91
+ # mask = torch.zeros((batch_size, num_patches), device="cuda", dtype=torch.float16)
92
+ # mask[:, 0:num_patches // 2] = 1
93
+ # # reshape the mask tensor to match the shape of _hidden_states
94
+ # mask = mask.unsqueeze(-1).expand(-1, -1, hidden_size)
95
+ # # apply the mask to _hidden_states
96
+ # _hidden_states = _hidden_states * mask
97
+
98
+ hidden_states = hidden_states + self.scale * _hidden_states
99
+
100
+ # linear proj
101
+ hidden_states = attn.to_out[0](hidden_states)
102
+ # dropout
103
+ hidden_states = attn.to_out[1](hidden_states)
104
+
105
+ if input_ndim == 4:
106
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
107
+
108
+ if attn.residual_connection:
109
+ hidden_states = hidden_states + residual
110
+
111
+ hidden_states = hidden_states / attn.rescale_output_factor
112
+
113
+ return hidden_states
114
+
115
+
116
+ class HairAttnProcessor2_0(torch.nn.Module):
117
+
118
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, use_resampler=False):
119
+ super().__init__()
120
+
121
+ if not hasattr(F, "scaled_dot_product_attention"):
122
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
123
+
124
+ self.hidden_size = hidden_size
125
+ self.cross_attention_dim = cross_attention_dim
126
+ self.scale = scale
127
+ self.use_resampler = use_resampler
128
+ if self.use_resampler:
129
+ self.resampler = Resampler(query_dim=hidden_size)
130
+ self.to_k_SSR = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
131
+ self.to_v_SSR = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
132
+
133
+ def __call__(
134
+ self,
135
+ attn,
136
+ hidden_states,
137
+ encoder_hidden_states=None,
138
+ attention_mask=None,
139
+ temb=None,
140
+ ):
141
+ residual = hidden_states
142
+
143
+ if attn.spatial_norm is not None:
144
+ hidden_states = attn.spatial_norm(hidden_states, temb)
145
+
146
+ input_ndim = hidden_states.ndim
147
+
148
+ if input_ndim == 4:
149
+ batch_size, channel, height, width = hidden_states.shape
150
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
151
+
152
+ batch_size, sequence_length, _ = (
153
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
154
+ )
155
+
156
+ if attention_mask is not None:
157
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
158
+ # scaled_dot_product_attention expects attention_mask shape to be
159
+ # (batch, heads, source_length, target_length)
160
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
161
+
162
+ if attn.group_norm is not None:
163
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
164
+
165
+ query = attn.to_q(hidden_states)
166
+
167
+ if encoder_hidden_states is None:
168
+ encoder_hidden_states = hidden_states
169
+ elif attn.norm_cross:
170
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
171
+
172
+ # split hidden states
173
+ split_num = encoder_hidden_states.shape[1] // 2
174
+ encoder_hidden_states, _hidden_states = encoder_hidden_states[:, :split_num,
175
+ :], encoder_hidden_states[:, split_num:, :]
176
+
177
+ if self.use_resampler:
178
+ _hidden_states = self.resampler(_hidden_states)
179
+
180
+ key = attn.to_k(encoder_hidden_states)
181
+ value = attn.to_v(encoder_hidden_states)
182
+
183
+ inner_dim = key.shape[-1]
184
+ head_dim = inner_dim // attn.heads
185
+
186
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
187
+
188
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
189
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
190
+
191
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
192
+ # TODO: add support for attn.scale when we move to Torch 2.1
193
+ hidden_states = F.scaled_dot_product_attention(
194
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
195
+ )
196
+
197
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
198
+ hidden_states = hidden_states.to(query.dtype)
199
+
200
+ _hidden_states = _hidden_states.to(self.to_k_SSR.weight.dtype)
201
+ _key = self.to_k_SSR(_hidden_states)
202
+ _value = self.to_v_SSR(_hidden_states)
203
+
204
+ _key = _key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
205
+ _value = _value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
206
+
207
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
208
+ # TODO: add support for attn.scale when we move to Torch 2.1
209
+ _hidden_states = F.scaled_dot_product_attention(
210
+ query.to(self.to_k_SSR.weight.dtype), _key, _value, attn_mask=None, dropout_p=0.0, is_causal=False
211
+ )
212
+
213
+ _hidden_states = _hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
214
+ _hidden_states = _hidden_states.to(query.dtype)
215
+
216
+ hidden_states = hidden_states + self.scale * _hidden_states
217
+
218
+ # linear proj
219
+ hidden_states = attn.to_out[0](hidden_states)
220
+ # dropout
221
+ hidden_states = attn.to_out[1](hidden_states)
222
+
223
+ if input_ndim == 4:
224
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
225
+
226
+ if attn.residual_connection:
227
+ hidden_states = hidden_states + residual
228
+
229
+ hidden_states = hidden_states / attn.rescale_output_factor
230
+
231
+ return hidden_states
232
+
233
+
234
+ class AttnProcessor(nn.Module):
235
+ r"""
236
+ Default processor for performing attention-related computations.
237
+ """
238
+ def __init__(
239
+ self,
240
+ hidden_size=None,
241
+ cross_attention_dim=None,
242
+ ):
243
+ super().__init__()
244
+
245
+ def __call__(
246
+ self,
247
+ attn,
248
+ hidden_states,
249
+ encoder_hidden_states=None,
250
+ attention_mask=None,
251
+ temb=None,
252
+ ):
253
+ residual = hidden_states
254
+
255
+ if attn.spatial_norm is not None:
256
+ hidden_states = attn.spatial_norm(hidden_states, temb)
257
+
258
+ input_ndim = hidden_states.ndim
259
+
260
+ if input_ndim == 4:
261
+ batch_size, channel, height, width = hidden_states.shape
262
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
263
+
264
+ batch_size, sequence_length, _ = (
265
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
266
+ )
267
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
268
+
269
+ if attn.group_norm is not None:
270
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
271
+
272
+ query = attn.to_q(hidden_states)
273
+
274
+ if encoder_hidden_states is None:
275
+ encoder_hidden_states = hidden_states
276
+ elif attn.norm_cross:
277
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
278
+
279
+ key = attn.to_k(encoder_hidden_states)
280
+ value = attn.to_v(encoder_hidden_states)
281
+
282
+ query = attn.head_to_batch_dim(query)
283
+ key = attn.head_to_batch_dim(key)
284
+ value = attn.head_to_batch_dim(value)
285
+
286
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
287
+ hidden_states = torch.bmm(attention_probs, value)
288
+ hidden_states = attn.batch_to_head_dim(hidden_states)
289
+
290
+ # linear proj
291
+ hidden_states = attn.to_out[0](hidden_states)
292
+ # dropout
293
+ hidden_states = attn.to_out[1](hidden_states)
294
+
295
+ if input_ndim == 4:
296
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
297
+
298
+ if attn.residual_connection:
299
+ hidden_states = hidden_states + residual
300
+
301
+ hidden_states = hidden_states / attn.rescale_output_factor
302
+
303
+ return hidden_states
304
+
305
+ class AttnProcessor2_0(torch.nn.Module):
306
+ r"""
307
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
308
+ """
309
+
310
+ def __init__(
311
+ self,
312
+ hidden_size=None,
313
+ cross_attention_dim=None,
314
+ ):
315
+ super().__init__()
316
+ if not hasattr(F, "scaled_dot_product_attention"):
317
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
318
+
319
+ def __call__(
320
+ self,
321
+ attn,
322
+ hidden_states,
323
+ encoder_hidden_states=None,
324
+ attention_mask=None,
325
+ temb=None,
326
+ ):
327
+ residual = hidden_states
328
+
329
+ if attn.spatial_norm is not None:
330
+ hidden_states = attn.spatial_norm(hidden_states, temb)
331
+
332
+ input_ndim = hidden_states.ndim
333
+
334
+ if input_ndim == 4:
335
+ batch_size, channel, height, width = hidden_states.shape
336
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
337
+
338
+ batch_size, sequence_length, _ = (
339
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
340
+ )
341
+
342
+ if attention_mask is not None:
343
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
344
+ # scaled_dot_product_attention expects attention_mask shape to be
345
+ # (batch, heads, source_length, target_length)
346
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
347
+
348
+ if attn.group_norm is not None:
349
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
350
+
351
+ query = attn.to_q(hidden_states)
352
+
353
+ if encoder_hidden_states is None:
354
+ encoder_hidden_states = hidden_states
355
+ elif attn.norm_cross:
356
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
357
+
358
+ key = attn.to_k(encoder_hidden_states)
359
+ value = attn.to_v(encoder_hidden_states)
360
+
361
+ inner_dim = key.shape[-1]
362
+ head_dim = inner_dim // attn.heads
363
+
364
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
365
+
366
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
367
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
368
+
369
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
370
+ # TODO: add support for attn.scale when we move to Torch 2.1
371
+ hidden_states = F.scaled_dot_product_attention(
372
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
373
+ )
374
+
375
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
376
+ hidden_states = hidden_states.to(query.dtype)
377
+
378
+ # linear proj
379
+ hidden_states = attn.to_out[0](hidden_states)
380
+ # dropout
381
+ hidden_states = attn.to_out[1](hidden_states)
382
+
383
+ if input_ndim == 4:
384
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
385
+
386
+ if attn.residual_connection:
387
+ hidden_states = hidden_states + residual
388
+
389
+ hidden_states = hidden_states / attn.rescale_output_factor
390
+
391
+ return hidden_states
ref_encoder/latent_controlnet.py ADDED
The diff for this file is too large to render. See raw diff
 
ref_encoder/motion_module.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+ from diffusers.models.attention import FeedForward
8
+ from diffusers.models.attention_processor import Attention, AttnProcessor
9
+ from diffusers.utils import BaseOutput
10
+ from diffusers.utils.import_utils import is_xformers_available
11
+ from einops import rearrange, repeat
12
+ from torch import nn
13
+
14
+
15
+ def zero_module(module):
16
+ # Zero out the parameters of a module and return it.
17
+ for p in module.parameters():
18
+ p.detach().zero_()
19
+ return module
20
+
21
+
22
+ @dataclass
23
+ class TemporalTransformer3DModelOutput(BaseOutput):
24
+ sample: torch.FloatTensor
25
+
26
+
27
+ if is_xformers_available():
28
+ import xformers
29
+ import xformers.ops
30
+ else:
31
+ xformers = None
32
+
33
+
34
+ def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
35
+ if motion_module_type == "Vanilla":
36
+ return VanillaTemporalModule(
37
+ in_channels=in_channels,
38
+ **motion_module_kwargs,
39
+ )
40
+ else:
41
+ raise ValueError
42
+
43
+
44
+ class VanillaTemporalModule(nn.Module):
45
+ def __init__(
46
+ self,
47
+ in_channels,
48
+ num_attention_heads=8,
49
+ num_transformer_block=2,
50
+ attention_block_types=("Temporal_Self", "Temporal_Self"),
51
+ cross_frame_attention_mode=None,
52
+ temporal_position_encoding=False,
53
+ temporal_position_encoding_max_len=24,
54
+ temporal_attention_dim_div=1,
55
+ zero_initialize=True,
56
+ ):
57
+ super().__init__()
58
+
59
+ self.temporal_transformer = TemporalTransformer3DModel(
60
+ in_channels=in_channels,
61
+ num_attention_heads=num_attention_heads,
62
+ attention_head_dim=in_channels
63
+ // num_attention_heads
64
+ // temporal_attention_dim_div,
65
+ num_layers=num_transformer_block,
66
+ attention_block_types=attention_block_types,
67
+ cross_frame_attention_mode=cross_frame_attention_mode,
68
+ temporal_position_encoding=temporal_position_encoding,
69
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
70
+ )
71
+
72
+ if zero_initialize:
73
+ self.temporal_transformer.proj_out = zero_module(
74
+ self.temporal_transformer.proj_out
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ input_tensor,
80
+ temb,
81
+ encoder_hidden_states,
82
+ attention_mask=None,
83
+ anchor_frame_idx=None,
84
+ ):
85
+ hidden_states = input_tensor
86
+ hidden_states = self.temporal_transformer(
87
+ hidden_states, encoder_hidden_states, attention_mask
88
+ )
89
+
90
+ output = hidden_states
91
+ return output
92
+
93
+
94
+ class TemporalTransformer3DModel(nn.Module):
95
+ def __init__(
96
+ self,
97
+ in_channels,
98
+ num_attention_heads,
99
+ attention_head_dim,
100
+ num_layers,
101
+ attention_block_types=(
102
+ "Temporal_Self",
103
+ "Temporal_Self",
104
+ ),
105
+ dropout=0.0,
106
+ norm_num_groups=32,
107
+ cross_attention_dim=768,
108
+ activation_fn="geglu",
109
+ attention_bias=False,
110
+ upcast_attention=False,
111
+ cross_frame_attention_mode=None,
112
+ temporal_position_encoding=False,
113
+ temporal_position_encoding_max_len=24,
114
+ ):
115
+ super().__init__()
116
+
117
+ inner_dim = num_attention_heads * attention_head_dim
118
+
119
+ self.norm = torch.nn.GroupNorm(
120
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
121
+ )
122
+ self.proj_in = nn.Linear(in_channels, inner_dim)
123
+
124
+ self.transformer_blocks = nn.ModuleList(
125
+ [
126
+ TemporalTransformerBlock(
127
+ dim=inner_dim,
128
+ num_attention_heads=num_attention_heads,
129
+ attention_head_dim=attention_head_dim,
130
+ attention_block_types=attention_block_types,
131
+ dropout=dropout,
132
+ norm_num_groups=norm_num_groups,
133
+ cross_attention_dim=cross_attention_dim,
134
+ activation_fn=activation_fn,
135
+ attention_bias=attention_bias,
136
+ upcast_attention=upcast_attention,
137
+ cross_frame_attention_mode=cross_frame_attention_mode,
138
+ temporal_position_encoding=temporal_position_encoding,
139
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
140
+ )
141
+ for d in range(num_layers)
142
+ ]
143
+ )
144
+ self.proj_out = nn.Linear(inner_dim, in_channels)
145
+
146
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
147
+ assert (
148
+ hidden_states.dim() == 5
149
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
150
+ video_length = hidden_states.shape[2]
151
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
152
+
153
+ batch, channel, height, weight = hidden_states.shape
154
+ residual = hidden_states
155
+
156
+ hidden_states = self.norm(hidden_states)
157
+ inner_dim = hidden_states.shape[1]
158
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
159
+ batch, height * weight, inner_dim
160
+ )
161
+ hidden_states = self.proj_in(hidden_states)
162
+
163
+ # Transformer Blocks
164
+ for block in self.transformer_blocks:
165
+ hidden_states = block(
166
+ hidden_states,
167
+ encoder_hidden_states=encoder_hidden_states,
168
+ video_length=video_length,
169
+ )
170
+
171
+ # output
172
+ hidden_states = self.proj_out(hidden_states)
173
+ hidden_states = (
174
+ hidden_states.reshape(batch, height, weight, inner_dim)
175
+ .permute(0, 3, 1, 2)
176
+ .contiguous()
177
+ )
178
+
179
+ output = hidden_states + residual
180
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
181
+
182
+ return output
183
+
184
+
185
+ class TemporalTransformerBlock(nn.Module):
186
+ def __init__(
187
+ self,
188
+ dim,
189
+ num_attention_heads,
190
+ attention_head_dim,
191
+ attention_block_types=(
192
+ "Temporal_Self",
193
+ "Temporal_Self",
194
+ ),
195
+ dropout=0.0,
196
+ norm_num_groups=32,
197
+ cross_attention_dim=768,
198
+ activation_fn="geglu",
199
+ attention_bias=False,
200
+ upcast_attention=False,
201
+ cross_frame_attention_mode=None,
202
+ temporal_position_encoding=False,
203
+ temporal_position_encoding_max_len=24,
204
+ ):
205
+ super().__init__()
206
+
207
+ attention_blocks = []
208
+ norms = []
209
+
210
+ for block_name in attention_block_types:
211
+ attention_blocks.append(
212
+ VersatileAttention(
213
+ attention_mode=block_name.split("_")[0],
214
+ cross_attention_dim=cross_attention_dim
215
+ if block_name.endswith("_Cross")
216
+ else None,
217
+ query_dim=dim,
218
+ heads=num_attention_heads,
219
+ dim_head=attention_head_dim,
220
+ dropout=dropout,
221
+ bias=attention_bias,
222
+ upcast_attention=upcast_attention,
223
+ cross_frame_attention_mode=cross_frame_attention_mode,
224
+ temporal_position_encoding=temporal_position_encoding,
225
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
226
+ )
227
+ )
228
+ norms.append(nn.LayerNorm(dim))
229
+
230
+ self.attention_blocks = nn.ModuleList(attention_blocks)
231
+ self.norms = nn.ModuleList(norms)
232
+
233
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
234
+ self.ff_norm = nn.LayerNorm(dim)
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states,
239
+ encoder_hidden_states=None,
240
+ attention_mask=None,
241
+ video_length=None,
242
+ ):
243
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
244
+ norm_hidden_states = norm(hidden_states)
245
+ hidden_states = (
246
+ attention_block(
247
+ norm_hidden_states,
248
+ encoder_hidden_states=encoder_hidden_states
249
+ if attention_block.is_cross_attention
250
+ else None,
251
+ video_length=video_length,
252
+ )
253
+ + hidden_states
254
+ )
255
+
256
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
257
+
258
+ output = hidden_states
259
+ return output
260
+
261
+
262
+ class PositionalEncoding(nn.Module):
263
+ def __init__(self, d_model, dropout=0.0, max_len=24):
264
+ super().__init__()
265
+ self.dropout = nn.Dropout(p=dropout)
266
+ position = torch.arange(max_len).unsqueeze(1)
267
+ div_term = torch.exp(
268
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
269
+ )
270
+ pe = torch.zeros(1, max_len, d_model)
271
+ pe[0, :, 0::2] = torch.sin(position * div_term)
272
+ pe[0, :, 1::2] = torch.cos(position * div_term)
273
+ self.register_buffer("pe", pe)
274
+
275
+ def forward(self, x):
276
+ x = x + self.pe[:, : x.size(1)]
277
+ return self.dropout(x)
278
+
279
+
280
+ class VersatileAttention(Attention):
281
+ def __init__(
282
+ self,
283
+ attention_mode=None,
284
+ cross_frame_attention_mode=None,
285
+ temporal_position_encoding=False,
286
+ temporal_position_encoding_max_len=24,
287
+ *args,
288
+ **kwargs,
289
+ ):
290
+ super().__init__(*args, **kwargs)
291
+ assert attention_mode == "Temporal"
292
+
293
+ self.attention_mode = attention_mode
294
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
295
+
296
+ self.pos_encoder = (
297
+ PositionalEncoding(
298
+ kwargs["query_dim"],
299
+ dropout=0.0,
300
+ max_len=temporal_position_encoding_max_len,
301
+ )
302
+ if (temporal_position_encoding and attention_mode == "Temporal")
303
+ else None
304
+ )
305
+
306
+ def extra_repr(self):
307
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
308
+
309
+ def set_use_memory_efficient_attention_xformers(
310
+ self,
311
+ use_memory_efficient_attention_xformers: bool,
312
+ attention_op: Optional[Callable] = None,
313
+ ):
314
+ if use_memory_efficient_attention_xformers:
315
+ if not is_xformers_available():
316
+ raise ModuleNotFoundError(
317
+ (
318
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
319
+ " xformers"
320
+ ),
321
+ name="xformers",
322
+ )
323
+ elif not torch.cuda.is_available():
324
+ raise ValueError(
325
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
326
+ " only available for GPU "
327
+ )
328
+ else:
329
+ try:
330
+ # Make sure we can run the memory efficient attention
331
+ _ = xformers.ops.memory_efficient_attention(
332
+ torch.randn((1, 2, 40), device="cuda"),
333
+ torch.randn((1, 2, 40), device="cuda"),
334
+ torch.randn((1, 2, 40), device="cuda"),
335
+ )
336
+ except Exception as e:
337
+ raise e
338
+
339
+ # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13.
340
+ # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13.
341
+ # You don't need XFormersAttnProcessor here.
342
+ # processor = XFormersAttnProcessor(
343
+ # attention_op=attention_op,
344
+ # )
345
+ processor = AttnProcessor()
346
+ else:
347
+ processor = AttnProcessor()
348
+
349
+ self.set_processor(processor)
350
+
351
+ def forward(
352
+ self,
353
+ hidden_states,
354
+ encoder_hidden_states=None,
355
+ attention_mask=None,
356
+ video_length=None,
357
+ **cross_attention_kwargs,
358
+ ):
359
+ if self.attention_mode == "Temporal":
360
+ d = hidden_states.shape[1] # d means HxW
361
+ hidden_states = rearrange(
362
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
363
+ )
364
+
365
+ if self.pos_encoder is not None:
366
+ hidden_states = self.pos_encoder(hidden_states)
367
+
368
+ encoder_hidden_states = (
369
+ repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
370
+ if encoder_hidden_states is not None
371
+ else encoder_hidden_states
372
+ )
373
+
374
+ else:
375
+ raise NotImplementedError
376
+
377
+ hidden_states = self.processor(
378
+ self,
379
+ hidden_states,
380
+ encoder_hidden_states=encoder_hidden_states,
381
+ attention_mask=attention_mask,
382
+ **cross_attention_kwargs,
383
+ )
384
+
385
+ if self.attention_mode == "Temporal":
386
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
387
+
388
+ return hidden_states
ref_encoder/mutual_self_attention.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py
2
+ from typing import Any, Dict, Optional
3
+
4
+ import torch
5
+ from einops import rearrange
6
+
7
+ from src.models.attention import TemporalBasicTransformerBlock
8
+
9
+ #from .attention import BasicTransformerBlock
10
+ from diffusers.models.attention import BasicTransformerBlock
11
+
12
+ def torch_dfs(model: torch.nn.Module):
13
+ result = [model]
14
+ for child in model.children():
15
+ result += torch_dfs(child)
16
+ return result
17
+
18
+
19
+ class ReferenceAttentionControl:
20
+ def __init__(
21
+ self,
22
+ unet,
23
+ mode="write",
24
+ do_classifier_free_guidance=False,
25
+ attention_auto_machine_weight=float("inf"),
26
+ gn_auto_machine_weight=1.0,
27
+ style_fidelity=1.0,
28
+ reference_attn=True,
29
+ reference_adain=False,
30
+ fusion_blocks="midup",
31
+ batch_size=1,
32
+ ) -> None:
33
+ # 10. Modify self attention and group norm
34
+ self.unet = unet
35
+ assert mode in ["read", "write"]
36
+ assert fusion_blocks in ["midup", "full"]
37
+ self.reference_attn = reference_attn
38
+ self.reference_adain = reference_adain
39
+ self.fusion_blocks = fusion_blocks
40
+ self.register_reference_hooks(
41
+ mode,
42
+ do_classifier_free_guidance,
43
+ attention_auto_machine_weight,
44
+ gn_auto_machine_weight,
45
+ style_fidelity,
46
+ reference_attn,
47
+ reference_adain,
48
+ fusion_blocks,
49
+ batch_size=batch_size,
50
+ )
51
+
52
+ def register_reference_hooks(
53
+ self,
54
+ mode,
55
+ do_classifier_free_guidance,
56
+ attention_auto_machine_weight,
57
+ gn_auto_machine_weight,
58
+ style_fidelity,
59
+ reference_attn,
60
+ reference_adain,
61
+ dtype=torch.float16,
62
+ batch_size=1,
63
+ num_images_per_prompt=1,
64
+ device=torch.device("cpu"),
65
+ fusion_blocks="midup",
66
+ ):
67
+ MODE = mode
68
+ do_classifier_free_guidance = do_classifier_free_guidance
69
+ attention_auto_machine_weight = attention_auto_machine_weight
70
+ gn_auto_machine_weight = gn_auto_machine_weight
71
+ style_fidelity = style_fidelity
72
+ reference_attn = reference_attn
73
+ reference_adain = reference_adain
74
+ fusion_blocks = fusion_blocks
75
+ num_images_per_prompt = num_images_per_prompt
76
+ dtype = dtype
77
+ if do_classifier_free_guidance:
78
+ uc_mask = (
79
+ torch.Tensor(
80
+ [1] * batch_size * num_images_per_prompt * 16
81
+ + [0] * batch_size * num_images_per_prompt * 16
82
+ )
83
+ .to(device)
84
+ .bool()
85
+ )
86
+ else:
87
+ uc_mask = (
88
+ torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
89
+ .to(device)
90
+ .bool()
91
+ )
92
+
93
+ def hacked_basic_transformer_inner_forward(
94
+ self,
95
+ hidden_states: torch.FloatTensor,
96
+ attention_mask: Optional[torch.FloatTensor] = None,
97
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
98
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
99
+ timestep: Optional[torch.LongTensor] = None,
100
+ cross_attention_kwargs: Dict[str, Any] = None,
101
+ class_labels: Optional[torch.LongTensor] = None,
102
+ video_length=None,
103
+ self_attention_additional_feats=None,
104
+ mode=None,
105
+ ):
106
+ if self.use_ada_layer_norm: # False
107
+ norm_hidden_states = self.norm1(hidden_states, timestep)
108
+ elif self.use_ada_layer_norm_zero:
109
+ (
110
+ norm_hidden_states,
111
+ gate_msa,
112
+ shift_mlp,
113
+ scale_mlp,
114
+ gate_mlp,
115
+ ) = self.norm1(
116
+ hidden_states,
117
+ timestep,
118
+ class_labels,
119
+ hidden_dtype=hidden_states.dtype,
120
+ )
121
+ else:
122
+ norm_hidden_states = self.norm1(hidden_states)
123
+
124
+ # 1. Self-Attention
125
+ # self.only_cross_attention = False
126
+ cross_attention_kwargs = (
127
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
128
+ )
129
+ if self.only_cross_attention:
130
+ attn_output = self.attn1(
131
+ norm_hidden_states,
132
+ encoder_hidden_states=encoder_hidden_states
133
+ if self.only_cross_attention
134
+ else None,
135
+ attention_mask=attention_mask,
136
+ **cross_attention_kwargs,
137
+ )
138
+ else:
139
+ if MODE == "write":
140
+ self.bank.append(norm_hidden_states.clone())
141
+ attn_output = self.attn1(
142
+ norm_hidden_states,
143
+ encoder_hidden_states=encoder_hidden_states
144
+ if self.only_cross_attention
145
+ else None,
146
+ attention_mask=attention_mask,
147
+ **cross_attention_kwargs,
148
+ )
149
+ if MODE == "read":
150
+ bank_fea = [
151
+ rearrange(
152
+ d.unsqueeze(1).repeat(1, video_length, 1, 1),
153
+ "b t l c -> (b t) l c",
154
+ )
155
+ for d in self.bank
156
+ ]
157
+ modify_norm_hidden_states = torch.cat(
158
+ [norm_hidden_states] + bank_fea, dim=1
159
+ )
160
+ hidden_states_uc = (
161
+ self.attn1(
162
+ norm_hidden_states,
163
+ encoder_hidden_states=modify_norm_hidden_states,
164
+ attention_mask=attention_mask,
165
+ )
166
+ + hidden_states
167
+ )
168
+ if do_classifier_free_guidance:
169
+ hidden_states_c = hidden_states_uc.clone()
170
+ _uc_mask = uc_mask.clone()
171
+ if hidden_states.shape[0] != _uc_mask.shape[0]:
172
+ _uc_mask = (
173
+ torch.Tensor(
174
+ [1] * (hidden_states.shape[0] // 2)
175
+ + [0] * (hidden_states.shape[0] // 2)
176
+ )
177
+ .to(device)
178
+ .bool()
179
+ )
180
+ hidden_states_c[_uc_mask] = (
181
+ self.attn1(
182
+ norm_hidden_states[_uc_mask],
183
+ encoder_hidden_states=norm_hidden_states[_uc_mask],
184
+ attention_mask=attention_mask,
185
+ )
186
+ + hidden_states[_uc_mask]
187
+ )
188
+ hidden_states = hidden_states_c.clone()
189
+ else:
190
+ hidden_states = hidden_states_uc
191
+
192
+ # self.bank.clear()
193
+ if self.attn2 is not None:
194
+ # Cross-Attention
195
+ norm_hidden_states = (
196
+ self.norm2(hidden_states, timestep)
197
+ if self.use_ada_layer_norm
198
+ else self.norm2(hidden_states)
199
+ )
200
+ hidden_states = (
201
+ self.attn2(
202
+ norm_hidden_states,
203
+ encoder_hidden_states=encoder_hidden_states,
204
+ attention_mask=attention_mask,
205
+ )
206
+ + hidden_states
207
+ )
208
+
209
+ # Feed-forward
210
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
211
+
212
+ # Temporal-Attention
213
+ if self.unet_use_temporal_attention:
214
+ d = hidden_states.shape[1]
215
+ hidden_states = rearrange(
216
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
217
+ )
218
+ norm_hidden_states = (
219
+ self.norm_temp(hidden_states, timestep)
220
+ if self.use_ada_layer_norm
221
+ else self.norm_temp(hidden_states)
222
+ )
223
+ hidden_states = (
224
+ self.attn_temp(norm_hidden_states) + hidden_states
225
+ )
226
+ hidden_states = rearrange(
227
+ hidden_states, "(b d) f c -> (b f) d c", d=d
228
+ )
229
+
230
+ return hidden_states
231
+
232
+ if self.use_ada_layer_norm_zero:
233
+ attn_output = gate_msa.unsqueeze(1) * attn_output
234
+ hidden_states = attn_output + hidden_states
235
+
236
+ if self.attn2 is not None:
237
+ norm_hidden_states = (
238
+ self.norm2(hidden_states, timestep)
239
+ if self.use_ada_layer_norm
240
+ else self.norm2(hidden_states)
241
+ )
242
+
243
+ # 2. Cross-Attention
244
+ attn_output = self.attn2(
245
+ norm_hidden_states,
246
+ encoder_hidden_states=encoder_hidden_states,
247
+ attention_mask=encoder_attention_mask,
248
+ **cross_attention_kwargs,
249
+ )
250
+ hidden_states = attn_output + hidden_states
251
+
252
+ # 3. Feed-forward
253
+ norm_hidden_states = self.norm3(hidden_states)
254
+
255
+ if self.use_ada_layer_norm_zero:
256
+ norm_hidden_states = (
257
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
258
+ )
259
+
260
+ ff_output = self.ff(norm_hidden_states)
261
+
262
+ if self.use_ada_layer_norm_zero:
263
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
264
+
265
+ hidden_states = ff_output + hidden_states
266
+
267
+ return hidden_states
268
+
269
+ if self.reference_attn:
270
+ if self.fusion_blocks == "midup":
271
+ attn_modules = [
272
+ module
273
+ for module in (
274
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
275
+ )
276
+ if isinstance(module, BasicTransformerBlock)
277
+ or isinstance(module, TemporalBasicTransformerBlock)
278
+ ]
279
+ elif self.fusion_blocks == "full":
280
+ attn_modules = [
281
+ module
282
+ for module in torch_dfs(self.unet)
283
+ if isinstance(module, BasicTransformerBlock)
284
+ or isinstance(module, TemporalBasicTransformerBlock)
285
+ ]
286
+ attn_modules = sorted(
287
+ attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
288
+ )
289
+
290
+ for i, module in enumerate(attn_modules):
291
+ module._original_inner_forward = module.forward
292
+ if isinstance(module, BasicTransformerBlock):
293
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
294
+ module, BasicTransformerBlock
295
+ )
296
+ if isinstance(module, TemporalBasicTransformerBlock):
297
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
298
+ module, TemporalBasicTransformerBlock
299
+ )
300
+
301
+ module.bank = []
302
+ module.attn_weight = float(i) / float(len(attn_modules))
303
+
304
+ def update(self, writer, dtype=torch.float16):
305
+ if self.reference_attn:
306
+ if self.fusion_blocks == "midup":
307
+ reader_attn_modules = [
308
+ module
309
+ for module in (
310
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
311
+ )
312
+ if isinstance(module, TemporalBasicTransformerBlock)
313
+ ]
314
+ writer_attn_modules = [
315
+ module
316
+ for module in (
317
+ torch_dfs(writer.unet.mid_block)
318
+ + torch_dfs(writer.unet.up_blocks)
319
+ )
320
+ if isinstance(module, BasicTransformerBlock)
321
+ ]
322
+ elif self.fusion_blocks == "full":
323
+ reader_attn_modules = [
324
+ module
325
+ for module in torch_dfs(self.unet)
326
+ if isinstance(module, TemporalBasicTransformerBlock)
327
+ ]
328
+ writer_attn_modules = [
329
+ module
330
+ for module in torch_dfs(writer.unet)
331
+ if isinstance(module, BasicTransformerBlock)
332
+ ]
333
+ reader_attn_modules = sorted(
334
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
335
+ )
336
+ writer_attn_modules = sorted(
337
+ writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
338
+ )
339
+ for r, w in zip(reader_attn_modules, writer_attn_modules):
340
+ r.bank = [v.clone().to(dtype) for v in w.bank]
341
+ # w.bank.clear()
342
+
343
+ def clear(self):
344
+ if self.reference_attn:
345
+ if self.fusion_blocks == "midup":
346
+ reader_attn_modules = [
347
+ module
348
+ for module in (
349
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
350
+ )
351
+ if isinstance(module, BasicTransformerBlock)
352
+ or isinstance(module, TemporalBasicTransformerBlock)
353
+ ]
354
+ elif self.fusion_blocks == "full":
355
+ reader_attn_modules = [
356
+ module
357
+ for module in torch_dfs(self.unet)
358
+ if isinstance(module, BasicTransformerBlock)
359
+ or isinstance(module, TemporalBasicTransformerBlock)
360
+ ]
361
+ reader_attn_modules = sorted(
362
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
363
+ )
364
+ for r in reader_attn_modules:
365
+ r.bank.clear()
ref_encoder/pose_guider.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.nn.init as init
6
+ from diffusers.models.modeling_utils import ModelMixin
7
+
8
+ from src.models.motion_module import zero_module
9
+ from src.models.resnet import InflatedConv3d
10
+
11
+
12
+ class PoseGuider(ModelMixin):
13
+ def __init__(
14
+ self,
15
+ conditioning_embedding_channels: int,
16
+ conditioning_channels: int = 3,
17
+ block_out_channels: Tuple[int] = (16, 32, 64, 128),
18
+ ):
19
+ super().__init__()
20
+ self.conv_in = InflatedConv3d(
21
+ conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
22
+ )
23
+
24
+ self.blocks = nn.ModuleList([])
25
+
26
+ for i in range(len(block_out_channels) - 1):
27
+ channel_in = block_out_channels[i]
28
+ channel_out = block_out_channels[i + 1]
29
+ self.blocks.append(
30
+ InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
31
+ )
32
+ self.blocks.append(
33
+ InflatedConv3d(
34
+ channel_in, channel_out, kernel_size=3, padding=1, stride=2
35
+ )
36
+ )
37
+
38
+ self.conv_out = zero_module(
39
+ InflatedConv3d(
40
+ block_out_channels[-1],
41
+ conditioning_embedding_channels,
42
+ kernel_size=3,
43
+ padding=1,
44
+ )
45
+ )
46
+
47
+ def forward(self, conditioning):
48
+ embedding = self.conv_in(conditioning)
49
+ embedding = F.silu(embedding)
50
+
51
+ for block in self.blocks:
52
+ embedding = block(embedding)
53
+ embedding = F.silu(embedding)
54
+
55
+ embedding = self.conv_out(embedding)
56
+
57
+ return embedding
ref_encoder/reference_control.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+ from diffusers.models.attention import BasicTransformerBlock
4
+ from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
5
+
6
+ def torch_dfs(model: torch.nn.Module):
7
+ result = [model]
8
+ for child in model.children():
9
+ result += torch_dfs(child)
10
+ return result
11
+
12
+ class ReferenceAttentionControl():
13
+
14
+ def __init__(self,
15
+ unet,
16
+ mode="write",
17
+ do_classifier_free_guidance=False,
18
+ attention_auto_machine_weight=float('inf'),
19
+ gn_auto_machine_weight=1.0,
20
+ style_fidelity=1.0,
21
+ reference_attn=True,
22
+ reference_adain=False,
23
+ fusion_blocks="full",
24
+ batch_size=1,
25
+ ) -> None:
26
+ # 10. Modify self attention and group norm
27
+ self.unet = unet
28
+ assert mode in ["read", "write"]
29
+ assert fusion_blocks in ["midup", "full"]
30
+ self.reference_attn = reference_attn
31
+ self.reference_adain = reference_adain
32
+ self.fusion_blocks = fusion_blocks
33
+ self.register_reference_hooks(
34
+ mode,
35
+ do_classifier_free_guidance,
36
+ attention_auto_machine_weight,
37
+ gn_auto_machine_weight,
38
+ style_fidelity,
39
+ reference_attn,
40
+ reference_adain,
41
+ fusion_blocks,
42
+ batch_size=batch_size,
43
+ )
44
+
45
+ def register_reference_hooks(
46
+ self,
47
+ mode,
48
+ do_classifier_free_guidance,
49
+ attention_auto_machine_weight,
50
+ gn_auto_machine_weight,
51
+ style_fidelity,
52
+ reference_attn,
53
+ reference_adain,
54
+ dtype=torch.float16,
55
+ batch_size=1,
56
+ num_images_per_prompt=1,
57
+ device=torch.device("cpu"),
58
+ fusion_blocks='midup',
59
+ ):
60
+ MODE = mode
61
+ do_classifier_free_guidance = do_classifier_free_guidance
62
+ attention_auto_machine_weight = attention_auto_machine_weight
63
+ gn_auto_machine_weight = gn_auto_machine_weight
64
+ style_fidelity = style_fidelity
65
+ reference_attn = reference_attn
66
+ reference_adain = reference_adain
67
+ fusion_blocks = fusion_blocks
68
+ num_images_per_prompt = num_images_per_prompt
69
+ dtype = dtype
70
+ if do_classifier_free_guidance:
71
+ uc_mask = (
72
+ torch.Tensor(
73
+ [1] * batch_size * num_images_per_prompt * 16 + [0] * batch_size * num_images_per_prompt * 16)
74
+ .to(device)
75
+ .bool()
76
+ )
77
+ else:
78
+ uc_mask = (
79
+ torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
80
+ .to(device)
81
+ .bool()
82
+ )
83
+
84
+ def hacked_basic_transformer_inner_forward(
85
+ self,
86
+ hidden_states: torch.FloatTensor,
87
+ attention_mask: Optional[torch.FloatTensor] = None,
88
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
89
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
90
+ timestep: Optional[torch.LongTensor] = None,
91
+ cross_attention_kwargs: Dict[str, Any] = None,
92
+ class_labels: Optional[torch.LongTensor] = None,
93
+ ):
94
+ if self.use_ada_layer_norm:
95
+ norm_hidden_states = self.norm1(hidden_states, timestep)
96
+ elif self.use_ada_layer_norm_zero:
97
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
98
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
99
+ )
100
+ else:
101
+ norm_hidden_states = self.norm1(hidden_states)
102
+
103
+ # 1. Self-Attention
104
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
105
+ if self.only_cross_attention:
106
+ attn_output = self.attn1(
107
+ norm_hidden_states,
108
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
109
+ attention_mask=attention_mask,
110
+ **cross_attention_kwargs,
111
+ )
112
+ else:
113
+ if MODE == "write":
114
+ self.bank.append(norm_hidden_states.clone())
115
+ attn_output = self.attn1(
116
+ norm_hidden_states,
117
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
118
+ attention_mask=attention_mask,
119
+ **cross_attention_kwargs,
120
+ )
121
+ if MODE == "read":
122
+ hidden_states_uc = self.attn1(norm_hidden_states,
123
+ encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank,
124
+ dim=1),
125
+ attention_mask=attention_mask) + hidden_states
126
+ hidden_states_c = hidden_states_uc.clone()
127
+ _uc_mask = uc_mask.clone()
128
+ if do_classifier_free_guidance:
129
+ if hidden_states.shape[0] != _uc_mask.shape[0]:
130
+ _uc_mask = (
131
+ torch.Tensor([1] * (hidden_states.shape[0] // 2) + [0] * (hidden_states.shape[0] // 2))
132
+ .to(device)
133
+ .bool()
134
+ )
135
+
136
+ hidden_states_c[_uc_mask] = self.attn1(
137
+ norm_hidden_states[_uc_mask],
138
+ encoder_hidden_states=norm_hidden_states[_uc_mask],
139
+ attention_mask=attention_mask,
140
+ ) + hidden_states[_uc_mask]
141
+ hidden_states = hidden_states_c.clone()
142
+
143
+ self.bank.clear()
144
+ if self.attn2 is not None:
145
+ # Cross-Attention
146
+ norm_hidden_states = (
147
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(
148
+ hidden_states)
149
+ )
150
+ hidden_states = (
151
+ self.attn2(
152
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states,
153
+ attention_mask=attention_mask
154
+ )
155
+ + hidden_states
156
+ )
157
+
158
+ # Feed-forward
159
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
160
+
161
+ return hidden_states
162
+
163
+ if self.use_ada_layer_norm_zero:
164
+ attn_output = gate_msa.unsqueeze(1) * attn_output
165
+ hidden_states = attn_output + hidden_states
166
+
167
+ if self.attn2 is not None:
168
+ norm_hidden_states = (
169
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
170
+ )
171
+
172
+ # 2. Cross-Attention
173
+ attn_output = self.attn2(
174
+ norm_hidden_states,
175
+ encoder_hidden_states=encoder_hidden_states,
176
+ attention_mask=encoder_attention_mask,
177
+ **cross_attention_kwargs,
178
+ )
179
+ hidden_states = attn_output + hidden_states
180
+
181
+ # 3. Feed-forward
182
+ norm_hidden_states = self.norm3(hidden_states)
183
+
184
+ if self.use_ada_layer_norm_zero:
185
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
186
+
187
+ ff_output = self.ff(norm_hidden_states)
188
+
189
+ if self.use_ada_layer_norm_zero:
190
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
191
+
192
+ hidden_states = ff_output + hidden_states
193
+
194
+ return hidden_states
195
+
196
+ def hacked_mid_forward(self, *args, **kwargs):
197
+ eps = 1e-6
198
+ x = self.original_forward(*args, **kwargs)
199
+ if MODE == "write":
200
+ if gn_auto_machine_weight >= self.gn_weight:
201
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
202
+ self.mean_bank.append(mean)
203
+ self.var_bank.append(var)
204
+ if MODE == "read":
205
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
206
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
207
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
208
+ mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
209
+ var_acc = sum(self.var_bank) / float(len(self.var_bank))
210
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
211
+ x_uc = (((x - mean) / std) * std_acc) + mean_acc
212
+ x_c = x_uc.clone()
213
+ if do_classifier_free_guidance and style_fidelity > 0:
214
+ x_c[uc_mask] = x[uc_mask]
215
+ x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
216
+ self.mean_bank = []
217
+ self.var_bank = []
218
+ return x
219
+
220
+ def hack_CrossAttnDownBlock2D_forward(
221
+ self,
222
+ hidden_states: torch.FloatTensor,
223
+ temb: Optional[torch.FloatTensor] = None,
224
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
225
+ attention_mask: Optional[torch.FloatTensor] = None,
226
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
227
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
228
+ ):
229
+ eps = 1e-6
230
+
231
+ # TODO(Patrick, William) - attention face_hair_mask is not used
232
+ output_states = ()
233
+
234
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
235
+ hidden_states = resnet(hidden_states, temb)
236
+ hidden_states = attn(
237
+ hidden_states,
238
+ encoder_hidden_states=encoder_hidden_states,
239
+ cross_attention_kwargs=cross_attention_kwargs,
240
+ attention_mask=attention_mask,
241
+ encoder_attention_mask=encoder_attention_mask,
242
+ return_dict=False,
243
+ )[0]
244
+ if MODE == "write":
245
+ if gn_auto_machine_weight >= self.gn_weight:
246
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
247
+ self.mean_bank.append([mean])
248
+ self.var_bank.append([var])
249
+ if MODE == "read":
250
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
251
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
252
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
253
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
254
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
255
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
256
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
257
+ hidden_states_c = hidden_states_uc.clone()
258
+ if do_classifier_free_guidance and style_fidelity > 0:
259
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
260
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
261
+
262
+ output_states = output_states + (hidden_states,)
263
+
264
+ if MODE == "read":
265
+ self.mean_bank = []
266
+ self.var_bank = []
267
+
268
+ if self.downsamplers is not None:
269
+ for downsampler in self.downsamplers:
270
+ hidden_states = downsampler(hidden_states)
271
+
272
+ output_states = output_states + (hidden_states,)
273
+
274
+ return hidden_states, output_states
275
+
276
+ def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
277
+ eps = 1e-6
278
+
279
+ output_states = ()
280
+
281
+ for i, resnet in enumerate(self.resnets):
282
+ hidden_states = resnet(hidden_states, temb)
283
+
284
+ if MODE == "write":
285
+ if gn_auto_machine_weight >= self.gn_weight:
286
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
287
+ self.mean_bank.append([mean])
288
+ self.var_bank.append([var])
289
+ if MODE == "read":
290
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
291
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
292
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
293
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
294
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
295
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
296
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
297
+ hidden_states_c = hidden_states_uc.clone()
298
+ if do_classifier_free_guidance and style_fidelity > 0:
299
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
300
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
301
+
302
+ output_states = output_states + (hidden_states,)
303
+
304
+ if MODE == "read":
305
+ self.mean_bank = []
306
+ self.var_bank = []
307
+
308
+ if self.downsamplers is not None:
309
+ for downsampler in self.downsamplers:
310
+ hidden_states = downsampler(hidden_states)
311
+
312
+ output_states = output_states + (hidden_states,)
313
+
314
+ return hidden_states, output_states
315
+
316
+ def hacked_CrossAttnUpBlock2D_forward(
317
+ self,
318
+ hidden_states: torch.FloatTensor,
319
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
320
+ temb: Optional[torch.FloatTensor] = None,
321
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
322
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
323
+ upsample_size: Optional[int] = None,
324
+ attention_mask: Optional[torch.FloatTensor] = None,
325
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
326
+ ):
327
+ eps = 1e-6
328
+ # TODO(Patrick, William) - attention face_hair_mask is not used
329
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
330
+ # pop res hidden states
331
+ res_hidden_states = res_hidden_states_tuple[-1]
332
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
333
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
334
+ hidden_states = resnet(hidden_states, temb)
335
+ hidden_states = attn(
336
+ hidden_states,
337
+ encoder_hidden_states=encoder_hidden_states,
338
+ cross_attention_kwargs=cross_attention_kwargs,
339
+ attention_mask=attention_mask,
340
+ encoder_attention_mask=encoder_attention_mask,
341
+ return_dict=False,
342
+ )[0]
343
+
344
+ if MODE == "write":
345
+ if gn_auto_machine_weight >= self.gn_weight:
346
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
347
+ self.mean_bank.append([mean])
348
+ self.var_bank.append([var])
349
+ if MODE == "read":
350
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
351
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
352
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
353
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
354
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
355
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
356
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
357
+ hidden_states_c = hidden_states_uc.clone()
358
+ if do_classifier_free_guidance and style_fidelity > 0:
359
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
360
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
361
+
362
+ if MODE == "read":
363
+ self.mean_bank = []
364
+ self.var_bank = []
365
+
366
+ if self.upsamplers is not None:
367
+ for upsampler in self.upsamplers:
368
+ hidden_states = upsampler(hidden_states, upsample_size)
369
+
370
+ return hidden_states
371
+
372
+ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
373
+ eps = 1e-6
374
+ for i, resnet in enumerate(self.resnets):
375
+ # pop res hidden states
376
+ res_hidden_states = res_hidden_states_tuple[-1]
377
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
378
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
379
+ hidden_states = resnet(hidden_states, temb)
380
+
381
+ if MODE == "write":
382
+ if gn_auto_machine_weight >= self.gn_weight:
383
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
384
+ self.mean_bank.append([mean])
385
+ self.var_bank.append([var])
386
+ if MODE == "read":
387
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
388
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
389
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
390
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
391
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
392
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
393
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
394
+ hidden_states_c = hidden_states_uc.clone()
395
+ if do_classifier_free_guidance and style_fidelity > 0:
396
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
397
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
398
+
399
+ if MODE == "read":
400
+ self.mean_bank = []
401
+ self.var_bank = []
402
+
403
+ if self.upsamplers is not None:
404
+ for upsampler in self.upsamplers:
405
+ hidden_states = upsampler(hidden_states, upsample_size)
406
+
407
+ return hidden_states
408
+
409
+ if self.reference_attn:
410
+ if self.fusion_blocks == "midup":
411
+ attn_modules = [module for module in (torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks))
412
+ if isinstance(module, BasicTransformerBlock)]
413
+ elif self.fusion_blocks == "full":
414
+ attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]
415
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
416
+
417
+ for i, module in enumerate(attn_modules):
418
+ module._original_inner_forward = module.forward
419
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
420
+ module.bank = []
421
+ module.attn_weight = float(i) / float(len(attn_modules))
422
+
423
+ if self.reference_adain:
424
+ gn_modules = [self.unet.mid_block]
425
+ self.unet.mid_block.gn_weight = 0
426
+
427
+ down_blocks = self.unet.down_blocks
428
+ for w, module in enumerate(down_blocks):
429
+ module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
430
+ gn_modules.append(module)
431
+
432
+ up_blocks = self.unet.up_blocks
433
+ for w, module in enumerate(up_blocks):
434
+ module.gn_weight = float(w) / float(len(up_blocks))
435
+ gn_modules.append(module)
436
+
437
+ for i, module in enumerate(gn_modules):
438
+ if getattr(module, "original_forward", None) is None:
439
+ module.original_forward = module.forward
440
+ if i == 0:
441
+ # mid_block
442
+ module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
443
+ elif isinstance(module, CrossAttnDownBlock2D):
444
+ module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
445
+ elif isinstance(module, DownBlock2D):
446
+ module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
447
+ elif isinstance(module, CrossAttnUpBlock2D):
448
+ module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
449
+ elif isinstance(module, UpBlock2D):
450
+ module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
451
+ module.mean_bank = []
452
+ module.var_bank = []
453
+ module.gn_weight *= 2
454
+
455
+ def update(self, writer, dtype=torch.float16):
456
+ if self.reference_attn:
457
+ if self.fusion_blocks == "midup":
458
+ reader_attn_modules = [module for module in
459
+ (torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)) if
460
+ isinstance(module, BasicTransformerBlock)]
461
+ writer_attn_modules = [module for module in
462
+ (torch_dfs(writer.unet.mid_block) + torch_dfs(writer.unet.up_blocks)) if
463
+ isinstance(module, BasicTransformerBlock)]
464
+ elif self.fusion_blocks == "full":
465
+ reader_attn_modules = [module for module in torch_dfs(self.unet) if
466
+ isinstance(module, BasicTransformerBlock)]
467
+ writer_attn_modules = [module for module in torch_dfs(writer.unet) if
468
+ isinstance(module, BasicTransformerBlock)]
469
+ reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
470
+ writer_attn_modules = sorted(writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
471
+ for r, w in zip(reader_attn_modules, writer_attn_modules):
472
+ r.bank = [v.clone().to(dtype) for v in w.bank]
473
+
474
+ if self.reference_adain:
475
+ reader_gn_modules = [self.unet.mid_block]
476
+
477
+ down_blocks = self.unet.down_blocks
478
+ for w, module in enumerate(down_blocks):
479
+ reader_gn_modules.append(module)
480
+
481
+ up_blocks = self.unet.up_blocks
482
+ for w, module in enumerate(up_blocks):
483
+ reader_gn_modules.append(module)
484
+
485
+ writer_gn_modules = [writer.unet.mid_block]
486
+
487
+ down_blocks = writer.unet.down_blocks
488
+ for w, module in enumerate(down_blocks):
489
+ writer_gn_modules.append(module)
490
+
491
+ up_blocks = writer.unet.up_blocks
492
+ for w, module in enumerate(up_blocks):
493
+ writer_gn_modules.append(module)
494
+
495
+ for r, w in zip(reader_gn_modules, writer_gn_modules):
496
+ if len(w.mean_bank) > 0 and isinstance(w.mean_bank[0], list):
497
+ r.mean_bank = [[v.clone().to(dtype) for v in vl] for vl in w.mean_bank]
498
+ r.var_bank = [[v.clone().to(dtype) for v in vl] for vl in w.var_bank]
499
+ else:
500
+ r.mean_bank = [v.clone().to(dtype) for v in w.mean_bank]
501
+ r.var_bank = [v.clone().to(dtype) for v in w.var_bank]
502
+
503
+ def clear(self):
504
+ if self.reference_attn:
505
+ if self.fusion_blocks == "midup":
506
+ reader_attn_modules = [module for module in
507
+ (torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)) if
508
+ isinstance(module, BasicTransformerBlock)]
509
+ elif self.fusion_blocks == "full":
510
+ reader_attn_modules = [module for module in torch_dfs(self.unet) if
511
+ isinstance(module, BasicTransformerBlock)]
512
+ reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
513
+ for r in reader_attn_modules:
514
+ r.bank.clear()
515
+ if self.reference_adain:
516
+ reader_gn_modules = [self.unet.mid_block]
517
+
518
+ down_blocks = self.unet.down_blocks
519
+ for w, module in enumerate(down_blocks):
520
+ reader_gn_modules.append(module)
521
+
522
+ up_blocks = self.unet.up_blocks
523
+ for w, module in enumerate(up_blocks):
524
+ reader_gn_modules.append(module)
525
+
526
+ for r in reader_gn_modules:
527
+ r.mean_bank.clear()
528
+ r.var_bank.clear()
ref_encoder/reference_unet.py ADDED
@@ -0,0 +1,1053 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.utils.checkpoint
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import UNet2DConditionLoadersMixin
10
+ from diffusers.utils import BaseOutput, logging
11
+ from diffusers.models.activations import get_activation
12
+ from diffusers.models.attention_processor import (
13
+ ADDED_KV_ATTENTION_PROCESSORS,
14
+ CROSS_ATTENTION_PROCESSORS,
15
+ AttentionProcessor,
16
+ AttnAddedKVProcessor,
17
+ AttnProcessor,
18
+ )
19
+ from diffusers.models.lora import LoRALinearLayer
20
+ from diffusers.models.embeddings import (
21
+ GaussianFourierProjection,
22
+ ImageHintTimeEmbedding,
23
+ ImageProjection,
24
+ ImageTimeEmbedding,
25
+ PositionNet,
26
+ TextImageProjection,
27
+ TextImageTimeEmbedding,
28
+ TextTimeEmbedding,
29
+ TimestepEmbedding,
30
+ Timesteps,
31
+ )
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.unet_2d_blocks import (
34
+ UNetMidBlock2DCrossAttn,
35
+ UNetMidBlock2DSimpleCrossAttn,
36
+ get_down_block,
37
+ get_up_block,
38
+ )
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+ class CCProjection(ModelMixin, ConfigMixin):
44
+ def __init__(self, in_channel=772, out_channel=768):
45
+ super().__init__()
46
+ self.in_channel = in_channel
47
+ self.out_channel = out_channel
48
+ self.projection = torch.nn.Linear(in_channel, out_channel)
49
+
50
+ def forward(self, x):
51
+ return self.projection(x)
52
+
53
+
54
+ class Identity(torch.nn.Module):
55
+ def __init__(self, scale=None, *args, **kwargs) -> None:
56
+ super(Identity, self).__init__()
57
+ def forward(self, input, *args, **kwargs):
58
+ return input
59
+
60
+
61
+ class _LoRACompatibleLinear(nn.Module):
62
+ """
63
+ A Linear layer that can be used with LoRA.
64
+ """
65
+
66
+ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
67
+ super().__init__(*args, **kwargs)
68
+ self.lora_layer = lora_layer
69
+
70
+ def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
71
+ self.lora_layer = lora_layer
72
+
73
+ def _fuse_lora(self):
74
+ pass
75
+
76
+ def _unfuse_lora(self):
77
+ pass
78
+
79
+ def forward(self, hidden_states, scale=None, lora_scale: int = 1):
80
+ return hidden_states
81
+
82
+
83
+ @dataclass
84
+ class UNet2DConditionOutput(BaseOutput):
85
+ """
86
+ The output of [`UNet2DConditionModel`].
87
+
88
+ Args:
89
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
90
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
91
+ """
92
+
93
+ sample: torch.FloatTensor = None
94
+
95
+
96
+ class ref_unet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
97
+ r"""
98
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
99
+ shaped output.
100
+
101
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
102
+ for all models (such as downloading or saving).
103
+
104
+ Parameters:
105
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
106
+ Height and width of input/output sample.
107
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
108
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
109
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
110
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
111
+ Whether to flip the sin to cos in the time embedding.
112
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
113
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
114
+ The tuple of downsample blocks to use.
115
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
116
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
117
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
118
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
119
+ The tuple of upsample blocks to use.
120
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
121
+ Whether to include self-attention in the basic transformer blocks, see
122
+ [`~models.attention.BasicTransformerBlock`].
123
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
124
+ The tuple of output channels for each block.
125
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
126
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
127
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
128
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
129
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
130
+ If `None`, normalization and activation layers is skipped in post-processing.
131
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
132
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
133
+ The dimension of the cross attention features.
134
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
135
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
136
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
137
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
138
+ encoder_hid_dim (`int`, *optional*, defaults to None):
139
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
140
+ dimension to `cross_attention_dim`.
141
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
142
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
143
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
144
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
145
+ num_attention_heads (`int`, *optional*):
146
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
147
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
148
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
149
+ class_embed_type (`str`, *optional*, defaults to `None`):
150
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
151
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
152
+ addition_embed_type (`str`, *optional*, defaults to `None`):
153
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
154
+ "text". "text" will use the `TextTimeEmbedding` layer.
155
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
156
+ Dimension for the timestep embeddings.
157
+ num_class_embeds (`int`, *optional*, defaults to `None`):
158
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
159
+ class conditioning with `class_embed_type` equal to `None`.
160
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
161
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
162
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
163
+ An optional override for the dimension of the projected time embedding.
164
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
165
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
166
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
167
+ timestep_post_act (`str`, *optional*, defaults to `None`):
168
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
169
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
170
+ The dimension of `cond_proj` layer in the timestep embedding.
171
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
172
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
173
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
174
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
175
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
176
+ embeddings with the class embeddings.
177
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
178
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
179
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
180
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
181
+ otherwise.
182
+ """
183
+
184
+ _supports_gradient_checkpointing = True
185
+
186
+ @register_to_config
187
+ def __init__(
188
+ self,
189
+ sample_size: Optional[int] = None,
190
+ in_channels: int = 4,
191
+ out_channels: int = 4,
192
+ center_input_sample: bool = False,
193
+ flip_sin_to_cos: bool = True,
194
+ freq_shift: int = 0,
195
+ down_block_types: Tuple[str] = (
196
+ "CrossAttnDownBlock2D",
197
+ "CrossAttnDownBlock2D",
198
+ "CrossAttnDownBlock2D",
199
+ "DownBlock2D",
200
+ ),
201
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
202
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
203
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
204
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
205
+ layers_per_block: Union[int, Tuple[int]] = 2,
206
+ downsample_padding: int = 1,
207
+ mid_block_scale_factor: float = 1,
208
+ act_fn: str = "silu",
209
+ norm_num_groups: Optional[int] = 32,
210
+ norm_eps: float = 1e-5,
211
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
212
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
213
+ encoder_hid_dim: Optional[int] = None,
214
+ encoder_hid_dim_type: Optional[str] = None,
215
+ attention_head_dim: Union[int, Tuple[int]] = 8,
216
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
217
+ dual_cross_attention: bool = False,
218
+ use_linear_projection: bool = False,
219
+ class_embed_type: Optional[str] = None,
220
+ addition_embed_type: Optional[str] = None,
221
+ addition_time_embed_dim: Optional[int] = None,
222
+ num_class_embeds: Optional[int] = None,
223
+ upcast_attention: bool = False,
224
+ resnet_time_scale_shift: str = "default",
225
+ resnet_skip_time_act: bool = False,
226
+ resnet_out_scale_factor: int = 1.0,
227
+ time_embedding_type: str = "positional",
228
+ time_embedding_dim: Optional[int] = None,
229
+ time_embedding_act_fn: Optional[str] = None,
230
+ timestep_post_act: Optional[str] = None,
231
+ time_cond_proj_dim: Optional[int] = None,
232
+ conv_in_kernel: int = 3,
233
+ conv_out_kernel: int = 3,
234
+ projection_class_embeddings_input_dim: Optional[int] = None,
235
+ attention_type: str = "default",
236
+ class_embeddings_concat: bool = False,
237
+ mid_block_only_cross_attention: Optional[bool] = None,
238
+ cross_attention_norm: Optional[str] = None,
239
+ addition_embed_type_num_heads=64,
240
+ ):
241
+ super().__init__()
242
+
243
+ self.sample_size = sample_size
244
+
245
+ if num_attention_heads is not None:
246
+ raise ValueError(
247
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
248
+ )
249
+
250
+ # If `num_attention_heads` is not defined (which is the case for most models)
251
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
252
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
253
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
254
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
255
+ # which is why we correct for the naming here.
256
+ num_attention_heads = num_attention_heads or attention_head_dim
257
+
258
+ # Check inputs
259
+ if len(down_block_types) != len(up_block_types):
260
+ raise ValueError(
261
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
262
+ )
263
+
264
+ if len(block_out_channels) != len(down_block_types):
265
+ raise ValueError(
266
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
267
+ )
268
+
269
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
270
+ raise ValueError(
271
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
272
+ )
273
+
274
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
275
+ raise ValueError(
276
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
277
+ )
278
+
279
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
280
+ raise ValueError(
281
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
282
+ )
283
+
284
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
285
+ raise ValueError(
286
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
287
+ )
288
+
289
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
290
+ raise ValueError(
291
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
292
+ )
293
+
294
+ # input
295
+ conv_in_padding = (conv_in_kernel - 1) // 2
296
+ self.conv_in = nn.Conv2d(
297
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
298
+ )
299
+
300
+ self.add_time_proj = Timesteps(256, True, downscale_freq_shift=0)
301
+ self.add_angle_proj = Timesteps(512, True, downscale_freq_shift=0) # encode camera angles
302
+ self.add_embedding = TimestepEmbedding(1280, 1280)
303
+
304
+ # time
305
+ if time_embedding_type == "fourier":
306
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
307
+ if time_embed_dim % 2 != 0:
308
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
309
+ self.time_proj = GaussianFourierProjection(
310
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
311
+ )
312
+ timestep_input_dim = time_embed_dim
313
+ elif time_embedding_type == "positional":
314
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
315
+
316
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
317
+ timestep_input_dim = block_out_channels[0]
318
+ else:
319
+ raise ValueError(
320
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
321
+ )
322
+
323
+ self.time_embedding = TimestepEmbedding(
324
+ timestep_input_dim,
325
+ time_embed_dim,
326
+ act_fn=act_fn,
327
+ post_act_fn=timestep_post_act,
328
+ cond_proj_dim=time_cond_proj_dim,
329
+ )
330
+
331
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
332
+ encoder_hid_dim_type = "text_proj"
333
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
334
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
335
+
336
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
337
+ raise ValueError(
338
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
339
+ )
340
+
341
+ if encoder_hid_dim_type == "text_proj":
342
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
343
+ elif encoder_hid_dim_type == "text_image_proj":
344
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
345
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
346
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
347
+ self.encoder_hid_proj = TextImageProjection(
348
+ text_embed_dim=encoder_hid_dim,
349
+ image_embed_dim=cross_attention_dim,
350
+ cross_attention_dim=cross_attention_dim,
351
+ )
352
+ elif encoder_hid_dim_type == "image_proj":
353
+ # Kandinsky 2.2
354
+ self.encoder_hid_proj = ImageProjection(
355
+ image_embed_dim=encoder_hid_dim,
356
+ cross_attention_dim=cross_attention_dim,
357
+ )
358
+ elif encoder_hid_dim_type is not None:
359
+ raise ValueError(
360
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
361
+ )
362
+ else:
363
+ self.encoder_hid_proj = None
364
+
365
+ # class embedding
366
+ if class_embed_type is None and num_class_embeds is not None:
367
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
368
+ elif class_embed_type == "timestep":
369
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
370
+ elif class_embed_type == "identity":
371
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
372
+ elif class_embed_type == "projection":
373
+ if projection_class_embeddings_input_dim is None:
374
+ raise ValueError(
375
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
376
+ )
377
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
378
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
379
+ # 2. it projects from an arbitrary input dimension.
380
+ #
381
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
382
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
383
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
384
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
385
+ elif class_embed_type == "simple_projection":
386
+ if projection_class_embeddings_input_dim is None:
387
+ raise ValueError(
388
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
389
+ )
390
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
391
+ else:
392
+ self.class_embedding = None
393
+
394
+ if addition_embed_type == "text":
395
+ if encoder_hid_dim is not None:
396
+ text_time_embedding_from_dim = encoder_hid_dim
397
+ else:
398
+ text_time_embedding_from_dim = cross_attention_dim
399
+
400
+ self.add_embedding = TextTimeEmbedding(
401
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
402
+ )
403
+ elif addition_embed_type == "text_image":
404
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
405
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
406
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
407
+ self.add_embedding = TextImageTimeEmbedding(
408
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
409
+ )
410
+ elif addition_embed_type == "text_time":
411
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
412
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
413
+ elif addition_embed_type == "image":
414
+ # Kandinsky 2.2
415
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
416
+ elif addition_embed_type == "image_hint":
417
+ # Kandinsky 2.2 ControlNet
418
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
419
+ elif addition_embed_type is not None:
420
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
421
+
422
+ if time_embedding_act_fn is None:
423
+ self.time_embed_act = None
424
+ else:
425
+ self.time_embed_act = get_activation(time_embedding_act_fn)
426
+
427
+ self.down_blocks = nn.ModuleList([])
428
+ self.up_blocks = nn.ModuleList([])
429
+
430
+ if isinstance(only_cross_attention, bool):
431
+ if mid_block_only_cross_attention is None:
432
+ mid_block_only_cross_attention = only_cross_attention
433
+
434
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
435
+
436
+ if mid_block_only_cross_attention is None:
437
+ mid_block_only_cross_attention = False
438
+
439
+ if isinstance(num_attention_heads, int):
440
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
441
+
442
+ if isinstance(attention_head_dim, int):
443
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
444
+
445
+ if isinstance(cross_attention_dim, int):
446
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
447
+
448
+ if isinstance(layers_per_block, int):
449
+ layers_per_block = [layers_per_block] * len(down_block_types)
450
+
451
+ if isinstance(transformer_layers_per_block, int):
452
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
453
+
454
+ if class_embeddings_concat:
455
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
456
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
457
+ # regular time embeddings
458
+ blocks_time_embed_dim = time_embed_dim * 2
459
+ else:
460
+ blocks_time_embed_dim = time_embed_dim
461
+
462
+ # down
463
+ output_channel = block_out_channels[0]
464
+ for i, down_block_type in enumerate(down_block_types):
465
+ input_channel = output_channel
466
+ output_channel = block_out_channels[i]
467
+ is_final_block = i == len(block_out_channels) - 1
468
+
469
+ down_block = get_down_block(
470
+ down_block_type,
471
+ num_layers=layers_per_block[i],
472
+ transformer_layers_per_block=transformer_layers_per_block[i],
473
+ in_channels=input_channel,
474
+ out_channels=output_channel,
475
+ temb_channels=blocks_time_embed_dim,
476
+ add_downsample=not is_final_block,
477
+ resnet_eps=norm_eps,
478
+ resnet_act_fn=act_fn,
479
+ resnet_groups=norm_num_groups,
480
+ cross_attention_dim=cross_attention_dim[i],
481
+ num_attention_heads=num_attention_heads[i],
482
+ downsample_padding=downsample_padding,
483
+ dual_cross_attention=dual_cross_attention,
484
+ use_linear_projection=use_linear_projection,
485
+ only_cross_attention=only_cross_attention[i],
486
+ upcast_attention=upcast_attention,
487
+ resnet_time_scale_shift=resnet_time_scale_shift,
488
+ attention_type=attention_type,
489
+ resnet_skip_time_act=resnet_skip_time_act,
490
+ resnet_out_scale_factor=resnet_out_scale_factor,
491
+ cross_attention_norm=cross_attention_norm,
492
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
493
+ )
494
+ self.down_blocks.append(down_block)
495
+
496
+ # mid
497
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
498
+ self.mid_block = UNetMidBlock2DCrossAttn(
499
+ transformer_layers_per_block=transformer_layers_per_block[-1],
500
+ in_channels=block_out_channels[-1],
501
+ temb_channels=blocks_time_embed_dim,
502
+ resnet_eps=norm_eps,
503
+ resnet_act_fn=act_fn,
504
+ output_scale_factor=mid_block_scale_factor,
505
+ resnet_time_scale_shift=resnet_time_scale_shift,
506
+ cross_attention_dim=cross_attention_dim[-1],
507
+ num_attention_heads=num_attention_heads[-1],
508
+ resnet_groups=norm_num_groups,
509
+ dual_cross_attention=dual_cross_attention,
510
+ use_linear_projection=use_linear_projection,
511
+ upcast_attention=upcast_attention,
512
+ attention_type=attention_type,
513
+ )
514
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
515
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
516
+ in_channels=block_out_channels[-1],
517
+ temb_channels=blocks_time_embed_dim,
518
+ resnet_eps=norm_eps,
519
+ resnet_act_fn=act_fn,
520
+ output_scale_factor=mid_block_scale_factor,
521
+ cross_attention_dim=cross_attention_dim[-1],
522
+ attention_head_dim=attention_head_dim[-1],
523
+ resnet_groups=norm_num_groups,
524
+ resnet_time_scale_shift=resnet_time_scale_shift,
525
+ skip_time_act=resnet_skip_time_act,
526
+ only_cross_attention=mid_block_only_cross_attention,
527
+ cross_attention_norm=cross_attention_norm,
528
+ )
529
+ elif mid_block_type is None:
530
+ self.mid_block = None
531
+ else:
532
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
533
+
534
+ # count how many layers upsample the images
535
+ self.num_upsamplers = 0
536
+
537
+ # up
538
+ reversed_block_out_channels = list(reversed(block_out_channels))
539
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
540
+ reversed_layers_per_block = list(reversed(layers_per_block))
541
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
542
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
543
+ only_cross_attention = list(reversed(only_cross_attention))
544
+
545
+ output_channel = reversed_block_out_channels[0]
546
+ for i, up_block_type in enumerate(up_block_types):
547
+ is_final_block = i == len(block_out_channels) - 1
548
+
549
+ prev_output_channel = output_channel
550
+ output_channel = reversed_block_out_channels[i]
551
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
552
+
553
+ # add upsample block for all BUT final layer
554
+ if not is_final_block:
555
+ add_upsample = True
556
+ self.num_upsamplers += 1
557
+ else:
558
+ add_upsample = False
559
+
560
+ up_block = get_up_block(
561
+ up_block_type,
562
+ num_layers=reversed_layers_per_block[i] + 1,
563
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
564
+ in_channels=input_channel,
565
+ out_channels=output_channel,
566
+ prev_output_channel=prev_output_channel,
567
+ temb_channels=blocks_time_embed_dim,
568
+ add_upsample=add_upsample,
569
+ resnet_eps=norm_eps,
570
+ resnet_act_fn=act_fn,
571
+ resnet_groups=norm_num_groups,
572
+ cross_attention_dim=reversed_cross_attention_dim[i],
573
+ num_attention_heads=reversed_num_attention_heads[i],
574
+ dual_cross_attention=dual_cross_attention,
575
+ use_linear_projection=use_linear_projection,
576
+ only_cross_attention=only_cross_attention[i],
577
+ upcast_attention=upcast_attention,
578
+ resnet_time_scale_shift=resnet_time_scale_shift,
579
+ attention_type=attention_type,
580
+ resnet_skip_time_act=resnet_skip_time_act,
581
+ resnet_out_scale_factor=resnet_out_scale_factor,
582
+ cross_attention_norm=cross_attention_norm,
583
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
584
+ )
585
+ self.up_blocks.append(up_block)
586
+ prev_output_channel = output_channel
587
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear()
588
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear()
589
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear()
590
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()])
591
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity()
592
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = None
593
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity()
594
+ self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity()
595
+ self.up_blocks[3].attentions[2].proj_out = Identity()
596
+
597
+ if attention_type in ["gated", "gated-text-image"]:
598
+ positive_len = 768
599
+ if isinstance(cross_attention_dim, int):
600
+ positive_len = cross_attention_dim
601
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
602
+ positive_len = cross_attention_dim[0]
603
+
604
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
605
+ self.position_net = PositionNet(
606
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
607
+ )
608
+
609
+ @property
610
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
611
+ r"""
612
+ Returns:
613
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
614
+ indexed by its weight name.
615
+ """
616
+ # set recursively
617
+ processors = {}
618
+
619
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
620
+ if hasattr(module, "get_processor"):
621
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
622
+
623
+ for sub_name, child in module.named_children():
624
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
625
+
626
+ return processors
627
+
628
+ for name, module in self.named_children():
629
+ fn_recursive_add_processors(name, module, processors)
630
+
631
+ return processors
632
+
633
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
634
+ r"""
635
+ Sets the attention processor to use to compute attention.
636
+
637
+ Parameters:
638
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
639
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
640
+ for **all** `Attention` layers.
641
+
642
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
643
+ processor. This is strongly recommended when setting trainable attention processors.
644
+
645
+ """
646
+ count = len(self.attn_processors.keys())
647
+
648
+ if isinstance(processor, dict) and len(processor) != count:
649
+ raise ValueError(
650
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
651
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
652
+ )
653
+
654
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
655
+ if hasattr(module, "set_processor"):
656
+ if not isinstance(processor, dict):
657
+ module.set_processor(processor)
658
+ else:
659
+ module.set_processor(processor.pop(f"{name}.processor"))
660
+
661
+ for sub_name, child in module.named_children():
662
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
663
+
664
+ for name, module in self.named_children():
665
+ fn_recursive_attn_processor(name, module, processor)
666
+
667
+ def set_default_attn_processor(self):
668
+ """
669
+ Disables custom attention processors and sets the default attention implementation.
670
+ """
671
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
672
+ processor = AttnAddedKVProcessor()
673
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
674
+ processor = AttnProcessor()
675
+ else:
676
+ raise ValueError(
677
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
678
+ )
679
+
680
+ self.set_attn_processor(processor)
681
+
682
+ def set_attention_slice(self, slice_size):
683
+ r"""
684
+ Enable sliced attention computation.
685
+
686
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
687
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
688
+
689
+ Args:
690
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
691
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
692
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
693
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
694
+ must be a multiple of `slice_size`.
695
+ """
696
+ sliceable_head_dims = []
697
+
698
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
699
+ if hasattr(module, "set_attention_slice"):
700
+ sliceable_head_dims.append(module.sliceable_head_dim)
701
+
702
+ for child in module.children():
703
+ fn_recursive_retrieve_sliceable_dims(child)
704
+
705
+ # retrieve number of attention layers
706
+ for module in self.children():
707
+ fn_recursive_retrieve_sliceable_dims(module)
708
+
709
+ num_sliceable_layers = len(sliceable_head_dims)
710
+
711
+ if slice_size == "auto":
712
+ # half the attention head size is usually a good trade-off between
713
+ # speed and memory
714
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
715
+ elif slice_size == "max":
716
+ # make smallest slice possible
717
+ slice_size = num_sliceable_layers * [1]
718
+
719
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
720
+
721
+ if len(slice_size) != len(sliceable_head_dims):
722
+ raise ValueError(
723
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
724
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
725
+ )
726
+
727
+ for i in range(len(slice_size)):
728
+ size = slice_size[i]
729
+ dim = sliceable_head_dims[i]
730
+ if size is not None and size > dim:
731
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
732
+
733
+ # Recursively walk through all the children.
734
+ # Any children which exposes the set_attention_slice method
735
+ # gets the message
736
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
737
+ if hasattr(module, "set_attention_slice"):
738
+ module.set_attention_slice(slice_size.pop())
739
+
740
+ for child in module.children():
741
+ fn_recursive_set_attention_slice(child, slice_size)
742
+
743
+ reversed_slice_size = list(reversed(slice_size))
744
+ for module in self.children():
745
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
746
+
747
+ def _set_gradient_checkpointing(self, module, value=False):
748
+ if hasattr(module, "gradient_checkpointing"):
749
+ module.gradient_checkpointing = value
750
+
751
+ def forward(
752
+ self,
753
+ sample: torch.FloatTensor,
754
+ timestep: Union[torch.Tensor, float, int],
755
+ encoder_hidden_states: torch.Tensor,
756
+ class_labels: Optional[torch.Tensor] = None,
757
+ timestep_cond: Optional[torch.Tensor] = None,
758
+ attention_mask: Optional[torch.Tensor] = None,
759
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
760
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
761
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
762
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
763
+ encoder_attention_mask: Optional[torch.Tensor] = None,
764
+ return_dict: bool = True,
765
+ add_time_ids: List=None,
766
+ ) -> Union[UNet2DConditionOutput, Tuple]:
767
+ r"""
768
+ The [`UNet2DConditionModel`] forward method.
769
+
770
+ Args:
771
+ sample (`torch.FloatTensor`):
772
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
773
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
774
+ encoder_hidden_states (`torch.FloatTensor`):
775
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
776
+ encoder_attention_mask (`torch.Tensor`):
777
+ A cross-attention face_hair_mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
778
+ `True` the face_hair_mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
779
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
780
+ return_dict (`bool`, *optional*, defaults to `True`):
781
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
782
+ tuple.
783
+ cross_attention_kwargs (`dict`, *optional*):
784
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
785
+ added_cond_kwargs: (`dict`, *optional*):
786
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
787
+ are passed along to the UNet blocks.
788
+
789
+ Returns:
790
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
791
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
792
+ a `tuple` is returned where the first element is the sample tensor.
793
+ """
794
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
795
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
796
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
797
+ # on the fly if necessary.
798
+ default_overall_up_factor = 2**self.num_upsamplers
799
+
800
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
801
+ forward_upsample_size = False
802
+ upsample_size = None
803
+
804
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
805
+ logger.info("Forward upsample size to force interpolation output size.")
806
+ forward_upsample_size = True
807
+
808
+ if attention_mask is not None:
809
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
810
+ attention_mask = attention_mask.unsqueeze(1)
811
+
812
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
813
+ if encoder_attention_mask is not None:
814
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
815
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
816
+
817
+ # 0. center input if necessary
818
+ if self.config.center_input_sample:
819
+ sample = 2 * sample - 1.0
820
+
821
+ # 1. time
822
+ timesteps = timestep
823
+ if not torch.is_tensor(timesteps):
824
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
825
+ # This would be a good case for the `match` statement (Python 3.10+)
826
+ is_mps = sample.device.type == "mps"
827
+ if isinstance(timestep, float):
828
+ dtype = torch.float32 if is_mps else torch.float64
829
+ else:
830
+ dtype = torch.int32 if is_mps else torch.int64
831
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
832
+ elif len(timesteps.shape) == 0:
833
+ timesteps = timesteps[None].to(sample.device)
834
+
835
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
836
+ timesteps = timesteps.expand(sample.shape[0])
837
+
838
+ t_emb = self.time_proj(timesteps)
839
+
840
+ # `Timesteps` does not contain any weights and will always return f32 tensors
841
+ # but time_embedding might actually be running in fp16. so we need to cast here.
842
+ # there might be better ways to encapsulate this.
843
+ t_emb = t_emb.to(dtype=sample.dtype)
844
+
845
+ emb = self.time_embedding(t_emb, timestep_cond)
846
+ aug_emb = None
847
+
848
+ if self.class_embedding is not None:
849
+ if class_labels is None:
850
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
851
+
852
+ if self.config.class_embed_type == "timestep":
853
+ class_labels = self.time_proj(class_labels)
854
+
855
+ # `Timesteps` does not contain any weights and will always return f32 tensors
856
+ # there might be better ways to encapsulate this.
857
+ class_labels = class_labels.to(dtype=sample.dtype)
858
+
859
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
860
+
861
+ if self.config.class_embeddings_concat:
862
+ emb = torch.cat([emb, class_emb], dim=-1)
863
+ else:
864
+ emb = emb + class_emb
865
+
866
+ if self.config.addition_embed_type == "text":
867
+ aug_emb = self.add_embedding(encoder_hidden_states)
868
+ elif self.config.addition_embed_type == "text_image":
869
+ # Kandinsky 2.1 - style
870
+ if "image_embeds" not in added_cond_kwargs:
871
+ raise ValueError(
872
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
873
+ )
874
+
875
+ image_embs = added_cond_kwargs.get("image_embeds")
876
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
877
+ aug_emb = self.add_embedding(text_embs, image_embs)
878
+ elif self.config.addition_embed_type == "text_time":
879
+ # SDXL - style
880
+ if "text_embeds" not in added_cond_kwargs:
881
+ raise ValueError(
882
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
883
+ )
884
+ text_embeds = added_cond_kwargs.get("text_embeds")
885
+ if "time_ids" not in added_cond_kwargs:
886
+ raise ValueError(
887
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
888
+ )
889
+ time_ids = added_cond_kwargs.get("time_ids")
890
+ time_embeds = self.add_time_proj(time_ids.flatten())
891
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
892
+
893
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
894
+ add_embeds = add_embeds.to(emb.dtype)
895
+ aug_emb = self.add_embedding(add_embeds)
896
+ elif self.config.addition_embed_type == "image":
897
+ # Kandinsky 2.2 - style
898
+ if "image_embeds" not in added_cond_kwargs:
899
+ raise ValueError(
900
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
901
+ )
902
+ image_embs = added_cond_kwargs.get("image_embeds")
903
+ aug_emb = self.add_embedding(image_embs)
904
+ elif self.config.addition_embed_type == "image_hint":
905
+ # Kandinsky 2.2 - style
906
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
907
+ raise ValueError(
908
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
909
+ )
910
+ image_embs = added_cond_kwargs.get("image_embeds")
911
+ hint = added_cond_kwargs.get("hint")
912
+ aug_emb, hint = self.add_embedding(image_embs, hint)
913
+ sample = torch.cat([sample, hint], dim=1)
914
+
915
+ emb = emb + aug_emb if aug_emb is not None else emb
916
+
917
+ if add_time_ids is not None:
918
+ cond_aug, polars, azimuths = add_time_ids
919
+ cond_aug_emb = self.add_time_proj(cond_aug.flatten())
920
+ polars_emb = self.add_angle_proj(polars.flatten())
921
+ azimuths_emb = self.add_angle_proj(azimuths.flatten())
922
+ time_embeds = torch.cat([cond_aug_emb, polars_emb, azimuths_emb],dim=1)
923
+ time_embeds = time_embeds.to(emb.dtype)
924
+ aug_emb = self.add_embedding(time_embeds)
925
+ emb = emb + aug_emb
926
+
927
+ if self.time_embed_act is not None:
928
+ emb = self.time_embed_act(emb)
929
+
930
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
931
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
932
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
933
+ # Kadinsky 2.1 - style
934
+ if "image_embeds" not in added_cond_kwargs:
935
+ raise ValueError(
936
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
937
+ )
938
+
939
+ image_embeds = added_cond_kwargs.get("image_embeds")
940
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
941
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
942
+ # Kandinsky 2.2 - style
943
+ if "image_embeds" not in added_cond_kwargs:
944
+ raise ValueError(
945
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
946
+ )
947
+ image_embeds = added_cond_kwargs.get("image_embeds")
948
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
949
+ # 2. pre-process
950
+ sample = self.conv_in(sample)
951
+
952
+ # 2.5 GLIGEN position net
953
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
954
+ cross_attention_kwargs = cross_attention_kwargs.copy()
955
+ gligen_args = cross_attention_kwargs.pop("gligen")
956
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
957
+
958
+ # 3. down
959
+
960
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
961
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
962
+
963
+ down_block_res_samples = (sample,)
964
+ for downsample_block in self.down_blocks:
965
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
966
+ # For t2i-adapter CrossAttnDownBlock2D
967
+ additional_residuals = {}
968
+ if is_adapter and len(down_block_additional_residuals) > 0:
969
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
970
+
971
+ sample, res_samples = downsample_block(
972
+ hidden_states=sample,
973
+ temb=emb,
974
+ encoder_hidden_states=encoder_hidden_states,
975
+ attention_mask=attention_mask,
976
+ cross_attention_kwargs=cross_attention_kwargs,
977
+ encoder_attention_mask=encoder_attention_mask,
978
+ **additional_residuals,
979
+ )
980
+ else:
981
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
982
+
983
+ if is_adapter and len(down_block_additional_residuals) > 0:
984
+ sample += down_block_additional_residuals.pop(0)
985
+
986
+ down_block_res_samples += res_samples
987
+
988
+ if is_controlnet:
989
+ new_down_block_res_samples = ()
990
+
991
+ for down_block_res_sample, down_block_additional_residual in zip(
992
+ down_block_res_samples, down_block_additional_residuals
993
+ ):
994
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
995
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
996
+
997
+ down_block_res_samples = new_down_block_res_samples
998
+
999
+ # 4. mid
1000
+ if self.mid_block is not None:
1001
+ sample = self.mid_block(
1002
+ sample,
1003
+ emb,
1004
+ encoder_hidden_states=encoder_hidden_states,
1005
+ attention_mask=attention_mask,
1006
+ cross_attention_kwargs=cross_attention_kwargs,
1007
+ encoder_attention_mask=encoder_attention_mask,
1008
+ )
1009
+ # To support T2I-Adapter-XL
1010
+ if (
1011
+ is_adapter
1012
+ and len(down_block_additional_residuals) > 0
1013
+ and sample.shape == down_block_additional_residuals[0].shape
1014
+ ):
1015
+ sample += down_block_additional_residuals.pop(0)
1016
+
1017
+ if is_controlnet:
1018
+ sample = sample + mid_block_additional_residual
1019
+
1020
+ # 5. up
1021
+ for i, upsample_block in enumerate(self.up_blocks):
1022
+ is_final_block = i == len(self.up_blocks) - 1
1023
+
1024
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1025
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1026
+
1027
+ # if we have not reached the final block and need to forward the
1028
+ # upsample size, we do it here
1029
+ if not is_final_block and forward_upsample_size:
1030
+
1031
+ upsample_size = down_block_res_samples[-1].shape[2:]
1032
+
1033
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1034
+
1035
+ sample = upsample_block(
1036
+ hidden_states=sample,
1037
+ temb=emb,
1038
+ res_hidden_states_tuple=res_samples,
1039
+ encoder_hidden_states=encoder_hidden_states,
1040
+ cross_attention_kwargs=cross_attention_kwargs,
1041
+ upsample_size=upsample_size,
1042
+ attention_mask=attention_mask,
1043
+ encoder_attention_mask=encoder_attention_mask,
1044
+ )
1045
+ else:
1046
+ sample = upsample_block(
1047
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1048
+ )
1049
+
1050
+ if not return_dict:
1051
+ return (sample,)
1052
+
1053
+ return UNet2DConditionOutput(sample=sample)
ref_encoder/reference_unetv2.py ADDED
@@ -0,0 +1,1037 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.utils.checkpoint
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import UNet2DConditionLoadersMixin
10
+ from diffusers.utils import BaseOutput, logging
11
+ from diffusers.models.activations import get_activation
12
+ from diffusers.models.attention_processor import (
13
+ ADDED_KV_ATTENTION_PROCESSORS,
14
+ CROSS_ATTENTION_PROCESSORS,
15
+ AttentionProcessor,
16
+ AttnAddedKVProcessor,
17
+ AttnProcessor,
18
+ )
19
+ from diffusers.models.lora import LoRALinearLayer
20
+ from diffusers.models.embeddings import (
21
+ GaussianFourierProjection,
22
+ ImageHintTimeEmbedding,
23
+ ImageProjection,
24
+ ImageTimeEmbedding,
25
+ PositionNet,
26
+ TextImageProjection,
27
+ TextImageTimeEmbedding,
28
+ TextTimeEmbedding,
29
+ TimestepEmbedding,
30
+ Timesteps,
31
+ )
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.unet_2d_blocks import (
34
+ UNetMidBlock2DCrossAttn,
35
+ UNetMidBlock2DSimpleCrossAttn,
36
+ get_down_block,
37
+ get_up_block,
38
+ )
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+
44
+ class CCProjection(ModelMixin, ConfigMixin):
45
+ def __init__(self, in_channel=772, out_channel=768):
46
+ super().__init__()
47
+ self.in_channel = in_channel
48
+ self.out_channel = out_channel
49
+ self.projection = torch.nn.Linear(in_channel, out_channel)
50
+
51
+ def forward(self, x):
52
+ return self.projection(x)
53
+
54
+
55
+ class Identity(torch.nn.Module):
56
+ def __init__(self, scale=None, *args, **kwargs) -> None:
57
+ super(Identity, self).__init__()
58
+ def forward(self, input, *args, **kwargs):
59
+ return input
60
+
61
+
62
+ class _LoRACompatibleLinear(nn.Module):
63
+ """
64
+ A Linear layer that can be used with LoRA.
65
+ """
66
+
67
+ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
68
+ super().__init__(*args, **kwargs)
69
+ self.lora_layer = lora_layer
70
+
71
+ def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
72
+ self.lora_layer = lora_layer
73
+
74
+ def _fuse_lora(self):
75
+ pass
76
+
77
+ def _unfuse_lora(self):
78
+ pass
79
+
80
+ def forward(self, hidden_states, scale=None, lora_scale: int = 1):
81
+ return hidden_states
82
+
83
+
84
+ @dataclass
85
+ class UNet2DConditionOutput(BaseOutput):
86
+ """
87
+ The output of [`UNet2DConditionModel`].
88
+
89
+ Args:
90
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
91
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
92
+ """
93
+
94
+ sample: torch.FloatTensor = None
95
+
96
+
97
+ class ref_unet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
98
+ r"""
99
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
100
+ shaped output.
101
+
102
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
103
+ for all models (such as downloading or saving).
104
+
105
+ Parameters:
106
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
107
+ Height and width of input/output sample.
108
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
109
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
110
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
111
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
112
+ Whether to flip the sin to cos in the time embedding.
113
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
114
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
115
+ The tuple of downsample blocks to use.
116
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
117
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
118
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
119
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
120
+ The tuple of upsample blocks to use.
121
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
122
+ Whether to include self-attention in the basic transformer blocks, see
123
+ [`~models.attention.BasicTransformerBlock`].
124
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
125
+ The tuple of output channels for each block.
126
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
127
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
128
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
129
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
130
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
131
+ If `None`, normalization and activation layers is skipped in post-processing.
132
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
133
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
134
+ The dimension of the cross attention features.
135
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
136
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
137
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
138
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
139
+ encoder_hid_dim (`int`, *optional*, defaults to None):
140
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
141
+ dimension to `cross_attention_dim`.
142
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
143
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
144
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
145
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
146
+ num_attention_heads (`int`, *optional*):
147
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
148
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
149
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
150
+ class_embed_type (`str`, *optional*, defaults to `None`):
151
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
152
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
153
+ addition_embed_type (`str`, *optional*, defaults to `None`):
154
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
155
+ "text". "text" will use the `TextTimeEmbedding` layer.
156
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
157
+ Dimension for the timestep embeddings.
158
+ num_class_embeds (`int`, *optional*, defaults to `None`):
159
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
160
+ class conditioning with `class_embed_type` equal to `None`.
161
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
162
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
163
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
164
+ An optional override for the dimension of the projected time embedding.
165
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
166
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
167
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
168
+ timestep_post_act (`str`, *optional*, defaults to `None`):
169
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
170
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
171
+ The dimension of `cond_proj` layer in the timestep embedding.
172
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
173
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
174
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
175
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
176
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
177
+ embeddings with the class embeddings.
178
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
179
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
180
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
181
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
182
+ otherwise.
183
+ """
184
+
185
+ _supports_gradient_checkpointing = True
186
+
187
+ @register_to_config
188
+ def __init__(
189
+ self,
190
+ sample_size: Optional[int] = None,
191
+ in_channels: int = 4,
192
+ out_channels: int = 4,
193
+ center_input_sample: bool = False,
194
+ flip_sin_to_cos: bool = True,
195
+ freq_shift: int = 0,
196
+ down_block_types: Tuple[str] = (
197
+ "CrossAttnDownBlock2D",
198
+ "CrossAttnDownBlock2D",
199
+ "CrossAttnDownBlock2D",
200
+ "DownBlock2D",
201
+ ),
202
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
203
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
204
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
205
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
206
+ layers_per_block: Union[int, Tuple[int]] = 2,
207
+ downsample_padding: int = 1,
208
+ mid_block_scale_factor: float = 1,
209
+ act_fn: str = "silu",
210
+ norm_num_groups: Optional[int] = 32,
211
+ norm_eps: float = 1e-5,
212
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
213
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
214
+ encoder_hid_dim: Optional[int] = None,
215
+ encoder_hid_dim_type: Optional[str] = None,
216
+ attention_head_dim: Union[int, Tuple[int]] = 8,
217
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
218
+ dual_cross_attention: bool = False,
219
+ use_linear_projection: bool = False,
220
+ class_embed_type: Optional[str] = None,
221
+ addition_embed_type: Optional[str] = None,
222
+ addition_time_embed_dim: Optional[int] = None,
223
+ num_class_embeds: Optional[int] = None,
224
+ upcast_attention: bool = False,
225
+ resnet_time_scale_shift: str = "default",
226
+ resnet_skip_time_act: bool = False,
227
+ resnet_out_scale_factor: int = 1.0,
228
+ time_embedding_type: str = "positional",
229
+ time_embedding_dim: Optional[int] = None,
230
+ time_embedding_act_fn: Optional[str] = None,
231
+ timestep_post_act: Optional[str] = None,
232
+ time_cond_proj_dim: Optional[int] = None,
233
+ conv_in_kernel: int = 3,
234
+ conv_out_kernel: int = 3,
235
+ projection_class_embeddings_input_dim: Optional[int] = None,
236
+ attention_type: str = "default",
237
+ class_embeddings_concat: bool = False,
238
+ mid_block_only_cross_attention: Optional[bool] = None,
239
+ cross_attention_norm: Optional[str] = None,
240
+ addition_embed_type_num_heads=64,
241
+ ):
242
+ super().__init__()
243
+
244
+ self.sample_size = sample_size
245
+
246
+ if num_attention_heads is not None:
247
+ raise ValueError(
248
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
249
+ )
250
+
251
+ # If `num_attention_heads` is not defined (which is the case for most models)
252
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
253
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
254
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
255
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
256
+ # which is why we correct for the naming here.
257
+ num_attention_heads = num_attention_heads or attention_head_dim
258
+
259
+ # Check inputs
260
+ if len(down_block_types) != len(up_block_types):
261
+ raise ValueError(
262
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
263
+ )
264
+
265
+ if len(block_out_channels) != len(down_block_types):
266
+ raise ValueError(
267
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
268
+ )
269
+
270
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
271
+ raise ValueError(
272
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
273
+ )
274
+
275
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
276
+ raise ValueError(
277
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
278
+ )
279
+
280
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
281
+ raise ValueError(
282
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
283
+ )
284
+
285
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
286
+ raise ValueError(
287
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
288
+ )
289
+
290
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
291
+ raise ValueError(
292
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
293
+ )
294
+
295
+ # input
296
+ conv_in_padding = (conv_in_kernel - 1) // 2
297
+ self.conv_in = nn.Conv2d(
298
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
299
+ )
300
+
301
+ # time
302
+ if time_embedding_type == "fourier":
303
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
304
+ if time_embed_dim % 2 != 0:
305
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
306
+ self.time_proj = GaussianFourierProjection(
307
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
308
+ )
309
+ timestep_input_dim = time_embed_dim
310
+ elif time_embedding_type == "positional":
311
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
312
+
313
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
314
+ timestep_input_dim = block_out_channels[0]
315
+ else:
316
+ raise ValueError(
317
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
318
+ )
319
+
320
+ self.time_embedding = TimestepEmbedding(
321
+ timestep_input_dim,
322
+ time_embed_dim,
323
+ act_fn=act_fn,
324
+ post_act_fn=timestep_post_act,
325
+ cond_proj_dim=time_cond_proj_dim,
326
+ )
327
+
328
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
329
+ encoder_hid_dim_type = "text_proj"
330
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
331
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
332
+
333
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
334
+ raise ValueError(
335
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
336
+ )
337
+
338
+ if encoder_hid_dim_type == "text_proj":
339
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
340
+ elif encoder_hid_dim_type == "text_image_proj":
341
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
342
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
343
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
344
+ self.encoder_hid_proj = TextImageProjection(
345
+ text_embed_dim=encoder_hid_dim,
346
+ image_embed_dim=cross_attention_dim,
347
+ cross_attention_dim=cross_attention_dim,
348
+ )
349
+ elif encoder_hid_dim_type == "image_proj":
350
+ # Kandinsky 2.2
351
+ self.encoder_hid_proj = ImageProjection(
352
+ image_embed_dim=encoder_hid_dim,
353
+ cross_attention_dim=cross_attention_dim,
354
+ )
355
+ elif encoder_hid_dim_type is not None:
356
+ raise ValueError(
357
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
358
+ )
359
+ else:
360
+ self.encoder_hid_proj = None
361
+
362
+ # class embedding
363
+ if class_embed_type is None and num_class_embeds is not None:
364
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
365
+ elif class_embed_type == "timestep":
366
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
367
+ elif class_embed_type == "identity":
368
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
369
+ elif class_embed_type == "projection":
370
+ if projection_class_embeddings_input_dim is None:
371
+ raise ValueError(
372
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
373
+ )
374
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
375
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
376
+ # 2. it projects from an arbitrary input dimension.
377
+ #
378
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
379
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
380
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
381
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
382
+ elif class_embed_type == "simple_projection":
383
+ if projection_class_embeddings_input_dim is None:
384
+ raise ValueError(
385
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
386
+ )
387
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
388
+ else:
389
+ self.class_embedding = None
390
+
391
+ if addition_embed_type == "text":
392
+ if encoder_hid_dim is not None:
393
+ text_time_embedding_from_dim = encoder_hid_dim
394
+ else:
395
+ text_time_embedding_from_dim = cross_attention_dim
396
+
397
+ self.add_embedding = TextTimeEmbedding(
398
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
399
+ )
400
+ elif addition_embed_type == "text_image":
401
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
402
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
403
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
404
+ self.add_embedding = TextImageTimeEmbedding(
405
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
406
+ )
407
+ elif addition_embed_type == "text_time":
408
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
409
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
410
+ elif addition_embed_type == "image":
411
+ # Kandinsky 2.2
412
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
413
+ elif addition_embed_type == "image_hint":
414
+ # Kandinsky 2.2 ControlNet
415
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
416
+ elif addition_embed_type is not None:
417
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
418
+
419
+ if time_embedding_act_fn is None:
420
+ self.time_embed_act = None
421
+ else:
422
+ self.time_embed_act = get_activation(time_embedding_act_fn)
423
+
424
+ self.down_blocks = nn.ModuleList([])
425
+ self.up_blocks = nn.ModuleList([])
426
+
427
+ if isinstance(only_cross_attention, bool):
428
+ if mid_block_only_cross_attention is None:
429
+ mid_block_only_cross_attention = only_cross_attention
430
+
431
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
432
+
433
+ if mid_block_only_cross_attention is None:
434
+ mid_block_only_cross_attention = False
435
+
436
+ if isinstance(num_attention_heads, int):
437
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
438
+
439
+ if isinstance(attention_head_dim, int):
440
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
441
+
442
+ if isinstance(cross_attention_dim, int):
443
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
444
+
445
+ if isinstance(layers_per_block, int):
446
+ layers_per_block = [layers_per_block] * len(down_block_types)
447
+
448
+ if isinstance(transformer_layers_per_block, int):
449
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
450
+
451
+ if class_embeddings_concat:
452
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
453
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
454
+ # regular time embeddings
455
+ blocks_time_embed_dim = time_embed_dim * 2
456
+ else:
457
+ blocks_time_embed_dim = time_embed_dim
458
+
459
+ # down
460
+ output_channel = block_out_channels[0]
461
+ for i, down_block_type in enumerate(down_block_types):
462
+ input_channel = output_channel
463
+ output_channel = block_out_channels[i]
464
+ is_final_block = i == len(block_out_channels) - 1
465
+
466
+ down_block = get_down_block(
467
+ down_block_type,
468
+ num_layers=layers_per_block[i],
469
+ transformer_layers_per_block=transformer_layers_per_block[i],
470
+ in_channels=input_channel,
471
+ out_channels=output_channel,
472
+ temb_channels=blocks_time_embed_dim,
473
+ add_downsample=not is_final_block,
474
+ resnet_eps=norm_eps,
475
+ resnet_act_fn=act_fn,
476
+ resnet_groups=norm_num_groups,
477
+ cross_attention_dim=cross_attention_dim[i],
478
+ num_attention_heads=num_attention_heads[i],
479
+ downsample_padding=downsample_padding,
480
+ dual_cross_attention=dual_cross_attention,
481
+ use_linear_projection=use_linear_projection,
482
+ only_cross_attention=only_cross_attention[i],
483
+ upcast_attention=upcast_attention,
484
+ resnet_time_scale_shift=resnet_time_scale_shift,
485
+ attention_type=attention_type,
486
+ resnet_skip_time_act=resnet_skip_time_act,
487
+ resnet_out_scale_factor=resnet_out_scale_factor,
488
+ cross_attention_norm=cross_attention_norm,
489
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
490
+ )
491
+ self.down_blocks.append(down_block)
492
+
493
+ # mid
494
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
495
+ self.mid_block = UNetMidBlock2DCrossAttn(
496
+ transformer_layers_per_block=transformer_layers_per_block[-1],
497
+ in_channels=block_out_channels[-1],
498
+ temb_channels=blocks_time_embed_dim,
499
+ resnet_eps=norm_eps,
500
+ resnet_act_fn=act_fn,
501
+ output_scale_factor=mid_block_scale_factor,
502
+ resnet_time_scale_shift=resnet_time_scale_shift,
503
+ cross_attention_dim=cross_attention_dim[-1],
504
+ num_attention_heads=num_attention_heads[-1],
505
+ resnet_groups=norm_num_groups,
506
+ dual_cross_attention=dual_cross_attention,
507
+ use_linear_projection=use_linear_projection,
508
+ upcast_attention=upcast_attention,
509
+ attention_type=attention_type,
510
+ )
511
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
512
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
513
+ in_channels=block_out_channels[-1],
514
+ temb_channels=blocks_time_embed_dim,
515
+ resnet_eps=norm_eps,
516
+ resnet_act_fn=act_fn,
517
+ output_scale_factor=mid_block_scale_factor,
518
+ cross_attention_dim=cross_attention_dim[-1],
519
+ attention_head_dim=attention_head_dim[-1],
520
+ resnet_groups=norm_num_groups,
521
+ resnet_time_scale_shift=resnet_time_scale_shift,
522
+ skip_time_act=resnet_skip_time_act,
523
+ only_cross_attention=mid_block_only_cross_attention,
524
+ cross_attention_norm=cross_attention_norm,
525
+ )
526
+ elif mid_block_type is None:
527
+ self.mid_block = None
528
+ else:
529
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
530
+
531
+ # count how many layers upsample the images
532
+ self.num_upsamplers = 0
533
+
534
+ # up
535
+ reversed_block_out_channels = list(reversed(block_out_channels))
536
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
537
+ reversed_layers_per_block = list(reversed(layers_per_block))
538
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
539
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
540
+ only_cross_attention = list(reversed(only_cross_attention))
541
+
542
+ output_channel = reversed_block_out_channels[0]
543
+ for i, up_block_type in enumerate(up_block_types):
544
+ is_final_block = i == len(block_out_channels) - 1
545
+
546
+ prev_output_channel = output_channel
547
+ output_channel = reversed_block_out_channels[i]
548
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
549
+
550
+ # add upsample block for all BUT final layer
551
+ if not is_final_block:
552
+ add_upsample = True
553
+ self.num_upsamplers += 1
554
+ else:
555
+ add_upsample = False
556
+
557
+ up_block = get_up_block(
558
+ up_block_type,
559
+ num_layers=reversed_layers_per_block[i] + 1,
560
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
561
+ in_channels=input_channel,
562
+ out_channels=output_channel,
563
+ prev_output_channel=prev_output_channel,
564
+ temb_channels=blocks_time_embed_dim,
565
+ add_upsample=add_upsample,
566
+ resnet_eps=norm_eps,
567
+ resnet_act_fn=act_fn,
568
+ resnet_groups=norm_num_groups,
569
+ cross_attention_dim=reversed_cross_attention_dim[i],
570
+ num_attention_heads=reversed_num_attention_heads[i],
571
+ dual_cross_attention=dual_cross_attention,
572
+ use_linear_projection=use_linear_projection,
573
+ only_cross_attention=only_cross_attention[i],
574
+ upcast_attention=upcast_attention,
575
+ resnet_time_scale_shift=resnet_time_scale_shift,
576
+ attention_type=attention_type,
577
+ resnet_skip_time_act=resnet_skip_time_act,
578
+ resnet_out_scale_factor=resnet_out_scale_factor,
579
+ cross_attention_norm=cross_attention_norm,
580
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
581
+ )
582
+ self.up_blocks.append(up_block)
583
+ prev_output_channel = output_channel
584
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear()
585
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear()
586
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear()
587
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()])
588
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity()
589
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = None
590
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity()
591
+ self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity()
592
+ self.up_blocks[3].attentions[2].proj_out = Identity()
593
+
594
+ if attention_type in ["gated", "gated-text-image"]:
595
+ positive_len = 768
596
+ if isinstance(cross_attention_dim, int):
597
+ positive_len = cross_attention_dim
598
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
599
+ positive_len = cross_attention_dim[0]
600
+
601
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
602
+ self.position_net = PositionNet(
603
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
604
+ )
605
+
606
+ @property
607
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
608
+ r"""
609
+ Returns:
610
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
611
+ indexed by its weight name.
612
+ """
613
+ # set recursively
614
+ processors = {}
615
+
616
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
617
+ if hasattr(module, "get_processor"):
618
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
619
+
620
+ for sub_name, child in module.named_children():
621
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
622
+
623
+ return processors
624
+
625
+ for name, module in self.named_children():
626
+ fn_recursive_add_processors(name, module, processors)
627
+
628
+ return processors
629
+
630
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
631
+ r"""
632
+ Sets the attention processor to use to compute attention.
633
+
634
+ Parameters:
635
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
636
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
637
+ for **all** `Attention` layers.
638
+
639
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
640
+ processor. This is strongly recommended when setting trainable attention processors.
641
+
642
+ """
643
+ count = len(self.attn_processors.keys())
644
+
645
+ if isinstance(processor, dict) and len(processor) != count:
646
+ raise ValueError(
647
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
648
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
649
+ )
650
+
651
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
652
+ if hasattr(module, "set_processor"):
653
+ if not isinstance(processor, dict):
654
+ module.set_processor(processor)
655
+ else:
656
+ module.set_processor(processor.pop(f"{name}.processor"))
657
+
658
+ for sub_name, child in module.named_children():
659
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
660
+
661
+ for name, module in self.named_children():
662
+ fn_recursive_attn_processor(name, module, processor)
663
+
664
+ def set_default_attn_processor(self):
665
+ """
666
+ Disables custom attention processors and sets the default attention implementation.
667
+ """
668
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
669
+ processor = AttnAddedKVProcessor()
670
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
671
+ processor = AttnProcessor()
672
+ else:
673
+ raise ValueError(
674
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
675
+ )
676
+
677
+ self.set_attn_processor(processor)
678
+
679
+ def set_attention_slice(self, slice_size):
680
+ r"""
681
+ Enable sliced attention computation.
682
+
683
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
684
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
685
+
686
+ Args:
687
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
688
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
689
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
690
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
691
+ must be a multiple of `slice_size`.
692
+ """
693
+ sliceable_head_dims = []
694
+
695
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
696
+ if hasattr(module, "set_attention_slice"):
697
+ sliceable_head_dims.append(module.sliceable_head_dim)
698
+
699
+ for child in module.children():
700
+ fn_recursive_retrieve_sliceable_dims(child)
701
+
702
+ # retrieve number of attention layers
703
+ for module in self.children():
704
+ fn_recursive_retrieve_sliceable_dims(module)
705
+
706
+ num_sliceable_layers = len(sliceable_head_dims)
707
+
708
+ if slice_size == "auto":
709
+ # half the attention head size is usually a good trade-off between
710
+ # speed and memory
711
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
712
+ elif slice_size == "max":
713
+ # make smallest slice possible
714
+ slice_size = num_sliceable_layers * [1]
715
+
716
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
717
+
718
+ if len(slice_size) != len(sliceable_head_dims):
719
+ raise ValueError(
720
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
721
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
722
+ )
723
+
724
+ for i in range(len(slice_size)):
725
+ size = slice_size[i]
726
+ dim = sliceable_head_dims[i]
727
+ if size is not None and size > dim:
728
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
729
+
730
+ # Recursively walk through all the children.
731
+ # Any children which exposes the set_attention_slice method
732
+ # gets the message
733
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
734
+ if hasattr(module, "set_attention_slice"):
735
+ module.set_attention_slice(slice_size.pop())
736
+
737
+ for child in module.children():
738
+ fn_recursive_set_attention_slice(child, slice_size)
739
+
740
+ reversed_slice_size = list(reversed(slice_size))
741
+ for module in self.children():
742
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
743
+
744
+ def _set_gradient_checkpointing(self, module, value=False):
745
+ if hasattr(module, "gradient_checkpointing"):
746
+ module.gradient_checkpointing = value
747
+
748
+ def forward(
749
+ self,
750
+ sample: torch.FloatTensor,
751
+ timestep: Union[torch.Tensor, float, int],
752
+ encoder_hidden_states: torch.Tensor,
753
+ class_labels: Optional[torch.Tensor] = None,
754
+ timestep_cond: Optional[torch.Tensor] = None,
755
+ attention_mask: Optional[torch.Tensor] = None,
756
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
757
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
758
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
759
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
760
+ encoder_attention_mask: Optional[torch.Tensor] = None,
761
+ return_dict: bool = True,
762
+ ) -> Union[UNet2DConditionOutput, Tuple]:
763
+ r"""
764
+ The [`UNet2DConditionModel`] forward method.
765
+
766
+ Args:
767
+ sample (`torch.FloatTensor`):
768
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
769
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
770
+ encoder_hidden_states (`torch.FloatTensor`):
771
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
772
+ encoder_attention_mask (`torch.Tensor`):
773
+ A cross-attention face_hair_mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
774
+ `True` the face_hair_mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
775
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
776
+ return_dict (`bool`, *optional*, defaults to `True`):
777
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
778
+ tuple.
779
+ cross_attention_kwargs (`dict`, *optional*):
780
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
781
+ added_cond_kwargs: (`dict`, *optional*):
782
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
783
+ are passed along to the UNet blocks.
784
+
785
+ Returns:
786
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
787
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
788
+ a `tuple` is returned where the first element is the sample tensor.
789
+ """
790
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
791
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
792
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
793
+ # on the fly if necessary.
794
+ default_overall_up_factor = 2**self.num_upsamplers
795
+
796
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
797
+ forward_upsample_size = False
798
+ upsample_size = None
799
+
800
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
801
+ logger.info("Forward upsample size to force interpolation output size.")
802
+ forward_upsample_size = True
803
+
804
+ if attention_mask is not None:
805
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
806
+ attention_mask = attention_mask.unsqueeze(1)
807
+
808
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
809
+ if encoder_attention_mask is not None:
810
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
811
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
812
+
813
+ # 0. center input if necessary
814
+ if self.config.center_input_sample:
815
+ sample = 2 * sample - 1.0
816
+
817
+ # 1. time
818
+ timesteps = timestep
819
+ if not torch.is_tensor(timesteps):
820
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
821
+ # This would be a good case for the `match` statement (Python 3.10+)
822
+ is_mps = sample.device.type == "mps"
823
+ if isinstance(timestep, float):
824
+ dtype = torch.float32 if is_mps else torch.float64
825
+ else:
826
+ dtype = torch.int32 if is_mps else torch.int64
827
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
828
+ elif len(timesteps.shape) == 0:
829
+ timesteps = timesteps[None].to(sample.device)
830
+
831
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
832
+ timesteps = timesteps.expand(sample.shape[0])
833
+
834
+ t_emb = self.time_proj(timesteps)
835
+
836
+ # `Timesteps` does not contain any weights and will always return f32 tensors
837
+ # but time_embedding might actually be running in fp16. so we need to cast here.
838
+ # there might be better ways to encapsulate this.
839
+ t_emb = t_emb.to(dtype=sample.dtype)
840
+
841
+ emb = self.time_embedding(t_emb, timestep_cond)
842
+ aug_emb = None
843
+
844
+ if self.class_embedding is not None:
845
+ if class_labels is None:
846
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
847
+
848
+ if self.config.class_embed_type == "timestep":
849
+ class_labels = self.time_proj(class_labels)
850
+
851
+ # `Timesteps` does not contain any weights and will always return f32 tensors
852
+ # there might be better ways to encapsulate this.
853
+ class_labels = class_labels.to(dtype=sample.dtype)
854
+
855
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
856
+
857
+ if self.config.class_embeddings_concat:
858
+ emb = torch.cat([emb, class_emb], dim=-1)
859
+ else:
860
+ emb = emb + class_emb
861
+
862
+ if self.config.addition_embed_type == "text":
863
+ aug_emb = self.add_embedding(encoder_hidden_states)
864
+ elif self.config.addition_embed_type == "text_image":
865
+ # Kandinsky 2.1 - style
866
+ if "image_embeds" not in added_cond_kwargs:
867
+ raise ValueError(
868
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
869
+ )
870
+
871
+ image_embs = added_cond_kwargs.get("image_embeds")
872
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
873
+ aug_emb = self.add_embedding(text_embs, image_embs)
874
+ elif self.config.addition_embed_type == "text_time":
875
+ # SDXL - style
876
+ if "text_embeds" not in added_cond_kwargs:
877
+ raise ValueError(
878
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
879
+ )
880
+ text_embeds = added_cond_kwargs.get("text_embeds")
881
+ if "time_ids" not in added_cond_kwargs:
882
+ raise ValueError(
883
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
884
+ )
885
+ time_ids = added_cond_kwargs.get("time_ids")
886
+ time_embeds = self.add_time_proj(time_ids.flatten())
887
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
888
+
889
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
890
+ add_embeds = add_embeds.to(emb.dtype)
891
+ aug_emb = self.add_embedding(add_embeds)
892
+ elif self.config.addition_embed_type == "image":
893
+ # Kandinsky 2.2 - style
894
+ if "image_embeds" not in added_cond_kwargs:
895
+ raise ValueError(
896
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
897
+ )
898
+ image_embs = added_cond_kwargs.get("image_embeds")
899
+ aug_emb = self.add_embedding(image_embs)
900
+ elif self.config.addition_embed_type == "image_hint":
901
+ # Kandinsky 2.2 - style
902
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
903
+ raise ValueError(
904
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
905
+ )
906
+ image_embs = added_cond_kwargs.get("image_embeds")
907
+ hint = added_cond_kwargs.get("hint")
908
+ aug_emb, hint = self.add_embedding(image_embs, hint)
909
+ sample = torch.cat([sample, hint], dim=1)
910
+
911
+ emb = emb + aug_emb if aug_emb is not None else emb
912
+
913
+ if self.time_embed_act is not None:
914
+ emb = self.time_embed_act(emb)
915
+
916
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
917
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
918
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
919
+ # Kadinsky 2.1 - style
920
+ if "image_embeds" not in added_cond_kwargs:
921
+ raise ValueError(
922
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
923
+ )
924
+
925
+ image_embeds = added_cond_kwargs.get("image_embeds")
926
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
927
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
928
+ # Kandinsky 2.2 - style
929
+ if "image_embeds" not in added_cond_kwargs:
930
+ raise ValueError(
931
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
932
+ )
933
+ image_embeds = added_cond_kwargs.get("image_embeds")
934
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
935
+ # 2. pre-process
936
+ sample = self.conv_in(sample)
937
+
938
+ # 2.5 GLIGEN position net
939
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
940
+ cross_attention_kwargs = cross_attention_kwargs.copy()
941
+ gligen_args = cross_attention_kwargs.pop("gligen")
942
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
943
+
944
+ # 3. down
945
+
946
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
947
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
948
+
949
+ down_block_res_samples = (sample,)
950
+ for downsample_block in self.down_blocks:
951
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
952
+ # For t2i-adapter CrossAttnDownBlock2D
953
+ additional_residuals = {}
954
+ if is_adapter and len(down_block_additional_residuals) > 0:
955
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
956
+
957
+ sample, res_samples = downsample_block(
958
+ hidden_states=sample,
959
+ temb=emb,
960
+ encoder_hidden_states=encoder_hidden_states,
961
+ attention_mask=attention_mask,
962
+ cross_attention_kwargs=cross_attention_kwargs,
963
+ encoder_attention_mask=encoder_attention_mask,
964
+ **additional_residuals,
965
+ )
966
+ else:
967
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
968
+
969
+ if is_adapter and len(down_block_additional_residuals) > 0:
970
+ sample += down_block_additional_residuals.pop(0)
971
+
972
+ down_block_res_samples += res_samples
973
+
974
+ if is_controlnet:
975
+ new_down_block_res_samples = ()
976
+
977
+ for down_block_res_sample, down_block_additional_residual in zip(
978
+ down_block_res_samples, down_block_additional_residuals
979
+ ):
980
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
981
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
982
+
983
+ down_block_res_samples = new_down_block_res_samples
984
+
985
+ # 4. mid
986
+ if self.mid_block is not None:
987
+ sample = self.mid_block(
988
+ sample,
989
+ emb,
990
+ encoder_hidden_states=encoder_hidden_states,
991
+ attention_mask=attention_mask,
992
+ cross_attention_kwargs=cross_attention_kwargs,
993
+ encoder_attention_mask=encoder_attention_mask,
994
+ )
995
+ # To support T2I-Adapter-XL
996
+ if (
997
+ is_adapter
998
+ and len(down_block_additional_residuals) > 0
999
+ and sample.shape == down_block_additional_residuals[0].shape
1000
+ ):
1001
+ sample += down_block_additional_residuals.pop(0)
1002
+
1003
+ if is_controlnet:
1004
+ sample = sample + mid_block_additional_residual
1005
+
1006
+ # 5. up
1007
+ for i, upsample_block in enumerate(self.up_blocks):
1008
+ is_final_block = i == len(self.up_blocks) - 1
1009
+
1010
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1011
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1012
+
1013
+ # if we have not reached the final block and need to forward the
1014
+ # upsample size, we do it here
1015
+ if not is_final_block and forward_upsample_size:
1016
+ upsample_size = down_block_res_samples[-1].shape[2:]
1017
+
1018
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1019
+ sample = upsample_block(
1020
+ hidden_states=sample,
1021
+ temb=emb,
1022
+ res_hidden_states_tuple=res_samples,
1023
+ encoder_hidden_states=encoder_hidden_states,
1024
+ cross_attention_kwargs=cross_attention_kwargs,
1025
+ upsample_size=upsample_size,
1026
+ attention_mask=attention_mask,
1027
+ encoder_attention_mask=encoder_attention_mask,
1028
+ )
1029
+ else:
1030
+ sample = upsample_block(
1031
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1032
+ )
1033
+
1034
+ if not return_dict:
1035
+ return (sample,)
1036
+
1037
+ return UNet2DConditionOutput(sample=sample)
ref_encoder/resnet.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+
9
+ class InflatedConv3d(nn.Conv2d):
10
+ def forward(self, x):
11
+ video_length = x.shape[2]
12
+
13
+ x = rearrange(x, "b c f h w -> (b f) c h w")
14
+ x = super().forward(x)
15
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
16
+
17
+ return x
18
+
19
+
20
+ class InflatedGroupNorm(nn.GroupNorm):
21
+ def forward(self, x):
22
+ video_length = x.shape[2]
23
+
24
+ x = rearrange(x, "b c f h w -> (b f) c h w")
25
+ x = super().forward(x)
26
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
27
+
28
+ return x
29
+
30
+
31
+ class Upsample3D(nn.Module):
32
+ def __init__(
33
+ self,
34
+ channels,
35
+ use_conv=False,
36
+ use_conv_transpose=False,
37
+ out_channels=None,
38
+ name="conv",
39
+ ):
40
+ super().__init__()
41
+ self.channels = channels
42
+ self.out_channels = out_channels or channels
43
+ self.use_conv = use_conv
44
+ self.use_conv_transpose = use_conv_transpose
45
+ self.name = name
46
+
47
+ conv = None
48
+ if use_conv_transpose:
49
+ raise NotImplementedError
50
+ elif use_conv:
51
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
52
+
53
+ def forward(self, hidden_states, output_size=None):
54
+ assert hidden_states.shape[1] == self.channels
55
+
56
+ if self.use_conv_transpose:
57
+ raise NotImplementedError
58
+
59
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
60
+ dtype = hidden_states.dtype
61
+ if dtype == torch.bfloat16:
62
+ hidden_states = hidden_states.to(torch.float32)
63
+
64
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
65
+ if hidden_states.shape[0] >= 64:
66
+ hidden_states = hidden_states.contiguous()
67
+
68
+ # if `output_size` is passed we force the interpolation output
69
+ # size and do not make use of `scale_factor=2`
70
+ if output_size is None:
71
+ hidden_states = F.interpolate(
72
+ hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
73
+ )
74
+ else:
75
+ hidden_states = F.interpolate(
76
+ hidden_states, size=output_size, mode="nearest"
77
+ )
78
+
79
+ # If the input is bfloat16, we cast back to bfloat16
80
+ if dtype == torch.bfloat16:
81
+ hidden_states = hidden_states.to(dtype)
82
+
83
+ # if self.use_conv:
84
+ # if self.name == "conv":
85
+ # hidden_states = self.conv(hidden_states)
86
+ # else:
87
+ # hidden_states = self.Conv2d_0(hidden_states)
88
+ hidden_states = self.conv(hidden_states)
89
+
90
+ return hidden_states
91
+
92
+
93
+ class Downsample3D(nn.Module):
94
+ def __init__(
95
+ self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
96
+ ):
97
+ super().__init__()
98
+ self.channels = channels
99
+ self.out_channels = out_channels or channels
100
+ self.use_conv = use_conv
101
+ self.padding = padding
102
+ stride = 2
103
+ self.name = name
104
+
105
+ if use_conv:
106
+ self.conv = InflatedConv3d(
107
+ self.channels, self.out_channels, 3, stride=stride, padding=padding
108
+ )
109
+ else:
110
+ raise NotImplementedError
111
+
112
+ def forward(self, hidden_states):
113
+ assert hidden_states.shape[1] == self.channels
114
+ if self.use_conv and self.padding == 0:
115
+ raise NotImplementedError
116
+
117
+ assert hidden_states.shape[1] == self.channels
118
+ hidden_states = self.conv(hidden_states)
119
+
120
+ return hidden_states
121
+
122
+
123
+ class ResnetBlock3D(nn.Module):
124
+ def __init__(
125
+ self,
126
+ *,
127
+ in_channels,
128
+ out_channels=None,
129
+ conv_shortcut=False,
130
+ dropout=0.0,
131
+ temb_channels=512,
132
+ groups=32,
133
+ groups_out=None,
134
+ pre_norm=True,
135
+ eps=1e-6,
136
+ non_linearity="swish",
137
+ time_embedding_norm="default",
138
+ output_scale_factor=1.0,
139
+ use_in_shortcut=None,
140
+ use_inflated_groupnorm=None,
141
+ ):
142
+ super().__init__()
143
+ self.pre_norm = pre_norm
144
+ self.pre_norm = True
145
+ self.in_channels = in_channels
146
+ out_channels = in_channels if out_channels is None else out_channels
147
+ self.out_channels = out_channels
148
+ self.use_conv_shortcut = conv_shortcut
149
+ self.time_embedding_norm = time_embedding_norm
150
+ self.output_scale_factor = output_scale_factor
151
+
152
+ if groups_out is None:
153
+ groups_out = groups
154
+
155
+ assert use_inflated_groupnorm != None
156
+ if use_inflated_groupnorm:
157
+ self.norm1 = InflatedGroupNorm(
158
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
159
+ )
160
+ else:
161
+ self.norm1 = torch.nn.GroupNorm(
162
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
163
+ )
164
+
165
+ self.conv1 = InflatedConv3d(
166
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
167
+ )
168
+
169
+ if temb_channels is not None:
170
+ if self.time_embedding_norm == "default":
171
+ time_emb_proj_out_channels = out_channels
172
+ elif self.time_embedding_norm == "scale_shift":
173
+ time_emb_proj_out_channels = out_channels * 2
174
+ else:
175
+ raise ValueError(
176
+ f"unknown time_embedding_norm : {self.time_embedding_norm} "
177
+ )
178
+
179
+ self.time_emb_proj = torch.nn.Linear(
180
+ temb_channels, time_emb_proj_out_channels
181
+ )
182
+ else:
183
+ self.time_emb_proj = None
184
+
185
+ if use_inflated_groupnorm:
186
+ self.norm2 = InflatedGroupNorm(
187
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
188
+ )
189
+ else:
190
+ self.norm2 = torch.nn.GroupNorm(
191
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
192
+ )
193
+ self.dropout = torch.nn.Dropout(dropout)
194
+ self.conv2 = InflatedConv3d(
195
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
196
+ )
197
+
198
+ if non_linearity == "swish":
199
+ self.nonlinearity = lambda x: F.silu(x)
200
+ elif non_linearity == "mish":
201
+ self.nonlinearity = Mish()
202
+ elif non_linearity == "silu":
203
+ self.nonlinearity = nn.SiLU()
204
+
205
+ self.use_in_shortcut = (
206
+ self.in_channels != self.out_channels
207
+ if use_in_shortcut is None
208
+ else use_in_shortcut
209
+ )
210
+
211
+ self.conv_shortcut = None
212
+ if self.use_in_shortcut:
213
+ self.conv_shortcut = InflatedConv3d(
214
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
215
+ )
216
+
217
+ def forward(self, input_tensor, temb):
218
+ hidden_states = input_tensor
219
+
220
+ video_length = input_tensor.shape[2]
221
+ batch_size = input_tensor.shape[0]
222
+ hidden_states = self.norm1(hidden_states)
223
+ hidden_states = self.nonlinearity(hidden_states)
224
+
225
+ hidden_states = self.conv1(hidden_states)
226
+
227
+ if temb is not None:
228
+ temb_shape = temb.shape[0]
229
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
230
+ if temb_shape>batch_size:
231
+ temb = temb.permute(2, 1, 0, 3, 4)
232
+
233
+
234
+
235
+ if temb is not None and self.time_embedding_norm == "default":
236
+ if hidden_states.shape[0]>1 and temb.shape[2]==hidden_states.shape[0]*hidden_states.shape[2]:
237
+ temb = temb.reshape(temb.shape[0], temb.shape[1], hidden_states.shape[0], -1, temb.shape[-2], temb.shape[-1])
238
+ temb = temb.permute(2, 0, 1, 3, 4, 5)
239
+ temb = temb.reshape(temb.shape[0], -1, temb.shape[3], temb.shape[4], temb.shape[5])
240
+ hidden_states = hidden_states + temb
241
+
242
+ hidden_states = self.norm2(hidden_states)
243
+
244
+ if temb is not None and self.time_embedding_norm == "scale_shift":
245
+ scale, shift = torch.chunk(temb, 2, dim=1)
246
+ hidden_states = hidden_states * (1 + scale) + shift
247
+
248
+ hidden_states = self.nonlinearity(hidden_states)
249
+
250
+ hidden_states = self.dropout(hidden_states)
251
+ hidden_states = self.conv2(hidden_states)
252
+
253
+ if self.conv_shortcut is not None:
254
+ input_tensor = self.conv_shortcut(input_tensor)
255
+
256
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
257
+
258
+ return output_tensor
259
+
260
+
261
+ class Mish(torch.nn.Module):
262
+ def forward(self, hidden_states):
263
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
ref_encoder/transformer_2d.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.models.embeddings import CaptionProjection
8
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
9
+ from diffusers.models.modeling_utils import ModelMixin
10
+ from diffusers.models.normalization import AdaLayerNormSingle
11
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
12
+ from torch import nn
13
+
14
+ from .attention import BasicTransformerBlock
15
+
16
+
17
+ @dataclass
18
+ class Transformer2DModelOutput(BaseOutput):
19
+ """
20
+ The output of [`Transformer2DModel`].
21
+
22
+ Args:
23
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
24
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
25
+ distributions for the unnoised latent pixels.
26
+ """
27
+
28
+ sample: torch.FloatTensor
29
+ ref_feature: torch.FloatTensor
30
+
31
+
32
+ class Transformer2DModel(ModelMixin, ConfigMixin):
33
+ """
34
+ A 2D Transformer model for image-like data.
35
+
36
+ Parameters:
37
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
38
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
39
+ in_channels (`int`, *optional*):
40
+ The number of channels in the input and output (specify if the input is **continuous**).
41
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
42
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
43
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
44
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
45
+ This is fixed during training since it is used to learn a number of position embeddings.
46
+ num_vector_embeds (`int`, *optional*):
47
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
48
+ Includes the class for the masked latent pixel.
49
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
50
+ num_embeds_ada_norm ( `int`, *optional*):
51
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
52
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
53
+ added to the hidden states.
54
+
55
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
56
+ attention_bias (`bool`, *optional*):
57
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
58
+ """
59
+
60
+ _supports_gradient_checkpointing = True
61
+
62
+ @register_to_config
63
+ def __init__(
64
+ self,
65
+ num_attention_heads: int = 16,
66
+ attention_head_dim: int = 88,
67
+ in_channels: Optional[int] = None,
68
+ out_channels: Optional[int] = None,
69
+ num_layers: int = 1,
70
+ dropout: float = 0.0,
71
+ norm_num_groups: int = 32,
72
+ cross_attention_dim: Optional[int] = None,
73
+ attention_bias: bool = False,
74
+ sample_size: Optional[int] = None,
75
+ num_vector_embeds: Optional[int] = None,
76
+ patch_size: Optional[int] = None,
77
+ activation_fn: str = "geglu",
78
+ num_embeds_ada_norm: Optional[int] = None,
79
+ use_linear_projection: bool = False,
80
+ only_cross_attention: bool = False,
81
+ double_self_attention: bool = False,
82
+ upcast_attention: bool = False,
83
+ norm_type: str = "layer_norm",
84
+ norm_elementwise_affine: bool = True,
85
+ norm_eps: float = 1e-5,
86
+ attention_type: str = "default",
87
+ caption_channels: int = None,
88
+ ):
89
+ super().__init__()
90
+ self.use_linear_projection = use_linear_projection
91
+ self.num_attention_heads = num_attention_heads
92
+ self.attention_head_dim = attention_head_dim
93
+ inner_dim = num_attention_heads * attention_head_dim
94
+
95
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
96
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
97
+
98
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
99
+ # Define whether input is continuous or discrete depending on configuration
100
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
101
+ self.is_input_vectorized = num_vector_embeds is not None
102
+ self.is_input_patches = in_channels is not None and patch_size is not None
103
+
104
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
105
+ deprecation_message = (
106
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
107
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
108
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
109
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
110
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
111
+ )
112
+ deprecate(
113
+ "norm_type!=num_embeds_ada_norm",
114
+ "1.0.0",
115
+ deprecation_message,
116
+ standard_warn=False,
117
+ )
118
+ norm_type = "ada_norm"
119
+
120
+ if self.is_input_continuous and self.is_input_vectorized:
121
+ raise ValueError(
122
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
123
+ " sure that either `in_channels` or `num_vector_embeds` is None."
124
+ )
125
+ elif self.is_input_vectorized and self.is_input_patches:
126
+ raise ValueError(
127
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
128
+ " sure that either `num_vector_embeds` or `num_patches` is None."
129
+ )
130
+ elif (
131
+ not self.is_input_continuous
132
+ and not self.is_input_vectorized
133
+ and not self.is_input_patches
134
+ ):
135
+ raise ValueError(
136
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
137
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
138
+ )
139
+
140
+ # 2. Define input layers
141
+ self.in_channels = in_channels
142
+
143
+ self.norm = torch.nn.GroupNorm(
144
+ num_groups=norm_num_groups,
145
+ num_channels=in_channels,
146
+ eps=1e-6,
147
+ affine=True,
148
+ )
149
+ if use_linear_projection:
150
+ self.proj_in = linear_cls(in_channels, inner_dim)
151
+ else:
152
+ self.proj_in = conv_cls(
153
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
154
+ )
155
+
156
+ # 3. Define transformers blocks
157
+ self.transformer_blocks = nn.ModuleList(
158
+ [
159
+ BasicTransformerBlock(
160
+ inner_dim,
161
+ num_attention_heads,
162
+ attention_head_dim,
163
+ dropout=dropout,
164
+ cross_attention_dim=cross_attention_dim,
165
+ activation_fn=activation_fn,
166
+ num_embeds_ada_norm=num_embeds_ada_norm,
167
+ attention_bias=attention_bias,
168
+ only_cross_attention=only_cross_attention,
169
+ double_self_attention=double_self_attention,
170
+ upcast_attention=upcast_attention,
171
+ norm_type=norm_type,
172
+ norm_elementwise_affine=norm_elementwise_affine,
173
+ norm_eps=norm_eps,
174
+ attention_type=attention_type,
175
+ )
176
+ for d in range(num_layers)
177
+ ]
178
+ )
179
+
180
+ # 4. Define output layers
181
+ self.out_channels = in_channels if out_channels is None else out_channels
182
+ # TODO: should use out_channels for continuous projections
183
+ if use_linear_projection:
184
+ self.proj_out = linear_cls(inner_dim, in_channels)
185
+ else:
186
+ self.proj_out = conv_cls(
187
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
188
+ )
189
+
190
+ # 5. PixArt-Alpha blocks.
191
+ self.adaln_single = None
192
+ self.use_additional_conditions = False
193
+ if norm_type == "ada_norm_single":
194
+ self.use_additional_conditions = self.config.sample_size == 128
195
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
196
+ # additional conditions until we find better name
197
+ self.adaln_single = AdaLayerNormSingle(
198
+ inner_dim, use_additional_conditions=self.use_additional_conditions
199
+ )
200
+
201
+ self.caption_projection = None
202
+ if caption_channels is not None:
203
+ self.caption_projection = CaptionProjection(
204
+ in_features=caption_channels, hidden_size=inner_dim
205
+ )
206
+
207
+ self.gradient_checkpointing = False
208
+
209
+ def _set_gradient_checkpointing(self, module, value=False):
210
+ if hasattr(module, "gradient_checkpointing"):
211
+ module.gradient_checkpointing = value
212
+
213
+ def forward(
214
+ self,
215
+ hidden_states: torch.Tensor,
216
+ encoder_hidden_states: Optional[torch.Tensor] = None,
217
+ timestep: Optional[torch.LongTensor] = None,
218
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
219
+ class_labels: Optional[torch.LongTensor] = None,
220
+ cross_attention_kwargs: Dict[str, Any] = None,
221
+ attention_mask: Optional[torch.Tensor] = None,
222
+ encoder_attention_mask: Optional[torch.Tensor] = None,
223
+ return_dict: bool = True,
224
+ ):
225
+ """
226
+ The [`Transformer2DModel`] forward method.
227
+
228
+ Args:
229
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
230
+ Input `hidden_states`.
231
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
232
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
233
+ self-attention.
234
+ timestep ( `torch.LongTensor`, *optional*):
235
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
236
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
237
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
238
+ `AdaLayerZeroNorm`.
239
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
240
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
241
+ `self.processor` in
242
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
243
+ attention_mask ( `torch.Tensor`, *optional*):
244
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
245
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
246
+ negative values to the attention scores corresponding to "discard" tokens.
247
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
248
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
249
+
250
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
251
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
252
+
253
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
254
+ above. This bias will be added to the cross-attention scores.
255
+ return_dict (`bool`, *optional*, defaults to `True`):
256
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
257
+ tuple.
258
+
259
+ Returns:
260
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
261
+ `tuple` where the first element is the sample tensor.
262
+ """
263
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
264
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
265
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
266
+ # expects mask of shape:
267
+ # [batch, key_tokens]
268
+ # adds singleton query_tokens dimension:
269
+ # [batch, 1, key_tokens]
270
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
271
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
272
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
273
+ if attention_mask is not None and attention_mask.ndim == 2:
274
+ # assume that mask is expressed as:
275
+ # (1 = keep, 0 = discard)
276
+ # convert mask into a bias that can be added to attention scores:
277
+ # (keep = +0, discard = -10000.0)
278
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
279
+ attention_mask = attention_mask.unsqueeze(1)
280
+
281
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
282
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
283
+ encoder_attention_mask = (
284
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
285
+ ) * -10000.0
286
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
287
+
288
+ # Retrieve lora scale.
289
+ lora_scale = (
290
+ cross_attention_kwargs.get("scale", 1.0)
291
+ if cross_attention_kwargs is not None
292
+ else 1.0
293
+ )
294
+
295
+ # 1. Input
296
+ batch, _, height, width = hidden_states.shape
297
+ residual = hidden_states
298
+
299
+ hidden_states = self.norm(hidden_states)
300
+ if not self.use_linear_projection:
301
+ hidden_states = (
302
+ self.proj_in(hidden_states, scale=lora_scale)
303
+ if not USE_PEFT_BACKEND
304
+ else self.proj_in(hidden_states)
305
+ )
306
+ inner_dim = hidden_states.shape[1]
307
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
308
+ batch, height * width, inner_dim
309
+ )
310
+ else:
311
+ inner_dim = hidden_states.shape[1]
312
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
313
+ batch, height * width, inner_dim
314
+ )
315
+ hidden_states = (
316
+ self.proj_in(hidden_states, scale=lora_scale)
317
+ if not USE_PEFT_BACKEND
318
+ else self.proj_in(hidden_states)
319
+ )
320
+
321
+ # 2. Blocks
322
+ if self.caption_projection is not None:
323
+ batch_size = hidden_states.shape[0]
324
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
325
+ encoder_hidden_states = encoder_hidden_states.view(
326
+ batch_size, -1, hidden_states.shape[-1]
327
+ )
328
+
329
+ ref_feature = hidden_states.reshape(batch, height, width, inner_dim)
330
+ for block in self.transformer_blocks:
331
+ if self.training and self.gradient_checkpointing:
332
+
333
+ def create_custom_forward(module, return_dict=None):
334
+ def custom_forward(*inputs):
335
+ if return_dict is not None:
336
+ return module(*inputs, return_dict=return_dict)
337
+ else:
338
+ return module(*inputs)
339
+
340
+ return custom_forward
341
+
342
+ ckpt_kwargs: Dict[str, Any] = (
343
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
344
+ )
345
+ hidden_states = torch.utils.checkpoint.checkpoint(
346
+ create_custom_forward(block),
347
+ hidden_states,
348
+ attention_mask,
349
+ encoder_hidden_states,
350
+ encoder_attention_mask,
351
+ timestep,
352
+ cross_attention_kwargs,
353
+ class_labels,
354
+ **ckpt_kwargs,
355
+ )
356
+ else:
357
+ hidden_states = block(
358
+ hidden_states,
359
+ attention_mask=attention_mask,
360
+ encoder_hidden_states=encoder_hidden_states,
361
+ encoder_attention_mask=encoder_attention_mask,
362
+ timestep=timestep,
363
+ cross_attention_kwargs=cross_attention_kwargs,
364
+ class_labels=class_labels,
365
+ )
366
+
367
+ # 3. Output
368
+ if self.is_input_continuous:
369
+ if not self.use_linear_projection:
370
+ hidden_states = (
371
+ hidden_states.reshape(batch, height, width, inner_dim)
372
+ .permute(0, 3, 1, 2)
373
+ .contiguous()
374
+ )
375
+ hidden_states = (
376
+ self.proj_out(hidden_states, scale=lora_scale)
377
+ if not USE_PEFT_BACKEND
378
+ else self.proj_out(hidden_states)
379
+ )
380
+ else:
381
+ hidden_states = (
382
+ self.proj_out(hidden_states, scale=lora_scale)
383
+ if not USE_PEFT_BACKEND
384
+ else self.proj_out(hidden_states)
385
+ )
386
+ hidden_states = (
387
+ hidden_states.reshape(batch, height, width, inner_dim)
388
+ .permute(0, 3, 1, 2)
389
+ .contiguous()
390
+ )
391
+
392
+ output = hidden_states + residual
393
+ if not return_dict:
394
+ return (output, ref_feature)
395
+
396
+ return Transformer2DModelOutput(sample=output, ref_feature=ref_feature)
ref_encoder/transformer_3d.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from diffusers.models import ModelMixin
7
+ from diffusers.utils import BaseOutput
8
+ from diffusers.utils.import_utils import is_xformers_available
9
+ from einops import rearrange, repeat
10
+ from torch import nn
11
+
12
+ from .attention import TemporalBasicTransformerBlock
13
+
14
+
15
+ @dataclass
16
+ class Transformer3DModelOutput(BaseOutput):
17
+ sample: torch.FloatTensor
18
+
19
+
20
+ if is_xformers_available():
21
+ import xformers
22
+ import xformers.ops
23
+ else:
24
+ xformers = None
25
+
26
+
27
+ class Transformer3DModel(ModelMixin, ConfigMixin):
28
+ _supports_gradient_checkpointing = True
29
+
30
+ @register_to_config
31
+ def __init__(
32
+ self,
33
+ num_attention_heads: int = 16,
34
+ attention_head_dim: int = 88,
35
+ in_channels: Optional[int] = None,
36
+ num_layers: int = 1,
37
+ dropout: float = 0.0,
38
+ norm_num_groups: int = 32,
39
+ cross_attention_dim: Optional[int] = None,
40
+ attention_bias: bool = False,
41
+ activation_fn: str = "geglu",
42
+ num_embeds_ada_norm: Optional[int] = None,
43
+ use_linear_projection: bool = False,
44
+ only_cross_attention: bool = False,
45
+ upcast_attention: bool = False,
46
+ unet_use_cross_frame_attention=None,
47
+ unet_use_temporal_attention=None,
48
+ name=None,
49
+ ):
50
+ super().__init__()
51
+ self.use_linear_projection = use_linear_projection
52
+ self.num_attention_heads = num_attention_heads
53
+ self.attention_head_dim = attention_head_dim
54
+ inner_dim = num_attention_heads * attention_head_dim
55
+
56
+ # Define input layers
57
+ self.in_channels = in_channels
58
+ self.name=name
59
+
60
+ self.norm = torch.nn.GroupNorm(
61
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
62
+ )
63
+ if use_linear_projection:
64
+ self.proj_in = nn.Linear(in_channels, inner_dim)
65
+ else:
66
+ self.proj_in = nn.Conv2d(
67
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
68
+ )
69
+
70
+ # Define transformers blocks
71
+ self.transformer_blocks = nn.ModuleList(
72
+ [
73
+ TemporalBasicTransformerBlock(
74
+ inner_dim,
75
+ num_attention_heads,
76
+ attention_head_dim,
77
+ dropout=dropout,
78
+ cross_attention_dim=cross_attention_dim,
79
+ activation_fn=activation_fn,
80
+ num_embeds_ada_norm=num_embeds_ada_norm,
81
+ attention_bias=attention_bias,
82
+ only_cross_attention=only_cross_attention,
83
+ upcast_attention=upcast_attention,
84
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
85
+ unet_use_temporal_attention=unet_use_temporal_attention,
86
+ name=f"{self.name}_{d}_TransformerBlock" if self.name else None,
87
+ )
88
+ for d in range(num_layers)
89
+ ]
90
+ )
91
+
92
+ # 4. Define output layers
93
+ if use_linear_projection:
94
+ self.proj_out = nn.Linear(in_channels, inner_dim)
95
+ else:
96
+ self.proj_out = nn.Conv2d(
97
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
98
+ )
99
+
100
+ self.gradient_checkpointing = False
101
+
102
+ def _set_gradient_checkpointing(self, module, value=False):
103
+ if hasattr(module, "gradient_checkpointing"):
104
+ module.gradient_checkpointing = value
105
+
106
+ def forward(
107
+ self,
108
+ hidden_states,
109
+ encoder_hidden_states=None,
110
+ self_attention_additional_feats=None,
111
+ mode=None,
112
+ timestep=None,
113
+ return_dict: bool = True,
114
+ ):
115
+ # Input
116
+ assert (
117
+ hidden_states.dim() == 5
118
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
119
+ video_length = hidden_states.shape[2]
120
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
121
+ if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
122
+ encoder_hidden_states = repeat(
123
+ encoder_hidden_states, "b n c -> (b f) n c", f=video_length
124
+ )
125
+
126
+ batch, channel, height, weight = hidden_states.shape
127
+ residual = hidden_states
128
+
129
+ hidden_states = self.norm(hidden_states)
130
+ if not self.use_linear_projection:
131
+ hidden_states = self.proj_in(hidden_states)
132
+ inner_dim = hidden_states.shape[1]
133
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
134
+ batch, height * weight, inner_dim
135
+ )
136
+ else:
137
+ inner_dim = hidden_states.shape[1]
138
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
139
+ batch, height * weight, inner_dim
140
+ )
141
+ hidden_states = self.proj_in(hidden_states)
142
+
143
+ # Blocks
144
+ for i, block in enumerate(self.transformer_blocks):
145
+
146
+ if self.training and self.gradient_checkpointing:
147
+
148
+ def create_custom_forward(module, return_dict=None):
149
+ def custom_forward(*inputs):
150
+ if return_dict is not None:
151
+ return module(*inputs, return_dict=return_dict)
152
+ else:
153
+ return module(*inputs)
154
+
155
+ return custom_forward
156
+
157
+ # if hasattr(self.block, 'bank') and len(self.block.bank) > 0:
158
+ # hidden_states
159
+ hidden_states = torch.utils.checkpoint.checkpoint(
160
+ create_custom_forward(block),
161
+ hidden_states,
162
+ encoder_hidden_states=encoder_hidden_states,
163
+ timestep=timestep,
164
+ attention_mask=None,
165
+ video_length=video_length,
166
+ self_attention_additional_feats=self_attention_additional_feats,
167
+ mode=mode,
168
+ )
169
+ else:
170
+
171
+ hidden_states = block(
172
+ hidden_states,
173
+ encoder_hidden_states=encoder_hidden_states,
174
+ timestep=timestep,
175
+ self_attention_additional_feats=self_attention_additional_feats,
176
+ mode=mode,
177
+ video_length=video_length,
178
+ )
179
+
180
+ # Output
181
+ if not self.use_linear_projection:
182
+ hidden_states = (
183
+ hidden_states.reshape(batch, height, weight, inner_dim)
184
+ .permute(0, 3, 1, 2)
185
+ .contiguous()
186
+ )
187
+ hidden_states = self.proj_out(hidden_states)
188
+ else:
189
+ hidden_states = self.proj_out(hidden_states)
190
+ hidden_states = (
191
+ hidden_states.reshape(batch, height, weight, inner_dim)
192
+ .permute(0, 3, 1, 2)
193
+ .contiguous()
194
+ )
195
+
196
+ output = hidden_states + residual
197
+
198
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
199
+ if not return_dict:
200
+ return (output,)
201
+
202
+ return Transformer3DModelOutput(sample=output)
ref_encoder/unet_2d_blocks.py ADDED
@@ -0,0 +1,1074 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+ from typing import Any, Dict, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers.models.activations import get_activation
8
+ from diffusers.models.attention_processor import Attention
9
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
10
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
11
+ from diffusers.utils import is_torch_version, logging
12
+ from diffusers.utils.torch_utils import apply_freeu
13
+ from torch import nn
14
+
15
+ from .transformer_2d import Transformer2DModel
16
+
17
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
+
19
+
20
+ def get_down_block(
21
+ down_block_type: str,
22
+ num_layers: int,
23
+ in_channels: int,
24
+ out_channels: int,
25
+ temb_channels: int,
26
+ add_downsample: bool,
27
+ resnet_eps: float,
28
+ resnet_act_fn: str,
29
+ transformer_layers_per_block: int = 1,
30
+ num_attention_heads: Optional[int] = None,
31
+ resnet_groups: Optional[int] = None,
32
+ cross_attention_dim: Optional[int] = None,
33
+ downsample_padding: Optional[int] = None,
34
+ dual_cross_attention: bool = False,
35
+ use_linear_projection: bool = False,
36
+ only_cross_attention: bool = False,
37
+ upcast_attention: bool = False,
38
+ resnet_time_scale_shift: str = "default",
39
+ attention_type: str = "default",
40
+ resnet_skip_time_act: bool = False,
41
+ resnet_out_scale_factor: float = 1.0,
42
+ cross_attention_norm: Optional[str] = None,
43
+ attention_head_dim: Optional[int] = None,
44
+ downsample_type: Optional[str] = None,
45
+ dropout: float = 0.0,
46
+ ):
47
+ # If attn head dim is not defined, we default it to the number of heads
48
+ if attention_head_dim is None:
49
+ logger.warn(
50
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
51
+ )
52
+ attention_head_dim = num_attention_heads
53
+
54
+ down_block_type = (
55
+ down_block_type[7:]
56
+ if down_block_type.startswith("UNetRes")
57
+ else down_block_type
58
+ )
59
+ if down_block_type == "DownBlock2D":
60
+ return DownBlock2D(
61
+ num_layers=num_layers,
62
+ in_channels=in_channels,
63
+ out_channels=out_channels,
64
+ temb_channels=temb_channels,
65
+ dropout=dropout,
66
+ add_downsample=add_downsample,
67
+ resnet_eps=resnet_eps,
68
+ resnet_act_fn=resnet_act_fn,
69
+ resnet_groups=resnet_groups,
70
+ downsample_padding=downsample_padding,
71
+ resnet_time_scale_shift=resnet_time_scale_shift,
72
+ )
73
+ elif down_block_type == "CrossAttnDownBlock2D":
74
+ if cross_attention_dim is None:
75
+ raise ValueError(
76
+ "cross_attention_dim must be specified for CrossAttnDownBlock2D"
77
+ )
78
+ return CrossAttnDownBlock2D(
79
+ num_layers=num_layers,
80
+ transformer_layers_per_block=transformer_layers_per_block,
81
+ in_channels=in_channels,
82
+ out_channels=out_channels,
83
+ temb_channels=temb_channels,
84
+ dropout=dropout,
85
+ add_downsample=add_downsample,
86
+ resnet_eps=resnet_eps,
87
+ resnet_act_fn=resnet_act_fn,
88
+ resnet_groups=resnet_groups,
89
+ downsample_padding=downsample_padding,
90
+ cross_attention_dim=cross_attention_dim,
91
+ num_attention_heads=num_attention_heads,
92
+ dual_cross_attention=dual_cross_attention,
93
+ use_linear_projection=use_linear_projection,
94
+ only_cross_attention=only_cross_attention,
95
+ upcast_attention=upcast_attention,
96
+ resnet_time_scale_shift=resnet_time_scale_shift,
97
+ attention_type=attention_type,
98
+ )
99
+ raise ValueError(f"{down_block_type} does not exist.")
100
+
101
+
102
+ def get_up_block(
103
+ up_block_type: str,
104
+ num_layers: int,
105
+ in_channels: int,
106
+ out_channels: int,
107
+ prev_output_channel: int,
108
+ temb_channels: int,
109
+ add_upsample: bool,
110
+ resnet_eps: float,
111
+ resnet_act_fn: str,
112
+ resolution_idx: Optional[int] = None,
113
+ transformer_layers_per_block: int = 1,
114
+ num_attention_heads: Optional[int] = None,
115
+ resnet_groups: Optional[int] = None,
116
+ cross_attention_dim: Optional[int] = None,
117
+ dual_cross_attention: bool = False,
118
+ use_linear_projection: bool = False,
119
+ only_cross_attention: bool = False,
120
+ upcast_attention: bool = False,
121
+ resnet_time_scale_shift: str = "default",
122
+ attention_type: str = "default",
123
+ resnet_skip_time_act: bool = False,
124
+ resnet_out_scale_factor: float = 1.0,
125
+ cross_attention_norm: Optional[str] = None,
126
+ attention_head_dim: Optional[int] = None,
127
+ upsample_type: Optional[str] = None,
128
+ dropout: float = 0.0,
129
+ ) -> nn.Module:
130
+ # If attn head dim is not defined, we default it to the number of heads
131
+ if attention_head_dim is None:
132
+ logger.warn(
133
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
134
+ )
135
+ attention_head_dim = num_attention_heads
136
+
137
+ up_block_type = (
138
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
139
+ )
140
+ if up_block_type == "UpBlock2D":
141
+ return UpBlock2D(
142
+ num_layers=num_layers,
143
+ in_channels=in_channels,
144
+ out_channels=out_channels,
145
+ prev_output_channel=prev_output_channel,
146
+ temb_channels=temb_channels,
147
+ resolution_idx=resolution_idx,
148
+ dropout=dropout,
149
+ add_upsample=add_upsample,
150
+ resnet_eps=resnet_eps,
151
+ resnet_act_fn=resnet_act_fn,
152
+ resnet_groups=resnet_groups,
153
+ resnet_time_scale_shift=resnet_time_scale_shift,
154
+ )
155
+ elif up_block_type == "CrossAttnUpBlock2D":
156
+ if cross_attention_dim is None:
157
+ raise ValueError(
158
+ "cross_attention_dim must be specified for CrossAttnUpBlock2D"
159
+ )
160
+ return CrossAttnUpBlock2D(
161
+ num_layers=num_layers,
162
+ transformer_layers_per_block=transformer_layers_per_block,
163
+ in_channels=in_channels,
164
+ out_channels=out_channels,
165
+ prev_output_channel=prev_output_channel,
166
+ temb_channels=temb_channels,
167
+ resolution_idx=resolution_idx,
168
+ dropout=dropout,
169
+ add_upsample=add_upsample,
170
+ resnet_eps=resnet_eps,
171
+ resnet_act_fn=resnet_act_fn,
172
+ resnet_groups=resnet_groups,
173
+ cross_attention_dim=cross_attention_dim,
174
+ num_attention_heads=num_attention_heads,
175
+ dual_cross_attention=dual_cross_attention,
176
+ use_linear_projection=use_linear_projection,
177
+ only_cross_attention=only_cross_attention,
178
+ upcast_attention=upcast_attention,
179
+ resnet_time_scale_shift=resnet_time_scale_shift,
180
+ attention_type=attention_type,
181
+ )
182
+
183
+ raise ValueError(f"{up_block_type} does not exist.")
184
+
185
+
186
+ class AutoencoderTinyBlock(nn.Module):
187
+ """
188
+ Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
189
+ blocks.
190
+
191
+ Args:
192
+ in_channels (`int`): The number of input channels.
193
+ out_channels (`int`): The number of output channels.
194
+ act_fn (`str`):
195
+ ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
196
+
197
+ Returns:
198
+ `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
199
+ `out_channels`.
200
+ """
201
+
202
+ def __init__(self, in_channels: int, out_channels: int, act_fn: str):
203
+ super().__init__()
204
+ act_fn = get_activation(act_fn)
205
+ self.conv = nn.Sequential(
206
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
207
+ act_fn,
208
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
209
+ act_fn,
210
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
211
+ )
212
+ self.skip = (
213
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
214
+ if in_channels != out_channels
215
+ else nn.Identity()
216
+ )
217
+ self.fuse = nn.ReLU()
218
+
219
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
220
+ return self.fuse(self.conv(x) + self.skip(x))
221
+
222
+
223
+ class UNetMidBlock2D(nn.Module):
224
+ """
225
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
226
+
227
+ Args:
228
+ in_channels (`int`): The number of input channels.
229
+ temb_channels (`int`): The number of temporal embedding channels.
230
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
231
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
232
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
233
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
234
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
235
+ model on tasks with long-range temporal dependencies.
236
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
237
+ resnet_groups (`int`, *optional*, defaults to 32):
238
+ The number of groups to use in the group normalization layers of the resnet blocks.
239
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
240
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
241
+ Whether to use pre-normalization for the resnet blocks.
242
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
243
+ attention_head_dim (`int`, *optional*, defaults to 1):
244
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
245
+ the number of input channels.
246
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
247
+
248
+ Returns:
249
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
250
+ in_channels, height, width)`.
251
+
252
+ """
253
+
254
+ def __init__(
255
+ self,
256
+ in_channels: int,
257
+ temb_channels: int,
258
+ dropout: float = 0.0,
259
+ num_layers: int = 1,
260
+ resnet_eps: float = 1e-6,
261
+ resnet_time_scale_shift: str = "default", # default, spatial
262
+ resnet_act_fn: str = "swish",
263
+ resnet_groups: int = 32,
264
+ attn_groups: Optional[int] = None,
265
+ resnet_pre_norm: bool = True,
266
+ add_attention: bool = True,
267
+ attention_head_dim: int = 1,
268
+ output_scale_factor: float = 1.0,
269
+ ):
270
+ super().__init__()
271
+ resnet_groups = (
272
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
273
+ )
274
+ self.add_attention = add_attention
275
+
276
+ if attn_groups is None:
277
+ attn_groups = (
278
+ resnet_groups if resnet_time_scale_shift == "default" else None
279
+ )
280
+
281
+ # there is always at least one resnet
282
+ resnets = [
283
+ ResnetBlock2D(
284
+ in_channels=in_channels,
285
+ out_channels=in_channels,
286
+ temb_channels=temb_channels,
287
+ eps=resnet_eps,
288
+ groups=resnet_groups,
289
+ dropout=dropout,
290
+ time_embedding_norm=resnet_time_scale_shift,
291
+ non_linearity=resnet_act_fn,
292
+ output_scale_factor=output_scale_factor,
293
+ pre_norm=resnet_pre_norm,
294
+ )
295
+ ]
296
+ attentions = []
297
+
298
+ if attention_head_dim is None:
299
+ logger.warn(
300
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
301
+ )
302
+ attention_head_dim = in_channels
303
+
304
+ for _ in range(num_layers):
305
+ if self.add_attention:
306
+ attentions.append(
307
+ Attention(
308
+ in_channels,
309
+ heads=in_channels // attention_head_dim,
310
+ dim_head=attention_head_dim,
311
+ rescale_output_factor=output_scale_factor,
312
+ eps=resnet_eps,
313
+ norm_num_groups=attn_groups,
314
+ spatial_norm_dim=temb_channels
315
+ if resnet_time_scale_shift == "spatial"
316
+ else None,
317
+ residual_connection=True,
318
+ bias=True,
319
+ upcast_softmax=True,
320
+ _from_deprecated_attn_block=True,
321
+ )
322
+ )
323
+ else:
324
+ attentions.append(None)
325
+
326
+ resnets.append(
327
+ ResnetBlock2D(
328
+ in_channels=in_channels,
329
+ out_channels=in_channels,
330
+ temb_channels=temb_channels,
331
+ eps=resnet_eps,
332
+ groups=resnet_groups,
333
+ dropout=dropout,
334
+ time_embedding_norm=resnet_time_scale_shift,
335
+ non_linearity=resnet_act_fn,
336
+ output_scale_factor=output_scale_factor,
337
+ pre_norm=resnet_pre_norm,
338
+ )
339
+ )
340
+
341
+ self.attentions = nn.ModuleList(attentions)
342
+ self.resnets = nn.ModuleList(resnets)
343
+
344
+ def forward(
345
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
346
+ ) -> torch.FloatTensor:
347
+ hidden_states = self.resnets[0](hidden_states, temb)
348
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
349
+ if attn is not None:
350
+ hidden_states = attn(hidden_states, temb=temb)
351
+ hidden_states = resnet(hidden_states, temb)
352
+
353
+ return hidden_states
354
+
355
+
356
+ class UNetMidBlock2DCrossAttn(nn.Module):
357
+ def __init__(
358
+ self,
359
+ in_channels: int,
360
+ temb_channels: int,
361
+ dropout: float = 0.0,
362
+ num_layers: int = 1,
363
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
364
+ resnet_eps: float = 1e-6,
365
+ resnet_time_scale_shift: str = "default",
366
+ resnet_act_fn: str = "swish",
367
+ resnet_groups: int = 32,
368
+ resnet_pre_norm: bool = True,
369
+ num_attention_heads: int = 1,
370
+ output_scale_factor: float = 1.0,
371
+ cross_attention_dim: int = 1280,
372
+ dual_cross_attention: bool = False,
373
+ use_linear_projection: bool = False,
374
+ upcast_attention: bool = False,
375
+ attention_type: str = "default",
376
+ ):
377
+ super().__init__()
378
+
379
+ self.has_cross_attention = True
380
+ self.num_attention_heads = num_attention_heads
381
+ resnet_groups = (
382
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
383
+ )
384
+
385
+ # support for variable transformer layers per block
386
+ if isinstance(transformer_layers_per_block, int):
387
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
388
+
389
+ # there is always at least one resnet
390
+ resnets = [
391
+ ResnetBlock2D(
392
+ in_channels=in_channels,
393
+ out_channels=in_channels,
394
+ temb_channels=temb_channels,
395
+ eps=resnet_eps,
396
+ groups=resnet_groups,
397
+ dropout=dropout,
398
+ time_embedding_norm=resnet_time_scale_shift,
399
+ non_linearity=resnet_act_fn,
400
+ output_scale_factor=output_scale_factor,
401
+ pre_norm=resnet_pre_norm,
402
+ )
403
+ ]
404
+ attentions = []
405
+
406
+ for i in range(num_layers):
407
+ if not dual_cross_attention:
408
+ attentions.append(
409
+ Transformer2DModel(
410
+ num_attention_heads,
411
+ in_channels // num_attention_heads,
412
+ in_channels=in_channels,
413
+ num_layers=transformer_layers_per_block[i],
414
+ cross_attention_dim=cross_attention_dim,
415
+ norm_num_groups=resnet_groups,
416
+ use_linear_projection=use_linear_projection,
417
+ upcast_attention=upcast_attention,
418
+ attention_type=attention_type,
419
+ )
420
+ )
421
+ else:
422
+ attentions.append(
423
+ DualTransformer2DModel(
424
+ num_attention_heads,
425
+ in_channels // num_attention_heads,
426
+ in_channels=in_channels,
427
+ num_layers=1,
428
+ cross_attention_dim=cross_attention_dim,
429
+ norm_num_groups=resnet_groups,
430
+ )
431
+ )
432
+ resnets.append(
433
+ ResnetBlock2D(
434
+ in_channels=in_channels,
435
+ out_channels=in_channels,
436
+ temb_channels=temb_channels,
437
+ eps=resnet_eps,
438
+ groups=resnet_groups,
439
+ dropout=dropout,
440
+ time_embedding_norm=resnet_time_scale_shift,
441
+ non_linearity=resnet_act_fn,
442
+ output_scale_factor=output_scale_factor,
443
+ pre_norm=resnet_pre_norm,
444
+ )
445
+ )
446
+
447
+ self.attentions = nn.ModuleList(attentions)
448
+ self.resnets = nn.ModuleList(resnets)
449
+
450
+ self.gradient_checkpointing = False
451
+
452
+ def forward(
453
+ self,
454
+ hidden_states: torch.FloatTensor,
455
+ temb: Optional[torch.FloatTensor] = None,
456
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
457
+ attention_mask: Optional[torch.FloatTensor] = None,
458
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
459
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
460
+ ) -> torch.FloatTensor:
461
+ lora_scale = (
462
+ cross_attention_kwargs.get("scale", 1.0)
463
+ if cross_attention_kwargs is not None
464
+ else 1.0
465
+ )
466
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
467
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
468
+ if self.training and self.gradient_checkpointing:
469
+
470
+ def create_custom_forward(module, return_dict=None):
471
+ def custom_forward(*inputs):
472
+ if return_dict is not None:
473
+ return module(*inputs, return_dict=return_dict)
474
+ else:
475
+ return module(*inputs)
476
+
477
+ return custom_forward
478
+
479
+ ckpt_kwargs: Dict[str, Any] = (
480
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
481
+ )
482
+ hidden_states, ref_feature = attn(
483
+ hidden_states,
484
+ encoder_hidden_states=encoder_hidden_states,
485
+ cross_attention_kwargs=cross_attention_kwargs,
486
+ attention_mask=attention_mask,
487
+ encoder_attention_mask=encoder_attention_mask,
488
+ return_dict=False,
489
+ )
490
+ hidden_states = torch.utils.checkpoint.checkpoint(
491
+ create_custom_forward(resnet),
492
+ hidden_states,
493
+ temb,
494
+ **ckpt_kwargs,
495
+ )
496
+ else:
497
+ hidden_states, ref_feature = attn(
498
+ hidden_states,
499
+ encoder_hidden_states=encoder_hidden_states,
500
+ cross_attention_kwargs=cross_attention_kwargs,
501
+ attention_mask=attention_mask,
502
+ encoder_attention_mask=encoder_attention_mask,
503
+ return_dict=False,
504
+ )
505
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
506
+
507
+ return hidden_states
508
+
509
+
510
+ class CrossAttnDownBlock2D(nn.Module):
511
+ def __init__(
512
+ self,
513
+ in_channels: int,
514
+ out_channels: int,
515
+ temb_channels: int,
516
+ dropout: float = 0.0,
517
+ num_layers: int = 1,
518
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
519
+ resnet_eps: float = 1e-6,
520
+ resnet_time_scale_shift: str = "default",
521
+ resnet_act_fn: str = "swish",
522
+ resnet_groups: int = 32,
523
+ resnet_pre_norm: bool = True,
524
+ num_attention_heads: int = 1,
525
+ cross_attention_dim: int = 1280,
526
+ output_scale_factor: float = 1.0,
527
+ downsample_padding: int = 1,
528
+ add_downsample: bool = True,
529
+ dual_cross_attention: bool = False,
530
+ use_linear_projection: bool = False,
531
+ only_cross_attention: bool = False,
532
+ upcast_attention: bool = False,
533
+ attention_type: str = "default",
534
+ ):
535
+ super().__init__()
536
+ resnets = []
537
+ attentions = []
538
+
539
+ self.has_cross_attention = True
540
+ self.num_attention_heads = num_attention_heads
541
+ if isinstance(transformer_layers_per_block, int):
542
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
543
+
544
+ for i in range(num_layers):
545
+ in_channels = in_channels if i == 0 else out_channels
546
+ resnets.append(
547
+ ResnetBlock2D(
548
+ in_channels=in_channels,
549
+ out_channels=out_channels,
550
+ temb_channels=temb_channels,
551
+ eps=resnet_eps,
552
+ groups=resnet_groups,
553
+ dropout=dropout,
554
+ time_embedding_norm=resnet_time_scale_shift,
555
+ non_linearity=resnet_act_fn,
556
+ output_scale_factor=output_scale_factor,
557
+ pre_norm=resnet_pre_norm,
558
+ )
559
+ )
560
+ if not dual_cross_attention:
561
+ attentions.append(
562
+ Transformer2DModel(
563
+ num_attention_heads,
564
+ out_channels // num_attention_heads,
565
+ in_channels=out_channels,
566
+ num_layers=transformer_layers_per_block[i],
567
+ cross_attention_dim=cross_attention_dim,
568
+ norm_num_groups=resnet_groups,
569
+ use_linear_projection=use_linear_projection,
570
+ only_cross_attention=only_cross_attention,
571
+ upcast_attention=upcast_attention,
572
+ attention_type=attention_type,
573
+ )
574
+ )
575
+ else:
576
+ attentions.append(
577
+ DualTransformer2DModel(
578
+ num_attention_heads,
579
+ out_channels // num_attention_heads,
580
+ in_channels=out_channels,
581
+ num_layers=1,
582
+ cross_attention_dim=cross_attention_dim,
583
+ norm_num_groups=resnet_groups,
584
+ )
585
+ )
586
+ self.attentions = nn.ModuleList(attentions)
587
+ self.resnets = nn.ModuleList(resnets)
588
+
589
+ if add_downsample:
590
+ self.downsamplers = nn.ModuleList(
591
+ [
592
+ Downsample2D(
593
+ out_channels,
594
+ use_conv=True,
595
+ out_channels=out_channels,
596
+ padding=downsample_padding,
597
+ name="op",
598
+ )
599
+ ]
600
+ )
601
+ else:
602
+ self.downsamplers = None
603
+
604
+ self.gradient_checkpointing = False
605
+
606
+ def forward(
607
+ self,
608
+ hidden_states: torch.FloatTensor,
609
+ temb: Optional[torch.FloatTensor] = None,
610
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
611
+ attention_mask: Optional[torch.FloatTensor] = None,
612
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
613
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
614
+ additional_residuals: Optional[torch.FloatTensor] = None,
615
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
616
+ output_states = ()
617
+
618
+ lora_scale = (
619
+ cross_attention_kwargs.get("scale", 1.0)
620
+ if cross_attention_kwargs is not None
621
+ else 1.0
622
+ )
623
+
624
+ blocks = list(zip(self.resnets, self.attentions))
625
+
626
+ for i, (resnet, attn) in enumerate(blocks):
627
+ if self.training and self.gradient_checkpointing:
628
+
629
+ def create_custom_forward(module, return_dict=None):
630
+ def custom_forward(*inputs):
631
+ if return_dict is not None:
632
+ return module(*inputs, return_dict=return_dict)
633
+ else:
634
+ return module(*inputs)
635
+
636
+ return custom_forward
637
+
638
+ ckpt_kwargs: Dict[str, Any] = (
639
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
640
+ )
641
+ hidden_states = torch.utils.checkpoint.checkpoint(
642
+ create_custom_forward(resnet),
643
+ hidden_states,
644
+ temb,
645
+ **ckpt_kwargs,
646
+ )
647
+ hidden_states, ref_feature = attn(
648
+ hidden_states,
649
+ encoder_hidden_states=encoder_hidden_states,
650
+ cross_attention_kwargs=cross_attention_kwargs,
651
+ attention_mask=attention_mask,
652
+ encoder_attention_mask=encoder_attention_mask,
653
+ return_dict=False,
654
+ )
655
+ else:
656
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
657
+ hidden_states, ref_feature = attn(
658
+ hidden_states,
659
+ encoder_hidden_states=encoder_hidden_states,
660
+ cross_attention_kwargs=cross_attention_kwargs,
661
+ attention_mask=attention_mask,
662
+ encoder_attention_mask=encoder_attention_mask,
663
+ return_dict=False,
664
+ )
665
+
666
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
667
+ if i == len(blocks) - 1 and additional_residuals is not None:
668
+ hidden_states = hidden_states + additional_residuals
669
+
670
+ output_states = output_states + (hidden_states,)
671
+
672
+ if self.downsamplers is not None:
673
+ for downsampler in self.downsamplers:
674
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
675
+
676
+ output_states = output_states + (hidden_states,)
677
+
678
+ return hidden_states, output_states
679
+
680
+
681
+ class DownBlock2D(nn.Module):
682
+ def __init__(
683
+ self,
684
+ in_channels: int,
685
+ out_channels: int,
686
+ temb_channels: int,
687
+ dropout: float = 0.0,
688
+ num_layers: int = 1,
689
+ resnet_eps: float = 1e-6,
690
+ resnet_time_scale_shift: str = "default",
691
+ resnet_act_fn: str = "swish",
692
+ resnet_groups: int = 32,
693
+ resnet_pre_norm: bool = True,
694
+ output_scale_factor: float = 1.0,
695
+ add_downsample: bool = True,
696
+ downsample_padding: int = 1,
697
+ ):
698
+ super().__init__()
699
+ resnets = []
700
+
701
+ for i in range(num_layers):
702
+ in_channels = in_channels if i == 0 else out_channels
703
+ resnets.append(
704
+ ResnetBlock2D(
705
+ in_channels=in_channels,
706
+ out_channels=out_channels,
707
+ temb_channels=temb_channels,
708
+ eps=resnet_eps,
709
+ groups=resnet_groups,
710
+ dropout=dropout,
711
+ time_embedding_norm=resnet_time_scale_shift,
712
+ non_linearity=resnet_act_fn,
713
+ output_scale_factor=output_scale_factor,
714
+ pre_norm=resnet_pre_norm,
715
+ )
716
+ )
717
+
718
+ self.resnets = nn.ModuleList(resnets)
719
+
720
+ if add_downsample:
721
+ self.downsamplers = nn.ModuleList(
722
+ [
723
+ Downsample2D(
724
+ out_channels,
725
+ use_conv=True,
726
+ out_channels=out_channels,
727
+ padding=downsample_padding,
728
+ name="op",
729
+ )
730
+ ]
731
+ )
732
+ else:
733
+ self.downsamplers = None
734
+
735
+ self.gradient_checkpointing = False
736
+
737
+ def forward(
738
+ self,
739
+ hidden_states: torch.FloatTensor,
740
+ temb: Optional[torch.FloatTensor] = None,
741
+ scale: float = 1.0,
742
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
743
+ output_states = ()
744
+
745
+ for resnet in self.resnets:
746
+ if self.training and self.gradient_checkpointing:
747
+
748
+ def create_custom_forward(module):
749
+ def custom_forward(*inputs):
750
+ return module(*inputs)
751
+
752
+ return custom_forward
753
+
754
+ if is_torch_version(">=", "1.11.0"):
755
+ hidden_states = torch.utils.checkpoint.checkpoint(
756
+ create_custom_forward(resnet),
757
+ hidden_states,
758
+ temb,
759
+ use_reentrant=False,
760
+ )
761
+ else:
762
+ hidden_states = torch.utils.checkpoint.checkpoint(
763
+ create_custom_forward(resnet), hidden_states, temb
764
+ )
765
+ else:
766
+ hidden_states = resnet(hidden_states, temb, scale=scale)
767
+
768
+ output_states = output_states + (hidden_states,)
769
+
770
+ if self.downsamplers is not None:
771
+ for downsampler in self.downsamplers:
772
+ hidden_states = downsampler(hidden_states, scale=scale)
773
+
774
+ output_states = output_states + (hidden_states,)
775
+
776
+ return hidden_states, output_states
777
+
778
+
779
+ class CrossAttnUpBlock2D(nn.Module):
780
+ def __init__(
781
+ self,
782
+ in_channels: int,
783
+ out_channels: int,
784
+ prev_output_channel: int,
785
+ temb_channels: int,
786
+ resolution_idx: Optional[int] = None,
787
+ dropout: float = 0.0,
788
+ num_layers: int = 1,
789
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
790
+ resnet_eps: float = 1e-6,
791
+ resnet_time_scale_shift: str = "default",
792
+ resnet_act_fn: str = "swish",
793
+ resnet_groups: int = 32,
794
+ resnet_pre_norm: bool = True,
795
+ num_attention_heads: int = 1,
796
+ cross_attention_dim: int = 1280,
797
+ output_scale_factor: float = 1.0,
798
+ add_upsample: bool = True,
799
+ dual_cross_attention: bool = False,
800
+ use_linear_projection: bool = False,
801
+ only_cross_attention: bool = False,
802
+ upcast_attention: bool = False,
803
+ attention_type: str = "default",
804
+ ):
805
+ super().__init__()
806
+ resnets = []
807
+ attentions = []
808
+
809
+ self.has_cross_attention = True
810
+ self.num_attention_heads = num_attention_heads
811
+
812
+ if isinstance(transformer_layers_per_block, int):
813
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
814
+
815
+ for i in range(num_layers):
816
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
817
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
818
+
819
+ resnets.append(
820
+ ResnetBlock2D(
821
+ in_channels=resnet_in_channels + res_skip_channels,
822
+ out_channels=out_channels,
823
+ temb_channels=temb_channels,
824
+ eps=resnet_eps,
825
+ groups=resnet_groups,
826
+ dropout=dropout,
827
+ time_embedding_norm=resnet_time_scale_shift,
828
+ non_linearity=resnet_act_fn,
829
+ output_scale_factor=output_scale_factor,
830
+ pre_norm=resnet_pre_norm,
831
+ )
832
+ )
833
+ if not dual_cross_attention:
834
+ attentions.append(
835
+ Transformer2DModel(
836
+ num_attention_heads,
837
+ out_channels // num_attention_heads,
838
+ in_channels=out_channels,
839
+ num_layers=transformer_layers_per_block[i],
840
+ cross_attention_dim=cross_attention_dim,
841
+ norm_num_groups=resnet_groups,
842
+ use_linear_projection=use_linear_projection,
843
+ only_cross_attention=only_cross_attention,
844
+ upcast_attention=upcast_attention,
845
+ attention_type=attention_type,
846
+ )
847
+ )
848
+ else:
849
+ attentions.append(
850
+ DualTransformer2DModel(
851
+ num_attention_heads,
852
+ out_channels // num_attention_heads,
853
+ in_channels=out_channels,
854
+ num_layers=1,
855
+ cross_attention_dim=cross_attention_dim,
856
+ norm_num_groups=resnet_groups,
857
+ )
858
+ )
859
+ self.attentions = nn.ModuleList(attentions)
860
+ self.resnets = nn.ModuleList(resnets)
861
+
862
+ if add_upsample:
863
+ self.upsamplers = nn.ModuleList(
864
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
865
+ )
866
+ else:
867
+ self.upsamplers = None
868
+
869
+ self.gradient_checkpointing = False
870
+ self.resolution_idx = resolution_idx
871
+
872
+ def forward(
873
+ self,
874
+ hidden_states: torch.FloatTensor,
875
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
876
+ temb: Optional[torch.FloatTensor] = None,
877
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
878
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
879
+ upsample_size: Optional[int] = None,
880
+ attention_mask: Optional[torch.FloatTensor] = None,
881
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
882
+ ) -> torch.FloatTensor:
883
+ lora_scale = (
884
+ cross_attention_kwargs.get("scale", 1.0)
885
+ if cross_attention_kwargs is not None
886
+ else 1.0
887
+ )
888
+ is_freeu_enabled = (
889
+ getattr(self, "s1", None)
890
+ and getattr(self, "s2", None)
891
+ and getattr(self, "b1", None)
892
+ and getattr(self, "b2", None)
893
+ )
894
+
895
+ for resnet, attn in zip(self.resnets, self.attentions):
896
+ # pop res hidden states
897
+ res_hidden_states = res_hidden_states_tuple[-1]
898
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
899
+
900
+ # FreeU: Only operate on the first two stages
901
+ if is_freeu_enabled:
902
+ hidden_states, res_hidden_states = apply_freeu(
903
+ self.resolution_idx,
904
+ hidden_states,
905
+ res_hidden_states,
906
+ s1=self.s1,
907
+ s2=self.s2,
908
+ b1=self.b1,
909
+ b2=self.b2,
910
+ )
911
+
912
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
913
+
914
+ if self.training and self.gradient_checkpointing:
915
+
916
+ def create_custom_forward(module, return_dict=None):
917
+ def custom_forward(*inputs):
918
+ if return_dict is not None:
919
+ return module(*inputs, return_dict=return_dict)
920
+ else:
921
+ return module(*inputs)
922
+
923
+ return custom_forward
924
+
925
+ ckpt_kwargs: Dict[str, Any] = (
926
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
927
+ )
928
+ hidden_states = torch.utils.checkpoint.checkpoint(
929
+ create_custom_forward(resnet),
930
+ hidden_states,
931
+ temb,
932
+ **ckpt_kwargs,
933
+ )
934
+ hidden_states, ref_feature = attn(
935
+ hidden_states,
936
+ encoder_hidden_states=encoder_hidden_states,
937
+ cross_attention_kwargs=cross_attention_kwargs,
938
+ attention_mask=attention_mask,
939
+ encoder_attention_mask=encoder_attention_mask,
940
+ return_dict=False,
941
+ )
942
+ else:
943
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
944
+ hidden_states, ref_feature = attn(
945
+ hidden_states,
946
+ encoder_hidden_states=encoder_hidden_states,
947
+ cross_attention_kwargs=cross_attention_kwargs,
948
+ attention_mask=attention_mask,
949
+ encoder_attention_mask=encoder_attention_mask,
950
+ return_dict=False,
951
+ )
952
+
953
+ if self.upsamplers is not None:
954
+ for upsampler in self.upsamplers:
955
+ hidden_states = upsampler(
956
+ hidden_states, upsample_size, scale=lora_scale
957
+ )
958
+
959
+ return hidden_states
960
+
961
+
962
+ class UpBlock2D(nn.Module):
963
+ def __init__(
964
+ self,
965
+ in_channels: int,
966
+ prev_output_channel: int,
967
+ out_channels: int,
968
+ temb_channels: int,
969
+ resolution_idx: Optional[int] = None,
970
+ dropout: float = 0.0,
971
+ num_layers: int = 1,
972
+ resnet_eps: float = 1e-6,
973
+ resnet_time_scale_shift: str = "default",
974
+ resnet_act_fn: str = "swish",
975
+ resnet_groups: int = 32,
976
+ resnet_pre_norm: bool = True,
977
+ output_scale_factor: float = 1.0,
978
+ add_upsample: bool = True,
979
+ ):
980
+ super().__init__()
981
+ resnets = []
982
+
983
+ for i in range(num_layers):
984
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
985
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
986
+
987
+ resnets.append(
988
+ ResnetBlock2D(
989
+ in_channels=resnet_in_channels + res_skip_channels,
990
+ out_channels=out_channels,
991
+ temb_channels=temb_channels,
992
+ eps=resnet_eps,
993
+ groups=resnet_groups,
994
+ dropout=dropout,
995
+ time_embedding_norm=resnet_time_scale_shift,
996
+ non_linearity=resnet_act_fn,
997
+ output_scale_factor=output_scale_factor,
998
+ pre_norm=resnet_pre_norm,
999
+ )
1000
+ )
1001
+
1002
+ self.resnets = nn.ModuleList(resnets)
1003
+
1004
+ if add_upsample:
1005
+ self.upsamplers = nn.ModuleList(
1006
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
1007
+ )
1008
+ else:
1009
+ self.upsamplers = None
1010
+
1011
+ self.gradient_checkpointing = False
1012
+ self.resolution_idx = resolution_idx
1013
+
1014
+ def forward(
1015
+ self,
1016
+ hidden_states: torch.FloatTensor,
1017
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1018
+ temb: Optional[torch.FloatTensor] = None,
1019
+ upsample_size: Optional[int] = None,
1020
+ scale: float = 1.0,
1021
+ ) -> torch.FloatTensor:
1022
+ is_freeu_enabled = (
1023
+ getattr(self, "s1", None)
1024
+ and getattr(self, "s2", None)
1025
+ and getattr(self, "b1", None)
1026
+ and getattr(self, "b2", None)
1027
+ )
1028
+
1029
+ for resnet in self.resnets:
1030
+ # pop res hidden states
1031
+ res_hidden_states = res_hidden_states_tuple[-1]
1032
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1033
+
1034
+ # FreeU: Only operate on the first two stages
1035
+ if is_freeu_enabled:
1036
+ hidden_states, res_hidden_states = apply_freeu(
1037
+ self.resolution_idx,
1038
+ hidden_states,
1039
+ res_hidden_states,
1040
+ s1=self.s1,
1041
+ s2=self.s2,
1042
+ b1=self.b1,
1043
+ b2=self.b2,
1044
+ )
1045
+
1046
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1047
+
1048
+ if self.training and self.gradient_checkpointing:
1049
+
1050
+ def create_custom_forward(module):
1051
+ def custom_forward(*inputs):
1052
+ return module(*inputs)
1053
+
1054
+ return custom_forward
1055
+
1056
+ if is_torch_version(">=", "1.11.0"):
1057
+ hidden_states = torch.utils.checkpoint.checkpoint(
1058
+ create_custom_forward(resnet),
1059
+ hidden_states,
1060
+ temb,
1061
+ use_reentrant=False,
1062
+ )
1063
+ else:
1064
+ hidden_states = torch.utils.checkpoint.checkpoint(
1065
+ create_custom_forward(resnet), hidden_states, temb
1066
+ )
1067
+ else:
1068
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1069
+
1070
+ if self.upsamplers is not None:
1071
+ for upsampler in self.upsamplers:
1072
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
1073
+
1074
+ return hidden_states
ref_encoder/unet_2d_condition.py ADDED
@@ -0,0 +1,1308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.utils.checkpoint
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import UNet2DConditionLoadersMixin
10
+ from diffusers.models.activations import get_activation
11
+ from diffusers.models.attention_processor import (
12
+ ADDED_KV_ATTENTION_PROCESSORS,
13
+ CROSS_ATTENTION_PROCESSORS,
14
+ AttentionProcessor,
15
+ AttnAddedKVProcessor,
16
+ AttnProcessor,
17
+ )
18
+ from diffusers.models.embeddings import (
19
+ GaussianFourierProjection,
20
+ ImageHintTimeEmbedding,
21
+ ImageProjection,
22
+ ImageTimeEmbedding,
23
+ PositionNet,
24
+ TextImageProjection,
25
+ TextImageTimeEmbedding,
26
+ TextTimeEmbedding,
27
+ TimestepEmbedding,
28
+ Timesteps,
29
+ )
30
+ from diffusers.models.modeling_utils import ModelMixin
31
+ from diffusers.utils import (
32
+ USE_PEFT_BACKEND,
33
+ BaseOutput,
34
+ deprecate,
35
+ logging,
36
+ scale_lora_layers,
37
+ unscale_lora_layers,
38
+ )
39
+
40
+ from .unet_2d_blocks import (
41
+ UNetMidBlock2D,
42
+ UNetMidBlock2DCrossAttn,
43
+ get_down_block,
44
+ get_up_block,
45
+ )
46
+
47
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
+
49
+
50
+ @dataclass
51
+ class UNet2DConditionOutput(BaseOutput):
52
+ """
53
+ The output of [`UNet2DConditionModel`].
54
+
55
+ Args:
56
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
57
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
58
+ """
59
+
60
+ sample: torch.FloatTensor = None
61
+ ref_features: Tuple[torch.FloatTensor] = None
62
+
63
+
64
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
65
+ r"""
66
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
67
+ shaped output.
68
+
69
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
70
+ for all models (such as downloading or saving).
71
+
72
+ Parameters:
73
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
74
+ Height and width of input/output sample.
75
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
76
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
77
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
78
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
79
+ Whether to flip the sin to cos in the time embedding.
80
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
81
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
82
+ The tuple of downsample blocks to use.
83
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
84
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
85
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
86
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
87
+ The tuple of upsample blocks to use.
88
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
89
+ Whether to include self-attention in the basic transformer blocks, see
90
+ [`~models.attention.BasicTransformerBlock`].
91
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
92
+ The tuple of output channels for each block.
93
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
94
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
95
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
96
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
97
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
98
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
99
+ If `None`, normalization and activation layers is skipped in post-processing.
100
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
101
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
102
+ The dimension of the cross attention features.
103
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
104
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
105
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
106
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
107
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
108
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
109
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
110
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
111
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
112
+ encoder_hid_dim (`int`, *optional*, defaults to None):
113
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
114
+ dimension to `cross_attention_dim`.
115
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
116
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
117
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
118
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
119
+ num_attention_heads (`int`, *optional*):
120
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
121
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
122
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
123
+ class_embed_type (`str`, *optional*, defaults to `None`):
124
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
125
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
126
+ addition_embed_type (`str`, *optional*, defaults to `None`):
127
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
128
+ "text". "text" will use the `TextTimeEmbedding` layer.
129
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
130
+ Dimension for the timestep embeddings.
131
+ num_class_embeds (`int`, *optional*, defaults to `None`):
132
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
133
+ class conditioning with `class_embed_type` equal to `None`.
134
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
135
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
136
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
137
+ An optional override for the dimension of the projected time embedding.
138
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
139
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
140
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
141
+ timestep_post_act (`str`, *optional*, defaults to `None`):
142
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
143
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
144
+ The dimension of `cond_proj` layer in the timestep embedding.
145
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
146
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
147
+ *optional*): The dimension of the `class_labels` input when
148
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
149
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
150
+ embeddings with the class embeddings.
151
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
152
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
153
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
154
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
155
+ otherwise.
156
+ """
157
+
158
+ _supports_gradient_checkpointing = True
159
+
160
+ @register_to_config
161
+ def __init__(
162
+ self,
163
+ sample_size: Optional[int] = None,
164
+ in_channels: int = 4,
165
+ out_channels: int = 4,
166
+ center_input_sample: bool = False,
167
+ flip_sin_to_cos: bool = True,
168
+ freq_shift: int = 0,
169
+ down_block_types: Tuple[str] = (
170
+ "CrossAttnDownBlock2D",
171
+ "CrossAttnDownBlock2D",
172
+ "CrossAttnDownBlock2D",
173
+ "DownBlock2D",
174
+ ),
175
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
176
+ up_block_types: Tuple[str] = (
177
+ "UpBlock2D",
178
+ "CrossAttnUpBlock2D",
179
+ "CrossAttnUpBlock2D",
180
+ "CrossAttnUpBlock2D",
181
+ ),
182
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
183
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
184
+ layers_per_block: Union[int, Tuple[int]] = 2,
185
+ downsample_padding: int = 1,
186
+ mid_block_scale_factor: float = 1,
187
+ dropout: float = 0.0,
188
+ act_fn: str = "silu",
189
+ norm_num_groups: Optional[int] = 32,
190
+ norm_eps: float = 1e-5,
191
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
192
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
193
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
194
+ encoder_hid_dim: Optional[int] = None,
195
+ encoder_hid_dim_type: Optional[str] = None,
196
+ attention_head_dim: Union[int, Tuple[int]] = 8,
197
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
198
+ dual_cross_attention: bool = False,
199
+ use_linear_projection: bool = False,
200
+ class_embed_type: Optional[str] = None,
201
+ addition_embed_type: Optional[str] = None,
202
+ addition_time_embed_dim: Optional[int] = None,
203
+ num_class_embeds: Optional[int] = None,
204
+ upcast_attention: bool = False,
205
+ resnet_time_scale_shift: str = "default",
206
+ resnet_skip_time_act: bool = False,
207
+ resnet_out_scale_factor: int = 1.0,
208
+ time_embedding_type: str = "positional",
209
+ time_embedding_dim: Optional[int] = None,
210
+ time_embedding_act_fn: Optional[str] = None,
211
+ timestep_post_act: Optional[str] = None,
212
+ time_cond_proj_dim: Optional[int] = None,
213
+ conv_in_kernel: int = 3,
214
+ conv_out_kernel: int = 3,
215
+ projection_class_embeddings_input_dim: Optional[int] = None,
216
+ attention_type: str = "default",
217
+ class_embeddings_concat: bool = False,
218
+ mid_block_only_cross_attention: Optional[bool] = None,
219
+ cross_attention_norm: Optional[str] = None,
220
+ addition_embed_type_num_heads=64,
221
+ ):
222
+ super().__init__()
223
+
224
+ self.sample_size = sample_size
225
+
226
+ if num_attention_heads is not None:
227
+ raise ValueError(
228
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
229
+ )
230
+
231
+ # If `num_attention_heads` is not defined (which is the case for most models)
232
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
233
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
234
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
235
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
236
+ # which is why we correct for the naming here.
237
+ num_attention_heads = num_attention_heads or attention_head_dim
238
+
239
+ # Check inputs
240
+ if len(down_block_types) != len(up_block_types):
241
+ raise ValueError(
242
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
243
+ )
244
+
245
+ if len(block_out_channels) != len(down_block_types):
246
+ raise ValueError(
247
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
248
+ )
249
+
250
+ if not isinstance(only_cross_attention, bool) and len(
251
+ only_cross_attention
252
+ ) != len(down_block_types):
253
+ raise ValueError(
254
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
255
+ )
256
+
257
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
258
+ down_block_types
259
+ ):
260
+ raise ValueError(
261
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
262
+ )
263
+
264
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
265
+ down_block_types
266
+ ):
267
+ raise ValueError(
268
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
269
+ )
270
+
271
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
272
+ down_block_types
273
+ ):
274
+ raise ValueError(
275
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
276
+ )
277
+
278
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
279
+ down_block_types
280
+ ):
281
+ raise ValueError(
282
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
283
+ )
284
+ if (
285
+ isinstance(transformer_layers_per_block, list)
286
+ and reverse_transformer_layers_per_block is None
287
+ ):
288
+ for layer_number_per_block in transformer_layers_per_block:
289
+ if isinstance(layer_number_per_block, list):
290
+ raise ValueError(
291
+ "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet."
292
+ )
293
+
294
+ # input
295
+ conv_in_padding = (conv_in_kernel - 1) // 2
296
+ self.conv_in = nn.Conv2d(
297
+ in_channels,
298
+ block_out_channels[0],
299
+ kernel_size=conv_in_kernel,
300
+ padding=conv_in_padding,
301
+ )
302
+
303
+ # time
304
+ if time_embedding_type == "fourier":
305
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
306
+ if time_embed_dim % 2 != 0:
307
+ raise ValueError(
308
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
309
+ )
310
+ self.time_proj = GaussianFourierProjection(
311
+ time_embed_dim // 2,
312
+ set_W_to_weight=False,
313
+ log=False,
314
+ flip_sin_to_cos=flip_sin_to_cos,
315
+ )
316
+ timestep_input_dim = time_embed_dim
317
+ elif time_embedding_type == "positional":
318
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
319
+
320
+ self.time_proj = Timesteps(
321
+ block_out_channels[0], flip_sin_to_cos, freq_shift
322
+ )
323
+ timestep_input_dim = block_out_channels[0]
324
+ else:
325
+ raise ValueError(
326
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
327
+ )
328
+
329
+ self.time_embedding = TimestepEmbedding(
330
+ timestep_input_dim,
331
+ time_embed_dim,
332
+ act_fn=act_fn,
333
+ post_act_fn=timestep_post_act,
334
+ cond_proj_dim=time_cond_proj_dim,
335
+ )
336
+
337
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
338
+ encoder_hid_dim_type = "text_proj"
339
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
340
+ logger.info(
341
+ "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
342
+ )
343
+
344
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
345
+ raise ValueError(
346
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
347
+ )
348
+
349
+ if encoder_hid_dim_type == "text_proj":
350
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
351
+ elif encoder_hid_dim_type == "text_image_proj":
352
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
353
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
354
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
355
+ self.encoder_hid_proj = TextImageProjection(
356
+ text_embed_dim=encoder_hid_dim,
357
+ image_embed_dim=cross_attention_dim,
358
+ cross_attention_dim=cross_attention_dim,
359
+ )
360
+ elif encoder_hid_dim_type == "image_proj":
361
+ # Kandinsky 2.2
362
+ self.encoder_hid_proj = ImageProjection(
363
+ image_embed_dim=encoder_hid_dim,
364
+ cross_attention_dim=cross_attention_dim,
365
+ )
366
+ elif encoder_hid_dim_type is not None:
367
+ raise ValueError(
368
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
369
+ )
370
+ else:
371
+ self.encoder_hid_proj = None
372
+
373
+ # class embedding
374
+ if class_embed_type is None and num_class_embeds is not None:
375
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
376
+ elif class_embed_type == "timestep":
377
+ self.class_embedding = TimestepEmbedding(
378
+ timestep_input_dim, time_embed_dim, act_fn=act_fn
379
+ )
380
+ elif class_embed_type == "identity":
381
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
382
+ elif class_embed_type == "projection":
383
+ if projection_class_embeddings_input_dim is None:
384
+ raise ValueError(
385
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
386
+ )
387
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
388
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
389
+ # 2. it projects from an arbitrary input dimension.
390
+ #
391
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
392
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
393
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
394
+ self.class_embedding = TimestepEmbedding(
395
+ projection_class_embeddings_input_dim, time_embed_dim
396
+ )
397
+ elif class_embed_type == "simple_projection":
398
+ if projection_class_embeddings_input_dim is None:
399
+ raise ValueError(
400
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
401
+ )
402
+ self.class_embedding = nn.Linear(
403
+ projection_class_embeddings_input_dim, time_embed_dim
404
+ )
405
+ else:
406
+ self.class_embedding = None
407
+
408
+ if addition_embed_type == "text":
409
+ if encoder_hid_dim is not None:
410
+ text_time_embedding_from_dim = encoder_hid_dim
411
+ else:
412
+ text_time_embedding_from_dim = cross_attention_dim
413
+
414
+ self.add_embedding = TextTimeEmbedding(
415
+ text_time_embedding_from_dim,
416
+ time_embed_dim,
417
+ num_heads=addition_embed_type_num_heads,
418
+ )
419
+ elif addition_embed_type == "text_image":
420
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
421
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
422
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
423
+ self.add_embedding = TextImageTimeEmbedding(
424
+ text_embed_dim=cross_attention_dim,
425
+ image_embed_dim=cross_attention_dim,
426
+ time_embed_dim=time_embed_dim,
427
+ )
428
+ elif addition_embed_type == "text_time":
429
+ self.add_time_proj = Timesteps(
430
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
431
+ )
432
+ self.add_embedding = TimestepEmbedding(
433
+ projection_class_embeddings_input_dim, time_embed_dim
434
+ )
435
+ elif addition_embed_type == "image":
436
+ # Kandinsky 2.2
437
+ self.add_embedding = ImageTimeEmbedding(
438
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
439
+ )
440
+ elif addition_embed_type == "image_hint":
441
+ # Kandinsky 2.2 ControlNet
442
+ self.add_embedding = ImageHintTimeEmbedding(
443
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
444
+ )
445
+ elif addition_embed_type is not None:
446
+ raise ValueError(
447
+ f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
448
+ )
449
+
450
+ if time_embedding_act_fn is None:
451
+ self.time_embed_act = None
452
+ else:
453
+ self.time_embed_act = get_activation(time_embedding_act_fn)
454
+
455
+ self.down_blocks = nn.ModuleList([])
456
+ self.up_blocks = nn.ModuleList([])
457
+
458
+ if isinstance(only_cross_attention, bool):
459
+ if mid_block_only_cross_attention is None:
460
+ mid_block_only_cross_attention = only_cross_attention
461
+
462
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
463
+
464
+ if mid_block_only_cross_attention is None:
465
+ mid_block_only_cross_attention = False
466
+
467
+ if isinstance(num_attention_heads, int):
468
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
469
+
470
+ if isinstance(attention_head_dim, int):
471
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
472
+
473
+ if isinstance(cross_attention_dim, int):
474
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
475
+
476
+ if isinstance(layers_per_block, int):
477
+ layers_per_block = [layers_per_block] * len(down_block_types)
478
+
479
+ if isinstance(transformer_layers_per_block, int):
480
+ transformer_layers_per_block = [transformer_layers_per_block] * len(
481
+ down_block_types
482
+ )
483
+
484
+ if class_embeddings_concat:
485
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
486
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
487
+ # regular time embeddings
488
+ blocks_time_embed_dim = time_embed_dim * 2
489
+ else:
490
+ blocks_time_embed_dim = time_embed_dim
491
+
492
+ # down
493
+ output_channel = block_out_channels[0]
494
+ for i, down_block_type in enumerate(down_block_types):
495
+ input_channel = output_channel
496
+ output_channel = block_out_channels[i]
497
+ is_final_block = i == len(block_out_channels) - 1
498
+
499
+ down_block = get_down_block(
500
+ down_block_type,
501
+ num_layers=layers_per_block[i],
502
+ transformer_layers_per_block=transformer_layers_per_block[i],
503
+ in_channels=input_channel,
504
+ out_channels=output_channel,
505
+ temb_channels=blocks_time_embed_dim,
506
+ add_downsample=not is_final_block,
507
+ resnet_eps=norm_eps,
508
+ resnet_act_fn=act_fn,
509
+ resnet_groups=norm_num_groups,
510
+ cross_attention_dim=cross_attention_dim[i],
511
+ num_attention_heads=num_attention_heads[i],
512
+ downsample_padding=downsample_padding,
513
+ dual_cross_attention=dual_cross_attention,
514
+ use_linear_projection=use_linear_projection,
515
+ only_cross_attention=only_cross_attention[i],
516
+ upcast_attention=upcast_attention,
517
+ resnet_time_scale_shift=resnet_time_scale_shift,
518
+ attention_type=attention_type,
519
+ resnet_skip_time_act=resnet_skip_time_act,
520
+ resnet_out_scale_factor=resnet_out_scale_factor,
521
+ cross_attention_norm=cross_attention_norm,
522
+ attention_head_dim=attention_head_dim[i]
523
+ if attention_head_dim[i] is not None
524
+ else output_channel,
525
+ dropout=dropout,
526
+ )
527
+ self.down_blocks.append(down_block)
528
+
529
+ # mid
530
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
531
+ self.mid_block = UNetMidBlock2DCrossAttn(
532
+ transformer_layers_per_block=transformer_layers_per_block[-1],
533
+ in_channels=block_out_channels[-1],
534
+ temb_channels=blocks_time_embed_dim,
535
+ dropout=dropout,
536
+ resnet_eps=norm_eps,
537
+ resnet_act_fn=act_fn,
538
+ output_scale_factor=mid_block_scale_factor,
539
+ resnet_time_scale_shift=resnet_time_scale_shift,
540
+ cross_attention_dim=cross_attention_dim[-1],
541
+ num_attention_heads=num_attention_heads[-1],
542
+ resnet_groups=norm_num_groups,
543
+ dual_cross_attention=dual_cross_attention,
544
+ use_linear_projection=use_linear_projection,
545
+ upcast_attention=upcast_attention,
546
+ attention_type=attention_type,
547
+ )
548
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
549
+ raise NotImplementedError(f"Unsupport mid_block_type: {mid_block_type}")
550
+ elif mid_block_type == "UNetMidBlock2D":
551
+ self.mid_block = UNetMidBlock2D(
552
+ in_channels=block_out_channels[-1],
553
+ temb_channels=blocks_time_embed_dim,
554
+ dropout=dropout,
555
+ num_layers=0,
556
+ resnet_eps=norm_eps,
557
+ resnet_act_fn=act_fn,
558
+ output_scale_factor=mid_block_scale_factor,
559
+ resnet_groups=norm_num_groups,
560
+ resnet_time_scale_shift=resnet_time_scale_shift,
561
+ add_attention=False,
562
+ )
563
+ elif mid_block_type is None:
564
+ self.mid_block = None
565
+ else:
566
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
567
+
568
+ # count how many layers upsample the images
569
+ self.num_upsamplers = 0
570
+
571
+ # up
572
+ reversed_block_out_channels = list(reversed(block_out_channels))
573
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
574
+ reversed_layers_per_block = list(reversed(layers_per_block))
575
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
576
+ reversed_transformer_layers_per_block = (
577
+ list(reversed(transformer_layers_per_block))
578
+ if reverse_transformer_layers_per_block is None
579
+ else reverse_transformer_layers_per_block
580
+ )
581
+ only_cross_attention = list(reversed(only_cross_attention))
582
+
583
+ output_channel = reversed_block_out_channels[0]
584
+ for i, up_block_type in enumerate(up_block_types):
585
+ is_final_block = i == len(block_out_channels) - 1
586
+
587
+ prev_output_channel = output_channel
588
+ output_channel = reversed_block_out_channels[i]
589
+ input_channel = reversed_block_out_channels[
590
+ min(i + 1, len(block_out_channels) - 1)
591
+ ]
592
+
593
+ # add upsample block for all BUT final layer
594
+ if not is_final_block:
595
+ add_upsample = True
596
+ self.num_upsamplers += 1
597
+ else:
598
+ add_upsample = False
599
+
600
+ up_block = get_up_block(
601
+ up_block_type,
602
+ num_layers=reversed_layers_per_block[i] + 1,
603
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
604
+ in_channels=input_channel,
605
+ out_channels=output_channel,
606
+ prev_output_channel=prev_output_channel,
607
+ temb_channels=blocks_time_embed_dim,
608
+ add_upsample=add_upsample,
609
+ resnet_eps=norm_eps,
610
+ resnet_act_fn=act_fn,
611
+ resolution_idx=i,
612
+ resnet_groups=norm_num_groups,
613
+ cross_attention_dim=reversed_cross_attention_dim[i],
614
+ num_attention_heads=reversed_num_attention_heads[i],
615
+ dual_cross_attention=dual_cross_attention,
616
+ use_linear_projection=use_linear_projection,
617
+ only_cross_attention=only_cross_attention[i],
618
+ upcast_attention=upcast_attention,
619
+ resnet_time_scale_shift=resnet_time_scale_shift,
620
+ attention_type=attention_type,
621
+ resnet_skip_time_act=resnet_skip_time_act,
622
+ resnet_out_scale_factor=resnet_out_scale_factor,
623
+ cross_attention_norm=cross_attention_norm,
624
+ attention_head_dim=attention_head_dim[i]
625
+ if attention_head_dim[i] is not None
626
+ else output_channel,
627
+ dropout=dropout,
628
+ )
629
+ self.up_blocks.append(up_block)
630
+ prev_output_channel = output_channel
631
+
632
+ # out
633
+ if norm_num_groups is not None:
634
+ self.conv_norm_out = nn.GroupNorm(
635
+ num_channels=block_out_channels[0],
636
+ num_groups=norm_num_groups,
637
+ eps=norm_eps,
638
+ )
639
+
640
+ self.conv_act = get_activation(act_fn)
641
+
642
+ else:
643
+ self.conv_norm_out = None
644
+ self.conv_act = None
645
+ self.conv_norm_out = None
646
+
647
+ conv_out_padding = (conv_out_kernel - 1) // 2
648
+ # self.conv_out = nn.Conv2d(
649
+ # block_out_channels[0],
650
+ # out_channels,
651
+ # kernel_size=conv_out_kernel,
652
+ # padding=conv_out_padding,
653
+ # )
654
+
655
+ if attention_type in ["gated", "gated-text-image"]:
656
+ positive_len = 768
657
+ if isinstance(cross_attention_dim, int):
658
+ positive_len = cross_attention_dim
659
+ elif isinstance(cross_attention_dim, tuple) or isinstance(
660
+ cross_attention_dim, list
661
+ ):
662
+ positive_len = cross_attention_dim[0]
663
+
664
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
665
+ self.position_net = PositionNet(
666
+ positive_len=positive_len,
667
+ out_dim=cross_attention_dim,
668
+ feature_type=feature_type,
669
+ )
670
+
671
+ @property
672
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
673
+ r"""
674
+ Returns:
675
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
676
+ indexed by its weight name.
677
+ """
678
+ # set recursively
679
+ processors = {}
680
+
681
+ def fn_recursive_add_processors(
682
+ name: str,
683
+ module: torch.nn.Module,
684
+ processors: Dict[str, AttentionProcessor],
685
+ ):
686
+ if hasattr(module, "get_processor"):
687
+ processors[f"{name}.processor"] = module.get_processor(
688
+ return_deprecated_lora=True
689
+ )
690
+
691
+ for sub_name, child in module.named_children():
692
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
693
+
694
+ return processors
695
+
696
+ for name, module in self.named_children():
697
+ fn_recursive_add_processors(name, module, processors)
698
+
699
+ return processors
700
+
701
+ def set_attn_processor(
702
+ self,
703
+ processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
704
+ _remove_lora=False,
705
+ ):
706
+ r"""
707
+ Sets the attention processor to use to compute attention.
708
+
709
+ Parameters:
710
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
711
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
712
+ for **all** `Attention` layers.
713
+
714
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
715
+ processor. This is strongly recommended when setting trainable attention processors.
716
+
717
+ """
718
+ count = len(self.attn_processors.keys())
719
+
720
+ if isinstance(processor, dict) and len(processor) != count:
721
+ raise ValueError(
722
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
723
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
724
+ )
725
+
726
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
727
+ if hasattr(module, "set_processor"):
728
+ if not isinstance(processor, dict):
729
+ module.set_processor(processor, _remove_lora=_remove_lora)
730
+ else:
731
+ module.set_processor(
732
+ processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
733
+ )
734
+
735
+ for sub_name, child in module.named_children():
736
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
737
+
738
+ for name, module in self.named_children():
739
+ fn_recursive_attn_processor(name, module, processor)
740
+
741
+ def set_default_attn_processor(self):
742
+ """
743
+ Disables custom attention processors and sets the default attention implementation.
744
+ """
745
+ if all(
746
+ proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
747
+ for proc in self.attn_processors.values()
748
+ ):
749
+ processor = AttnAddedKVProcessor()
750
+ elif all(
751
+ proc.__class__ in CROSS_ATTENTION_PROCESSORS
752
+ for proc in self.attn_processors.values()
753
+ ):
754
+ processor = AttnProcessor()
755
+ else:
756
+ raise ValueError(
757
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
758
+ )
759
+
760
+ self.set_attn_processor(processor, _remove_lora=True)
761
+
762
+ def set_attention_slice(self, slice_size):
763
+ r"""
764
+ Enable sliced attention computation.
765
+
766
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
767
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
768
+
769
+ Args:
770
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
771
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
772
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
773
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
774
+ must be a multiple of `slice_size`.
775
+ """
776
+ sliceable_head_dims = []
777
+
778
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
779
+ if hasattr(module, "set_attention_slice"):
780
+ sliceable_head_dims.append(module.sliceable_head_dim)
781
+
782
+ for child in module.children():
783
+ fn_recursive_retrieve_sliceable_dims(child)
784
+
785
+ # retrieve number of attention layers
786
+ for module in self.children():
787
+ fn_recursive_retrieve_sliceable_dims(module)
788
+
789
+ num_sliceable_layers = len(sliceable_head_dims)
790
+
791
+ if slice_size == "auto":
792
+ # half the attention head size is usually a good trade-off between
793
+ # speed and memory
794
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
795
+ elif slice_size == "max":
796
+ # make smallest slice possible
797
+ slice_size = num_sliceable_layers * [1]
798
+
799
+ slice_size = (
800
+ num_sliceable_layers * [slice_size]
801
+ if not isinstance(slice_size, list)
802
+ else slice_size
803
+ )
804
+
805
+ if len(slice_size) != len(sliceable_head_dims):
806
+ raise ValueError(
807
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
808
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
809
+ )
810
+
811
+ for i in range(len(slice_size)):
812
+ size = slice_size[i]
813
+ dim = sliceable_head_dims[i]
814
+ if size is not None and size > dim:
815
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
816
+
817
+ # Recursively walk through all the children.
818
+ # Any children which exposes the set_attention_slice method
819
+ # gets the message
820
+ def fn_recursive_set_attention_slice(
821
+ module: torch.nn.Module, slice_size: List[int]
822
+ ):
823
+ if hasattr(module, "set_attention_slice"):
824
+ module.set_attention_slice(slice_size.pop())
825
+
826
+ for child in module.children():
827
+ fn_recursive_set_attention_slice(child, slice_size)
828
+
829
+ reversed_slice_size = list(reversed(slice_size))
830
+ for module in self.children():
831
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
832
+
833
+ def _set_gradient_checkpointing(self, module, value=False):
834
+ if hasattr(module, "gradient_checkpointing"):
835
+ module.gradient_checkpointing = value
836
+
837
+ def enable_freeu(self, s1, s2, b1, b2):
838
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
839
+
840
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
841
+
842
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
843
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
844
+
845
+ Args:
846
+ s1 (`float`):
847
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
848
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
849
+ s2 (`float`):
850
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
851
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
852
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
853
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
854
+ """
855
+ for i, upsample_block in enumerate(self.up_blocks):
856
+ setattr(upsample_block, "s1", s1)
857
+ setattr(upsample_block, "s2", s2)
858
+ setattr(upsample_block, "b1", b1)
859
+ setattr(upsample_block, "b2", b2)
860
+
861
+ def disable_freeu(self):
862
+ """Disables the FreeU mechanism."""
863
+ freeu_keys = {"s1", "s2", "b1", "b2"}
864
+ for i, upsample_block in enumerate(self.up_blocks):
865
+ for k in freeu_keys:
866
+ if (
867
+ hasattr(upsample_block, k)
868
+ or getattr(upsample_block, k, None) is not None
869
+ ):
870
+ setattr(upsample_block, k, None)
871
+
872
+ def forward(
873
+ self,
874
+ sample: torch.FloatTensor,
875
+ timestep: Union[torch.Tensor, float, int],
876
+ encoder_hidden_states: torch.Tensor,
877
+ class_labels: Optional[torch.Tensor] = None,
878
+ timestep_cond: Optional[torch.Tensor] = None,
879
+ attention_mask: Optional[torch.Tensor] = None,
880
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
881
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
882
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
883
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
884
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
885
+ encoder_attention_mask: Optional[torch.Tensor] = None,
886
+ return_dict: bool = True,
887
+ ) -> Union[UNet2DConditionOutput, Tuple]:
888
+ r"""
889
+ The [`UNet2DConditionModel`] forward method.
890
+
891
+ Args:
892
+ sample (`torch.FloatTensor`):
893
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
894
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
895
+ encoder_hidden_states (`torch.FloatTensor`):
896
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
897
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
898
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
899
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
900
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
901
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
902
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
903
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
904
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
905
+ negative values to the attention scores corresponding to "discard" tokens.
906
+ cross_attention_kwargs (`dict`, *optional*):
907
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
908
+ `self.processor` in
909
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
910
+ added_cond_kwargs: (`dict`, *optional*):
911
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
912
+ are passed along to the UNet blocks.
913
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
914
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
915
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
916
+ A tensor that if specified is added to the residual of the middle unet block.
917
+ encoder_attention_mask (`torch.Tensor`):
918
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
919
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
920
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
921
+ return_dict (`bool`, *optional*, defaults to `True`):
922
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
923
+ tuple.
924
+ cross_attention_kwargs (`dict`, *optional*):
925
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
926
+ added_cond_kwargs: (`dict`, *optional*):
927
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
928
+ are passed along to the UNet blocks.
929
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
930
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
931
+ example from ControlNet side model(s)
932
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
933
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
934
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
935
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
936
+
937
+ Returns:
938
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
939
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
940
+ a `tuple` is returned where the first element is the sample tensor.
941
+ """
942
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
943
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
944
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
945
+ # on the fly if necessary.
946
+ default_overall_up_factor = 2**self.num_upsamplers
947
+
948
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
949
+ forward_upsample_size = False
950
+ upsample_size = None
951
+
952
+ for dim in sample.shape[-2:]:
953
+ if dim % default_overall_up_factor != 0:
954
+ # Forward upsample size to force interpolation output size.
955
+ forward_upsample_size = True
956
+ break
957
+
958
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
959
+ # expects mask of shape:
960
+ # [batch, key_tokens]
961
+ # adds singleton query_tokens dimension:
962
+ # [batch, 1, key_tokens]
963
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
964
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
965
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
966
+ if attention_mask is not None:
967
+ # assume that mask is expressed as:
968
+ # (1 = keep, 0 = discard)
969
+ # convert mask into a bias that can be added to attention scores:
970
+ # (keep = +0, discard = -10000.0)
971
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
972
+ attention_mask = attention_mask.unsqueeze(1)
973
+
974
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
975
+ if encoder_attention_mask is not None:
976
+ encoder_attention_mask = (
977
+ 1 - encoder_attention_mask.to(sample.dtype)
978
+ ) * -10000.0
979
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
980
+
981
+ # 0. center input if necessary
982
+ if self.config.center_input_sample:
983
+ sample = 2 * sample - 1.0
984
+
985
+ # 1. time
986
+ timesteps = timestep
987
+ if not torch.is_tensor(timesteps):
988
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
989
+ # This would be a good case for the `match` statement (Python 3.10+)
990
+ is_mps = sample.device.type == "mps"
991
+ if isinstance(timestep, float):
992
+ dtype = torch.float32 if is_mps else torch.float64
993
+ else:
994
+ dtype = torch.int32 if is_mps else torch.int64
995
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
996
+ elif len(timesteps.shape) == 0:
997
+ timesteps = timesteps[None].to(sample.device)
998
+
999
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1000
+ timesteps = timesteps.expand(sample.shape[0])
1001
+
1002
+ t_emb = self.time_proj(timesteps)
1003
+
1004
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1005
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1006
+ # there might be better ways to encapsulate this.
1007
+ t_emb = t_emb.to(dtype=sample.dtype)
1008
+
1009
+ emb = self.time_embedding(t_emb, timestep_cond)
1010
+ aug_emb = None
1011
+
1012
+ if self.class_embedding is not None:
1013
+ if class_labels is None:
1014
+ raise ValueError(
1015
+ "class_labels should be provided when num_class_embeds > 0"
1016
+ )
1017
+
1018
+ if self.config.class_embed_type == "timestep":
1019
+ class_labels = self.time_proj(class_labels)
1020
+
1021
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1022
+ # there might be better ways to encapsulate this.
1023
+ class_labels = class_labels.to(dtype=sample.dtype)
1024
+
1025
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1026
+
1027
+ if self.config.class_embeddings_concat:
1028
+ emb = torch.cat([emb, class_emb], dim=-1)
1029
+ else:
1030
+ emb = emb + class_emb
1031
+
1032
+ if self.config.addition_embed_type == "text":
1033
+ aug_emb = self.add_embedding(encoder_hidden_states)
1034
+ elif self.config.addition_embed_type == "text_image":
1035
+ # Kandinsky 2.1 - style
1036
+ if "image_embeds" not in added_cond_kwargs:
1037
+ raise ValueError(
1038
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1039
+ )
1040
+
1041
+ image_embs = added_cond_kwargs.get("image_embeds")
1042
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1043
+ aug_emb = self.add_embedding(text_embs, image_embs)
1044
+ elif self.config.addition_embed_type == "text_time":
1045
+ # SDXL - style
1046
+ if "text_embeds" not in added_cond_kwargs:
1047
+ raise ValueError(
1048
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1049
+ )
1050
+ text_embeds = added_cond_kwargs.get("text_embeds")
1051
+ if "time_ids" not in added_cond_kwargs:
1052
+ raise ValueError(
1053
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1054
+ )
1055
+ time_ids = added_cond_kwargs.get("time_ids")
1056
+ time_embeds = self.add_time_proj(time_ids.flatten())
1057
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1058
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1059
+ add_embeds = add_embeds.to(emb.dtype)
1060
+ aug_emb = self.add_embedding(add_embeds)
1061
+ elif self.config.addition_embed_type == "image":
1062
+ # Kandinsky 2.2 - style
1063
+ if "image_embeds" not in added_cond_kwargs:
1064
+ raise ValueError(
1065
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1066
+ )
1067
+ image_embs = added_cond_kwargs.get("image_embeds")
1068
+ aug_emb = self.add_embedding(image_embs)
1069
+ elif self.config.addition_embed_type == "image_hint":
1070
+ # Kandinsky 2.2 - style
1071
+ if (
1072
+ "image_embeds" not in added_cond_kwargs
1073
+ or "hint" not in added_cond_kwargs
1074
+ ):
1075
+ raise ValueError(
1076
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1077
+ )
1078
+ image_embs = added_cond_kwargs.get("image_embeds")
1079
+ hint = added_cond_kwargs.get("hint")
1080
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1081
+ sample = torch.cat([sample, hint], dim=1)
1082
+
1083
+ emb = emb + aug_emb if aug_emb is not None else emb
1084
+
1085
+ if self.time_embed_act is not None:
1086
+ emb = self.time_embed_act(emb)
1087
+
1088
+ if (
1089
+ self.encoder_hid_proj is not None
1090
+ and self.config.encoder_hid_dim_type == "text_proj"
1091
+ ):
1092
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1093
+ elif (
1094
+ self.encoder_hid_proj is not None
1095
+ and self.config.encoder_hid_dim_type == "text_image_proj"
1096
+ ):
1097
+ # Kadinsky 2.1 - style
1098
+ if "image_embeds" not in added_cond_kwargs:
1099
+ raise ValueError(
1100
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1101
+ )
1102
+
1103
+ image_embeds = added_cond_kwargs.get("image_embeds")
1104
+ encoder_hidden_states = self.encoder_hid_proj(
1105
+ encoder_hidden_states, image_embeds
1106
+ )
1107
+ elif (
1108
+ self.encoder_hid_proj is not None
1109
+ and self.config.encoder_hid_dim_type == "image_proj"
1110
+ ):
1111
+ # Kandinsky 2.2 - style
1112
+ if "image_embeds" not in added_cond_kwargs:
1113
+ raise ValueError(
1114
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1115
+ )
1116
+ image_embeds = added_cond_kwargs.get("image_embeds")
1117
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1118
+ elif (
1119
+ self.encoder_hid_proj is not None
1120
+ and self.config.encoder_hid_dim_type == "ip_image_proj"
1121
+ ):
1122
+ if "image_embeds" not in added_cond_kwargs:
1123
+ raise ValueError(
1124
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1125
+ )
1126
+ image_embeds = added_cond_kwargs.get("image_embeds")
1127
+ image_embeds = self.encoder_hid_proj(image_embeds).to(
1128
+ encoder_hidden_states.dtype
1129
+ )
1130
+ encoder_hidden_states = torch.cat(
1131
+ [encoder_hidden_states, image_embeds], dim=1
1132
+ )
1133
+
1134
+ # 2. pre-process
1135
+ sample = self.conv_in(sample)
1136
+
1137
+ # 2.5 GLIGEN position net
1138
+ if (
1139
+ cross_attention_kwargs is not None
1140
+ and cross_attention_kwargs.get("gligen", None) is not None
1141
+ ):
1142
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1143
+ gligen_args = cross_attention_kwargs.pop("gligen")
1144
+ cross_attention_kwargs["gligen"] = {
1145
+ "objs": self.position_net(**gligen_args)
1146
+ }
1147
+
1148
+ # 3. down
1149
+ lora_scale = (
1150
+ cross_attention_kwargs.get("scale", 1.0)
1151
+ if cross_attention_kwargs is not None
1152
+ else 1.0
1153
+ )
1154
+ if USE_PEFT_BACKEND:
1155
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1156
+ scale_lora_layers(self, lora_scale)
1157
+
1158
+ is_controlnet = (
1159
+ mid_block_additional_residual is not None
1160
+ and down_block_additional_residuals is not None
1161
+ )
1162
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1163
+ is_adapter = down_intrablock_additional_residuals is not None
1164
+ # maintain backward compatibility for legacy usage, where
1165
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1166
+ # but can only use one or the other
1167
+ if (
1168
+ not is_adapter
1169
+ and mid_block_additional_residual is None
1170
+ and down_block_additional_residuals is not None
1171
+ ):
1172
+ deprecate(
1173
+ "T2I should not use down_block_additional_residuals",
1174
+ "1.3.0",
1175
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1176
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1177
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1178
+ standard_warn=False,
1179
+ )
1180
+ down_intrablock_additional_residuals = down_block_additional_residuals
1181
+ is_adapter = True
1182
+
1183
+ down_block_res_samples = (sample,)
1184
+ tot_referece_features = ()
1185
+ for downsample_block in self.down_blocks:
1186
+ if (
1187
+ hasattr(downsample_block, "has_cross_attention")
1188
+ and downsample_block.has_cross_attention
1189
+ ):
1190
+ # For t2i-adapter CrossAttnDownBlock2D
1191
+ additional_residuals = {}
1192
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1193
+ additional_residuals[
1194
+ "additional_residuals"
1195
+ ] = down_intrablock_additional_residuals.pop(0)
1196
+
1197
+ sample, res_samples = downsample_block(
1198
+ hidden_states=sample,
1199
+ temb=emb,
1200
+ encoder_hidden_states=encoder_hidden_states,
1201
+ attention_mask=attention_mask,
1202
+ cross_attention_kwargs=cross_attention_kwargs,
1203
+ encoder_attention_mask=encoder_attention_mask,
1204
+ **additional_residuals,
1205
+ )
1206
+ else:
1207
+ sample, res_samples = downsample_block(
1208
+ hidden_states=sample, temb=emb, scale=lora_scale
1209
+ )
1210
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1211
+ sample += down_intrablock_additional_residuals.pop(0)
1212
+
1213
+ down_block_res_samples += res_samples
1214
+
1215
+ if is_controlnet:
1216
+ new_down_block_res_samples = ()
1217
+
1218
+ for down_block_res_sample, down_block_additional_residual in zip(
1219
+ down_block_res_samples, down_block_additional_residuals
1220
+ ):
1221
+ down_block_res_sample = (
1222
+ down_block_res_sample + down_block_additional_residual
1223
+ )
1224
+ new_down_block_res_samples = new_down_block_res_samples + (
1225
+ down_block_res_sample,
1226
+ )
1227
+
1228
+ down_block_res_samples = new_down_block_res_samples
1229
+
1230
+ # 4. mid
1231
+ if self.mid_block is not None:
1232
+ if (
1233
+ hasattr(self.mid_block, "has_cross_attention")
1234
+ and self.mid_block.has_cross_attention
1235
+ ):
1236
+ sample = self.mid_block(
1237
+ sample,
1238
+ emb,
1239
+ encoder_hidden_states=encoder_hidden_states,
1240
+ attention_mask=attention_mask,
1241
+ cross_attention_kwargs=cross_attention_kwargs,
1242
+ encoder_attention_mask=encoder_attention_mask,
1243
+ )
1244
+ else:
1245
+ sample = self.mid_block(sample, emb)
1246
+
1247
+ # To support T2I-Adapter-XL
1248
+ if (
1249
+ is_adapter
1250
+ and len(down_intrablock_additional_residuals) > 0
1251
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1252
+ ):
1253
+ sample += down_intrablock_additional_residuals.pop(0)
1254
+
1255
+ if is_controlnet:
1256
+ sample = sample + mid_block_additional_residual
1257
+
1258
+ # 5. up
1259
+ for i, upsample_block in enumerate(self.up_blocks):
1260
+ is_final_block = i == len(self.up_blocks) - 1
1261
+
1262
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1263
+ down_block_res_samples = down_block_res_samples[
1264
+ : -len(upsample_block.resnets)
1265
+ ]
1266
+
1267
+ # if we have not reached the final block and need to forward the
1268
+ # upsample size, we do it here
1269
+ if not is_final_block and forward_upsample_size:
1270
+ upsample_size = down_block_res_samples[-1].shape[2:]
1271
+
1272
+ if (
1273
+ hasattr(upsample_block, "has_cross_attention")
1274
+ and upsample_block.has_cross_attention
1275
+ ):
1276
+ sample = upsample_block(
1277
+ hidden_states=sample,
1278
+ temb=emb,
1279
+ res_hidden_states_tuple=res_samples,
1280
+ encoder_hidden_states=encoder_hidden_states,
1281
+ cross_attention_kwargs=cross_attention_kwargs,
1282
+ upsample_size=upsample_size,
1283
+ attention_mask=attention_mask,
1284
+ encoder_attention_mask=encoder_attention_mask,
1285
+ )
1286
+ else:
1287
+ sample = upsample_block(
1288
+ hidden_states=sample,
1289
+ temb=emb,
1290
+ res_hidden_states_tuple=res_samples,
1291
+ upsample_size=upsample_size,
1292
+ scale=lora_scale,
1293
+ )
1294
+
1295
+ # 6. post-process
1296
+ # if self.conv_norm_out:
1297
+ # sample = self.conv_norm_out(sample)
1298
+ # sample = self.conv_act(sample)
1299
+ # sample = self.conv_out(sample)
1300
+
1301
+ if USE_PEFT_BACKEND:
1302
+ # remove `lora_scale` from each PEFT layer
1303
+ unscale_lora_layers(self, lora_scale)
1304
+
1305
+ if not return_dict:
1306
+ return (sample,)
1307
+
1308
+ return UNet2DConditionOutput(sample=sample)
ref_encoder/unet_3d.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
2
+
3
+ from collections import OrderedDict
4
+ from dataclasses import dataclass
5
+ from os import PathLike
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.models.attention_processor import AttentionProcessor
14
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
15
+ try:
16
+ from diffusers.modeling_utils import ModelMixin
17
+ except:
18
+ from diffusers.models.modeling_utils import ModelMixin
19
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
20
+ from safetensors.torch import load_file
21
+
22
+ from .resnet import InflatedConv3d, InflatedGroupNorm
23
+ from .unet_3d_blocks import UNetMidBlock3DCrossAttn, get_down_block, get_up_block
24
+
25
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
26
+ from einops import rearrange
27
+
28
+
29
+ @dataclass
30
+ class UNet3DConditionOutput(BaseOutput):
31
+ sample: torch.FloatTensor
32
+
33
+
34
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
35
+ _supports_gradient_checkpointing = True
36
+
37
+ @register_to_config
38
+ def __init__(
39
+ self,
40
+ sample_size: Optional[int] = None,
41
+ in_channels: int = 4,
42
+ out_channels: int = 4,
43
+ center_input_sample: bool = False,
44
+ flip_sin_to_cos: bool = True,
45
+ freq_shift: int = 0,
46
+ down_block_types: Tuple[str] = (
47
+ "CrossAttnDownBlock3D",
48
+ "CrossAttnDownBlock3D",
49
+ "CrossAttnDownBlock3D",
50
+ "DownBlock3D",
51
+ ),
52
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
53
+ up_block_types: Tuple[str] = (
54
+ "UpBlock3D",
55
+ "CrossAttnUpBlock3D",
56
+ "CrossAttnUpBlock3D",
57
+ "CrossAttnUpBlock3D",
58
+ ),
59
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
60
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
61
+ layers_per_block: int = 2,
62
+ downsample_padding: int = 1,
63
+ mid_block_scale_factor: float = 1,
64
+ act_fn: str = "silu",
65
+ norm_num_groups: int = 32,
66
+ norm_eps: float = 1e-5,
67
+ cross_attention_dim: int = 1280,
68
+ attention_head_dim: Union[int, Tuple[int]] = 8,
69
+ dual_cross_attention: bool = False,
70
+ use_linear_projection: bool = False,
71
+ class_embed_type: Optional[str] = None,
72
+ num_class_embeds: Optional[int] = None,
73
+ upcast_attention: bool = False,
74
+ resnet_time_scale_shift: str = "default",
75
+ use_inflated_groupnorm=False,
76
+ # Additional
77
+ use_motion_module=False, ######
78
+ motion_module_resolutions=(1, 2, 4, 8), ####
79
+ motion_module_mid_block=False, #####
80
+ motion_module_decoder_only=False, #####
81
+ motion_module_type=None, #####
82
+ motion_module_kwargs={}, #####
83
+ unet_use_cross_frame_attention=None, #####
84
+ unet_use_temporal_attention=None, #####
85
+ mode=None, #####
86
+ task_type="action", #####
87
+ ):
88
+ super().__init__()
89
+
90
+ self.sample_size = sample_size
91
+ time_embed_dim = block_out_channels[0] * 4
92
+
93
+ # input
94
+ self.conv_in = InflatedConv3d(
95
+ in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
96
+ )
97
+
98
+ # time
99
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
100
+ timestep_input_dim = block_out_channels[0]
101
+
102
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
103
+
104
+ # class embedding
105
+ if class_embed_type is None and num_class_embeds is not None:
106
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
107
+ elif class_embed_type == "timestep":
108
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
109
+ elif class_embed_type == "identity":
110
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
111
+ else:
112
+ self.class_embedding = None
113
+
114
+ self.down_blocks = nn.ModuleList([])
115
+ self.mid_block = None
116
+ self.up_blocks = nn.ModuleList([])
117
+
118
+ if isinstance(only_cross_attention, bool):
119
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
120
+
121
+ if isinstance(attention_head_dim, int):
122
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
123
+
124
+ # down
125
+ output_channel = block_out_channels[0]
126
+ for i, down_block_type in enumerate(down_block_types):
127
+ if task_type == "action":
128
+ name_index, mid_name = None, None
129
+ else:
130
+ name_index, mid_name = i, "MidBlock"
131
+ res = 2**i
132
+ input_channel = output_channel
133
+ output_channel = block_out_channels[i]
134
+ is_final_block = i == len(block_out_channels) - 1
135
+
136
+ down_block = get_down_block(
137
+ down_block_type,
138
+ num_layers=layers_per_block,
139
+ in_channels=input_channel,
140
+ out_channels=output_channel,
141
+ temb_channels=time_embed_dim,
142
+ add_downsample=not is_final_block,
143
+ resnet_eps=norm_eps,
144
+ resnet_act_fn=act_fn,
145
+ resnet_groups=norm_num_groups,
146
+ cross_attention_dim=cross_attention_dim,
147
+ attn_num_head_channels=attention_head_dim[i],
148
+ downsample_padding=downsample_padding,
149
+ dual_cross_attention=dual_cross_attention,
150
+ use_linear_projection=use_linear_projection,
151
+ only_cross_attention=only_cross_attention[i],
152
+ upcast_attention=upcast_attention,
153
+ resnet_time_scale_shift=resnet_time_scale_shift,
154
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
155
+ unet_use_temporal_attention=unet_use_temporal_attention,
156
+ use_inflated_groupnorm=use_inflated_groupnorm,
157
+ use_motion_module=use_motion_module
158
+ and (res in motion_module_resolutions)
159
+ and (not motion_module_decoder_only),
160
+ motion_module_type=motion_module_type,
161
+ motion_module_kwargs=motion_module_kwargs,
162
+ name_index=name_index, #####
163
+ )
164
+ self.down_blocks.append(down_block)
165
+
166
+ # mid
167
+
168
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
169
+ self.mid_block = UNetMidBlock3DCrossAttn(
170
+ in_channels=block_out_channels[-1],
171
+ temb_channels=time_embed_dim,
172
+ resnet_eps=norm_eps,
173
+ resnet_act_fn=act_fn,
174
+ output_scale_factor=mid_block_scale_factor,
175
+ resnet_time_scale_shift=resnet_time_scale_shift,
176
+ cross_attention_dim=cross_attention_dim,
177
+ attn_num_head_channels=attention_head_dim[-1],
178
+ resnet_groups=norm_num_groups,
179
+ dual_cross_attention=dual_cross_attention,
180
+ use_linear_projection=use_linear_projection,
181
+ upcast_attention=upcast_attention,
182
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
183
+ unet_use_temporal_attention=unet_use_temporal_attention,
184
+ use_inflated_groupnorm=use_inflated_groupnorm,
185
+ use_motion_module=use_motion_module and motion_module_mid_block,
186
+ motion_module_type=motion_module_type,
187
+ motion_module_kwargs=motion_module_kwargs,
188
+ name=mid_name,
189
+ )
190
+ else:
191
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
192
+
193
+ # count how many layers upsample the videos
194
+ self.num_upsamplers = 0
195
+
196
+ # up
197
+ reversed_block_out_channels = list(reversed(block_out_channels))
198
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
199
+ only_cross_attention = list(reversed(only_cross_attention))
200
+ output_channel = reversed_block_out_channels[0]
201
+ for i, up_block_type in enumerate(up_block_types):
202
+ res = 2 ** (3 - i)
203
+ is_final_block = i == len(block_out_channels) - 1
204
+
205
+ if task_type == "action":
206
+ name_index = None
207
+ else:
208
+ name_index = i
209
+
210
+ prev_output_channel = output_channel
211
+ output_channel = reversed_block_out_channels[i]
212
+ input_channel = reversed_block_out_channels[
213
+ min(i + 1, len(block_out_channels) - 1)
214
+ ]
215
+
216
+ # add upsample block for all BUT final layer
217
+ if not is_final_block:
218
+ add_upsample = True
219
+ self.num_upsamplers += 1
220
+ else:
221
+ add_upsample = False
222
+
223
+ up_block = get_up_block(
224
+ up_block_type,
225
+ num_layers=layers_per_block + 1,
226
+ in_channels=input_channel,
227
+ out_channels=output_channel,
228
+ prev_output_channel=prev_output_channel,
229
+ temb_channels=time_embed_dim,
230
+ add_upsample=add_upsample,
231
+ resnet_eps=norm_eps,
232
+ resnet_act_fn=act_fn,
233
+ resnet_groups=norm_num_groups,
234
+ cross_attention_dim=cross_attention_dim,
235
+ attn_num_head_channels=reversed_attention_head_dim[i],
236
+ dual_cross_attention=dual_cross_attention,
237
+ use_linear_projection=use_linear_projection,
238
+ only_cross_attention=only_cross_attention[i],
239
+ upcast_attention=upcast_attention,
240
+ resnet_time_scale_shift=resnet_time_scale_shift,
241
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
242
+ unet_use_temporal_attention=unet_use_temporal_attention,
243
+ use_inflated_groupnorm=use_inflated_groupnorm,
244
+ use_motion_module=use_motion_module
245
+ and (res in motion_module_resolutions),
246
+ motion_module_type=motion_module_type,
247
+ motion_module_kwargs=motion_module_kwargs,
248
+ name_index=name_index,
249
+ )
250
+ self.up_blocks.append(up_block)
251
+ prev_output_channel = output_channel
252
+
253
+ # out
254
+ if use_inflated_groupnorm:
255
+ self.conv_norm_out = InflatedGroupNorm(
256
+ num_channels=block_out_channels[0],
257
+ num_groups=norm_num_groups,
258
+ eps=norm_eps,
259
+ )
260
+ else:
261
+ self.conv_norm_out = nn.GroupNorm(
262
+ num_channels=block_out_channels[0],
263
+ num_groups=norm_num_groups,
264
+ eps=norm_eps,
265
+ )
266
+ self.conv_act = nn.SiLU()
267
+ self.conv_out = InflatedConv3d(
268
+ block_out_channels[0], out_channels, kernel_size=3, padding=1
269
+ )
270
+
271
+ self.mode = mode
272
+
273
+ @property
274
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
275
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
276
+ r"""
277
+ Returns:
278
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
279
+ indexed by its weight name.
280
+ """
281
+ # set recursively
282
+ processors = {}
283
+
284
+ def fn_recursive_add_processors(
285
+ name: str,
286
+ module: torch.nn.Module,
287
+ processors: Dict[str, AttentionProcessor],
288
+ ):
289
+ if hasattr(module, "set_processor"):
290
+ processors[f"{name}.processor"] = module.processor
291
+
292
+ for sub_name, child in module.named_children():
293
+ if "temporal_transformer" not in sub_name:
294
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
295
+
296
+ return processors
297
+
298
+ for name, module in self.named_children():
299
+ if "temporal_transformer" not in name:
300
+ fn_recursive_add_processors(name, module, processors)
301
+
302
+ return processors
303
+
304
+ def set_attention_slice(self, slice_size):
305
+ r"""
306
+ Enable sliced attention computation.
307
+
308
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
309
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
310
+
311
+ Args:
312
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
313
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
314
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
315
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
316
+ must be a multiple of `slice_size`.
317
+ """
318
+ sliceable_head_dims = []
319
+
320
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
321
+ if hasattr(module, "set_attention_slice"):
322
+ sliceable_head_dims.append(module.sliceable_head_dim)
323
+
324
+ for child in module.children():
325
+ fn_recursive_retrieve_slicable_dims(child)
326
+
327
+ # retrieve number of attention layers
328
+ for module in self.children():
329
+ fn_recursive_retrieve_slicable_dims(module)
330
+
331
+ num_slicable_layers = len(sliceable_head_dims)
332
+
333
+ if slice_size == "auto":
334
+ # half the attention head size is usually a good trade-off between
335
+ # speed and memory
336
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
337
+ elif slice_size == "max":
338
+ # make smallest slice possible
339
+ slice_size = num_slicable_layers * [1]
340
+
341
+ slice_size = (
342
+ num_slicable_layers * [slice_size]
343
+ if not isinstance(slice_size, list)
344
+ else slice_size
345
+ )
346
+
347
+ if len(slice_size) != len(sliceable_head_dims):
348
+ raise ValueError(
349
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
350
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
351
+ )
352
+
353
+ for i in range(len(slice_size)):
354
+ size = slice_size[i]
355
+ dim = sliceable_head_dims[i]
356
+ if size is not None and size > dim:
357
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
358
+
359
+ # Recursively walk through all the children.
360
+ # Any children which exposes the set_attention_slice method
361
+ # gets the message
362
+ def fn_recursive_set_attention_slice(
363
+ module: torch.nn.Module, slice_size: List[int]
364
+ ):
365
+ if hasattr(module, "set_attention_slice"):
366
+ module.set_attention_slice(slice_size.pop())
367
+
368
+ for child in module.children():
369
+ fn_recursive_set_attention_slice(child, slice_size)
370
+
371
+ reversed_slice_size = list(reversed(slice_size))
372
+ for module in self.children():
373
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
374
+
375
+ def _set_gradient_checkpointing(self, module, value=False):
376
+ if hasattr(module, "gradient_checkpointing"):
377
+ module.gradient_checkpointing = value
378
+
379
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
380
+ def set_attn_processor(
381
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
382
+ ):
383
+ r"""
384
+ Sets the attention processor to use to compute attention.
385
+
386
+ Parameters:
387
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
388
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
389
+ for **all** `Attention` layers.
390
+
391
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
392
+ processor. This is strongly recommended when setting trainable attention processors.
393
+
394
+ """
395
+ count = len(self.attn_processors.keys())
396
+
397
+ if isinstance(processor, dict) and len(processor) != count:
398
+ raise ValueError(
399
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
400
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
401
+ )
402
+
403
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
404
+ if hasattr(module, "set_processor"):
405
+ if not isinstance(processor, dict):
406
+ module.set_processor(processor)
407
+ else:
408
+ module.set_processor(processor.pop(f"{name}.processor"))
409
+
410
+ for sub_name, child in module.named_children():
411
+ if "temporal_transformer" not in sub_name:
412
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
413
+
414
+ for name, module in self.named_children():
415
+ if "temporal_transformer" not in name:
416
+ fn_recursive_attn_processor(name, module, processor)
417
+
418
+ def forward(
419
+ self,
420
+ sample: torch.FloatTensor,
421
+ timestep: Union[torch.Tensor, float, int],
422
+ encoder_hidden_states: torch.Tensor,
423
+ class_labels: Optional[torch.Tensor] = None,
424
+ pose_cond_fea: Optional[torch.Tensor] = None,
425
+ attention_mask: Optional[torch.Tensor] = None,
426
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
427
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
428
+ return_dict: bool = True,
429
+ self_attention_additional_feats = None,
430
+ ) -> Union[UNet3DConditionOutput, Tuple]:
431
+ r"""
432
+ Args:
433
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
434
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
435
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
436
+ return_dict (`bool`, *optional*, defaults to `True`):
437
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
438
+
439
+ Returns:
440
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
441
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
442
+ returning a tuple, the first element is the sample tensor.
443
+ """
444
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
445
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
446
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
447
+ # on the fly if necessary.
448
+ default_overall_up_factor = 2**self.num_upsamplers
449
+
450
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
451
+ forward_upsample_size = False
452
+ upsample_size = None
453
+
454
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
455
+ logger.info("Forward upsample size to force interpolation output size.")
456
+ forward_upsample_size = True
457
+
458
+ # prepare attention_mask
459
+ if attention_mask is not None:
460
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
461
+ attention_mask = attention_mask.unsqueeze(1)
462
+
463
+ # center input if necessary
464
+ if self.config.center_input_sample:
465
+ sample = 2 * sample - 1.0
466
+
467
+ # time
468
+ timesteps = timestep
469
+ if not torch.is_tensor(timesteps):
470
+ # This would be a good case for the `match` statement (Python 3.10+)
471
+ is_mps = sample.device.type == "mps"
472
+ if isinstance(timestep, float):
473
+ dtype = torch.float32 if is_mps else torch.float64
474
+ else:
475
+ dtype = torch.int32 if is_mps else torch.int64
476
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
477
+ elif len(timesteps.shape) == 0:
478
+ timesteps = timesteps[None].to(sample.device)
479
+
480
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
481
+ timesteps = timesteps.expand(sample.shape[0])
482
+
483
+ t_emb = self.time_proj(timesteps)
484
+
485
+ # timesteps does not contain any weights and will always return f32 tensors
486
+ # but time_embedding might actually be running in fp16. so we need to cast here.
487
+ # there might be better ways to encapsulate this.
488
+ t_emb = t_emb.to(dtype=self.dtype)
489
+ emb = self.time_embedding(t_emb)
490
+
491
+ if self.class_embedding is not None:
492
+ if class_labels is None:
493
+ raise ValueError(
494
+ "class_labels should be provided when num_class_embeds > 0"
495
+ )
496
+
497
+ if self.config.class_embed_type == "timestep":
498
+ class_labels = self.time_proj(class_labels)
499
+
500
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
501
+ emb = emb + class_emb
502
+
503
+ # pre-process
504
+ sample = self.conv_in(sample)
505
+ if pose_cond_fea is not None:
506
+ sample = sample + pose_cond_fea
507
+
508
+ # down
509
+ down_block_res_samples = (sample,)
510
+ for downsample_block in self.down_blocks:
511
+ if (
512
+ hasattr(downsample_block, "has_cross_attention")
513
+ and downsample_block.has_cross_attention
514
+ ):
515
+ sample, res_samples = downsample_block(
516
+ hidden_states=sample,
517
+ temb=emb,
518
+ encoder_hidden_states=encoder_hidden_states,
519
+ attention_mask=attention_mask,
520
+ self_attention_additional_feats=self_attention_additional_feats,
521
+ mode=self.mode,
522
+ )
523
+ else:
524
+ sample, res_samples = downsample_block(
525
+ hidden_states=sample,
526
+ temb=emb,
527
+ encoder_hidden_states=encoder_hidden_states,
528
+ )
529
+
530
+ down_block_res_samples += res_samples
531
+
532
+ if down_block_additional_residuals is not None:
533
+ new_down_block_res_samples = ()
534
+
535
+ for down_block_res_sample, down_block_additional_residual in zip(
536
+ down_block_res_samples, down_block_additional_residuals
537
+ ):
538
+ down_block_additional_residual = rearrange(
539
+ down_block_additional_residual, "(b f) c h w -> b c f h w", f=down_block_res_sample.shape[2]
540
+ )
541
+ down_block_res_sample = (
542
+ down_block_res_sample + down_block_additional_residual
543
+ )
544
+ new_down_block_res_samples += (down_block_res_sample,)
545
+
546
+ down_block_res_samples = new_down_block_res_samples
547
+
548
+ # mid
549
+ sample = self.mid_block(
550
+ sample,
551
+ emb,
552
+ encoder_hidden_states=encoder_hidden_states,
553
+ attention_mask=attention_mask,
554
+ self_attention_additional_feats=self_attention_additional_feats,
555
+ mode=self.mode,
556
+ )
557
+
558
+ if mid_block_additional_residual is not None:
559
+ mid_block_additional_residual = rearrange(
560
+ mid_block_additional_residual, "(b f) c h w -> b c f h w", f=sample.shape[2]
561
+ )
562
+ sample = sample + mid_block_additional_residual
563
+
564
+ # up
565
+ for i, upsample_block in enumerate(self.up_blocks):
566
+ is_final_block = i == len(self.up_blocks) - 1
567
+
568
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
569
+ down_block_res_samples = down_block_res_samples[
570
+ : -len(upsample_block.resnets)
571
+ ]
572
+
573
+ # if we have not reached the final block and need to forward the
574
+ # upsample size, we do it here
575
+ if not is_final_block and forward_upsample_size:
576
+ upsample_size = down_block_res_samples[-1].shape[2:]
577
+
578
+ if (
579
+ hasattr(upsample_block, "has_cross_attention")
580
+ and upsample_block.has_cross_attention
581
+ ):
582
+ sample = upsample_block(
583
+ hidden_states=sample,
584
+ temb=emb,
585
+ res_hidden_states_tuple=res_samples,
586
+ encoder_hidden_states=encoder_hidden_states,
587
+ upsample_size=upsample_size,
588
+ attention_mask=attention_mask,
589
+ self_attention_additional_feats=self_attention_additional_feats,
590
+ mode=self.mode,
591
+ )
592
+ else:
593
+ sample = upsample_block(
594
+ hidden_states=sample,
595
+ temb=emb,
596
+ res_hidden_states_tuple=res_samples,
597
+ upsample_size=upsample_size,
598
+ encoder_hidden_states=encoder_hidden_states,
599
+ )
600
+
601
+ # post-process
602
+ sample = self.conv_norm_out(sample)
603
+ sample = self.conv_act(sample)
604
+ sample = self.conv_out(sample)
605
+
606
+ if not return_dict:
607
+ return (sample,)
608
+
609
+ return UNet3DConditionOutput(sample=sample)
610
+
611
+ @classmethod
612
+ def from_pretrained_2d(
613
+ cls,
614
+ pretrained_model_path: PathLike,
615
+ motion_module_path: PathLike,
616
+ subfolder=None,
617
+ unet_additional_kwargs=None,
618
+ mm_zero_proj_out=False,
619
+ ):
620
+ pretrained_model_path = Path(pretrained_model_path)
621
+ motion_module_path = Path(motion_module_path)
622
+ if subfolder is not None:
623
+ pretrained_model_path = pretrained_model_path.joinpath(subfolder)
624
+ logger.info(
625
+ f"loaded temporal unet's pretrained weights from {pretrained_model_path} ..."
626
+ )
627
+
628
+ config_file = pretrained_model_path / "config.json"
629
+ if not (config_file.exists() and config_file.is_file()):
630
+ raise RuntimeError(f"{config_file} does not exist or is not a file")
631
+
632
+ unet_config = cls.load_config(config_file)
633
+ unet_config["_class_name"] = cls.__name__
634
+ unet_config["down_block_types"] = [
635
+ "CrossAttnDownBlock3D",
636
+ "CrossAttnDownBlock3D",
637
+ "CrossAttnDownBlock3D",
638
+ "DownBlock3D",
639
+ ]
640
+ unet_config["up_block_types"] = [
641
+ "UpBlock3D",
642
+ "CrossAttnUpBlock3D",
643
+ "CrossAttnUpBlock3D",
644
+ "CrossAttnUpBlock3D",
645
+ ]
646
+ unet_config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
647
+
648
+ model = cls.from_config(unet_config, **unet_additional_kwargs)
649
+ # load the vanilla weights
650
+ if pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME).exists():
651
+ logger.debug(
652
+ f"loading safeTensors weights from {pretrained_model_path} ..."
653
+ )
654
+ state_dict = load_file(
655
+ pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME), device="cpu"
656
+ )
657
+
658
+ elif pretrained_model_path.joinpath(WEIGHTS_NAME).exists():
659
+ logger.debug(f"loading weights from {pretrained_model_path} ...")
660
+ state_dict = torch.load(
661
+ pretrained_model_path.joinpath(WEIGHTS_NAME),
662
+ map_location="cpu",
663
+ weights_only=True,
664
+ )
665
+ else:
666
+ raise FileNotFoundError(f"no weights file found in {pretrained_model_path}")
667
+
668
+ # load the motion module weights
669
+ if motion_module_path.exists() and motion_module_path.is_file():
670
+ if motion_module_path.suffix.lower() in [".pth", ".pt", ".ckpt"]:
671
+ logger.info(f"Load motion module params from {motion_module_path}")
672
+ motion_state_dict = torch.load(
673
+ motion_module_path, map_location="cpu", weights_only=True
674
+ )
675
+ elif motion_module_path.suffix.lower() == ".safetensors":
676
+ motion_state_dict = load_file(motion_module_path, device="cpu")
677
+ else:
678
+ raise RuntimeError(
679
+ f"unknown file format for motion module weights: {motion_module_path.suffix}"
680
+ )
681
+ if mm_zero_proj_out:
682
+ logger.info(f"Zero initialize proj_out layers in motion module...")
683
+ new_motion_state_dict = OrderedDict()
684
+ for k in motion_state_dict:
685
+ if "proj_out" in k:
686
+ continue
687
+ new_motion_state_dict[k] = motion_state_dict[k]
688
+ motion_state_dict = new_motion_state_dict
689
+
690
+ # merge the state dicts
691
+ state_dict.update(motion_state_dict)
692
+
693
+ # load the weights into the model
694
+ m, u = model.load_state_dict(state_dict, strict=False)
695
+ logger.debug(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
696
+
697
+ params = [
698
+ p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()
699
+ ]
700
+ logger.info(f"Loaded {sum(params) / 1e6}M-parameter motion module")
701
+
702
+ return model
ref_encoder/unet_3d_blocks.py ADDED
@@ -0,0 +1,906 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import pdb
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from .motion_module import get_motion_module
9
+
10
+ # from .motion_module import get_motion_module
11
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
12
+ from .transformer_3d import Transformer3DModel
13
+
14
+
15
+ def get_down_block(
16
+ down_block_type,
17
+ num_layers,
18
+ in_channels,
19
+ out_channels,
20
+ temb_channels,
21
+ add_downsample,
22
+ resnet_eps,
23
+ resnet_act_fn,
24
+ attn_num_head_channels,
25
+ resnet_groups=None,
26
+ cross_attention_dim=None,
27
+ downsample_padding=None,
28
+ dual_cross_attention=False,
29
+ use_linear_projection=False,
30
+ only_cross_attention=False,
31
+ upcast_attention=False,
32
+ resnet_time_scale_shift="default",
33
+ unet_use_cross_frame_attention=None,
34
+ unet_use_temporal_attention=None,
35
+ use_inflated_groupnorm=None,
36
+ use_motion_module=None,
37
+ motion_module_type=None,
38
+ motion_module_kwargs=None,
39
+ name_index=None,
40
+ ):
41
+ down_block_type = (
42
+ down_block_type[7:]
43
+ if down_block_type.startswith("UNetRes")
44
+ else down_block_type
45
+ )
46
+ if down_block_type == "DownBlock3D":
47
+ return DownBlock3D(
48
+ num_layers=num_layers,
49
+ in_channels=in_channels,
50
+ out_channels=out_channels,
51
+ temb_channels=temb_channels,
52
+ add_downsample=add_downsample,
53
+ resnet_eps=resnet_eps,
54
+ resnet_act_fn=resnet_act_fn,
55
+ resnet_groups=resnet_groups,
56
+ downsample_padding=downsample_padding,
57
+ resnet_time_scale_shift=resnet_time_scale_shift,
58
+ use_inflated_groupnorm=use_inflated_groupnorm,
59
+ use_motion_module=use_motion_module,
60
+ motion_module_type=motion_module_type,
61
+ motion_module_kwargs=motion_module_kwargs,
62
+ )
63
+ elif down_block_type == "CrossAttnDownBlock3D":
64
+ if cross_attention_dim is None:
65
+ raise ValueError(
66
+ "cross_attention_dim must be specified for CrossAttnDownBlock3D"
67
+ )
68
+ if name_index is not None:
69
+ name_index = f"CrossAttnDownBlock_{name_index}_"
70
+ return CrossAttnDownBlock3D(
71
+ num_layers=num_layers,
72
+ in_channels=in_channels,
73
+ out_channels=out_channels,
74
+ temb_channels=temb_channels,
75
+ add_downsample=add_downsample,
76
+ resnet_eps=resnet_eps,
77
+ resnet_act_fn=resnet_act_fn,
78
+ resnet_groups=resnet_groups,
79
+ downsample_padding=downsample_padding,
80
+ cross_attention_dim=cross_attention_dim,
81
+ attn_num_head_channels=attn_num_head_channels,
82
+ dual_cross_attention=dual_cross_attention,
83
+ use_linear_projection=use_linear_projection,
84
+ only_cross_attention=only_cross_attention,
85
+ upcast_attention=upcast_attention,
86
+ resnet_time_scale_shift=resnet_time_scale_shift,
87
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
88
+ unet_use_temporal_attention=unet_use_temporal_attention,
89
+ use_inflated_groupnorm=use_inflated_groupnorm,
90
+ use_motion_module=use_motion_module,
91
+ motion_module_type=motion_module_type,
92
+ motion_module_kwargs=motion_module_kwargs,
93
+ name=name_index,
94
+ )
95
+ raise ValueError(f"{down_block_type} does not exist.")
96
+
97
+
98
+ def get_up_block(
99
+ up_block_type,
100
+ num_layers,
101
+ in_channels,
102
+ out_channels,
103
+ prev_output_channel,
104
+ temb_channels,
105
+ add_upsample,
106
+ resnet_eps,
107
+ resnet_act_fn,
108
+ attn_num_head_channels,
109
+ resnet_groups=None,
110
+ cross_attention_dim=None,
111
+ dual_cross_attention=False,
112
+ use_linear_projection=False,
113
+ only_cross_attention=False,
114
+ upcast_attention=False,
115
+ resnet_time_scale_shift="default",
116
+ unet_use_cross_frame_attention=None,
117
+ unet_use_temporal_attention=None,
118
+ use_inflated_groupnorm=None,
119
+ use_motion_module=None,
120
+ motion_module_type=None,
121
+ motion_module_kwargs=None,
122
+ name_index=None,
123
+ ):
124
+ up_block_type = (
125
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
126
+ )
127
+ if up_block_type == "UpBlock3D":
128
+ return UpBlock3D(
129
+ num_layers=num_layers,
130
+ in_channels=in_channels,
131
+ out_channels=out_channels,
132
+ prev_output_channel=prev_output_channel,
133
+ temb_channels=temb_channels,
134
+ add_upsample=add_upsample,
135
+ resnet_eps=resnet_eps,
136
+ resnet_act_fn=resnet_act_fn,
137
+ resnet_groups=resnet_groups,
138
+ resnet_time_scale_shift=resnet_time_scale_shift,
139
+ use_inflated_groupnorm=use_inflated_groupnorm,
140
+ use_motion_module=use_motion_module,
141
+ motion_module_type=motion_module_type,
142
+ motion_module_kwargs=motion_module_kwargs,
143
+ )
144
+ elif up_block_type == "CrossAttnUpBlock3D":
145
+ if cross_attention_dim is None:
146
+ raise ValueError(
147
+ "cross_attention_dim must be specified for CrossAttnUpBlock3D"
148
+ )
149
+ if name_index is not None:
150
+ name_index = f"CrossAttnUpBlock_{name_index}_"
151
+ return CrossAttnUpBlock3D(
152
+ num_layers=num_layers,
153
+ in_channels=in_channels,
154
+ out_channels=out_channels,
155
+ prev_output_channel=prev_output_channel,
156
+ temb_channels=temb_channels,
157
+ add_upsample=add_upsample,
158
+ resnet_eps=resnet_eps,
159
+ resnet_act_fn=resnet_act_fn,
160
+ resnet_groups=resnet_groups,
161
+ cross_attention_dim=cross_attention_dim,
162
+ attn_num_head_channels=attn_num_head_channels,
163
+ dual_cross_attention=dual_cross_attention,
164
+ use_linear_projection=use_linear_projection,
165
+ only_cross_attention=only_cross_attention,
166
+ upcast_attention=upcast_attention,
167
+ resnet_time_scale_shift=resnet_time_scale_shift,
168
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
169
+ unet_use_temporal_attention=unet_use_temporal_attention,
170
+ use_inflated_groupnorm=use_inflated_groupnorm,
171
+ use_motion_module=use_motion_module,
172
+ motion_module_type=motion_module_type,
173
+ motion_module_kwargs=motion_module_kwargs,
174
+ name=name_index,
175
+ )
176
+ raise ValueError(f"{up_block_type} does not exist.")
177
+
178
+
179
+ class UNetMidBlock3DCrossAttn(nn.Module):
180
+ def __init__(
181
+ self,
182
+ in_channels: int,
183
+ temb_channels: int,
184
+ dropout: float = 0.0,
185
+ num_layers: int = 1,
186
+ resnet_eps: float = 1e-6,
187
+ resnet_time_scale_shift: str = "default",
188
+ resnet_act_fn: str = "swish",
189
+ resnet_groups: int = 32,
190
+ resnet_pre_norm: bool = True,
191
+ attn_num_head_channels=1,
192
+ output_scale_factor=1.0,
193
+ cross_attention_dim=1280,
194
+ dual_cross_attention=False,
195
+ use_linear_projection=False,
196
+ upcast_attention=False,
197
+ unet_use_cross_frame_attention=None,
198
+ unet_use_temporal_attention=None,
199
+ use_inflated_groupnorm=None,
200
+ use_motion_module=None,
201
+ motion_module_type=None,
202
+ motion_module_kwargs=None,
203
+ name=None
204
+ ):
205
+ super().__init__()
206
+
207
+ self.has_cross_attention = True
208
+ self.attn_num_head_channels = attn_num_head_channels
209
+ resnet_groups = (
210
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
211
+ )
212
+ self.name = name
213
+ # there is always at least one resnet
214
+ resnets = [
215
+ ResnetBlock3D(
216
+ in_channels=in_channels,
217
+ out_channels=in_channels,
218
+ temb_channels=temb_channels,
219
+ eps=resnet_eps,
220
+ groups=resnet_groups,
221
+ dropout=dropout,
222
+ time_embedding_norm=resnet_time_scale_shift,
223
+ non_linearity=resnet_act_fn,
224
+ output_scale_factor=output_scale_factor,
225
+ pre_norm=resnet_pre_norm,
226
+ use_inflated_groupnorm=use_inflated_groupnorm,
227
+ )
228
+ ]
229
+ attentions = []
230
+ motion_modules = []
231
+ for i in range(num_layers):
232
+ if dual_cross_attention:
233
+ raise NotImplementedError
234
+ if self.name is not None:
235
+ attn_name = f"{self.name}_{i}_TransformerModel"
236
+ else:
237
+ attn_name = None
238
+ attentions.append(
239
+ Transformer3DModel(
240
+ attn_num_head_channels,
241
+ in_channels // attn_num_head_channels,
242
+ in_channels=in_channels,
243
+ num_layers=1,
244
+ cross_attention_dim=cross_attention_dim,
245
+ norm_num_groups=resnet_groups,
246
+ use_linear_projection=use_linear_projection,
247
+ upcast_attention=upcast_attention,
248
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
249
+ unet_use_temporal_attention=unet_use_temporal_attention,
250
+ name=attn_name,
251
+ )
252
+ )
253
+ motion_modules.append(
254
+ get_motion_module(
255
+ in_channels=in_channels,
256
+ motion_module_type=motion_module_type,
257
+ motion_module_kwargs=motion_module_kwargs,
258
+ )
259
+ if use_motion_module
260
+ else None
261
+ )
262
+ resnets.append(
263
+ ResnetBlock3D(
264
+ in_channels=in_channels,
265
+ out_channels=in_channels,
266
+ temb_channels=temb_channels,
267
+ eps=resnet_eps,
268
+ groups=resnet_groups,
269
+ dropout=dropout,
270
+ time_embedding_norm=resnet_time_scale_shift,
271
+ non_linearity=resnet_act_fn,
272
+ output_scale_factor=output_scale_factor,
273
+ pre_norm=resnet_pre_norm,
274
+ use_inflated_groupnorm=use_inflated_groupnorm,
275
+ )
276
+ )
277
+
278
+ self.attentions = nn.ModuleList(attentions)
279
+ self.resnets = nn.ModuleList(resnets)
280
+ self.motion_modules = nn.ModuleList(motion_modules)
281
+
282
+ def forward(
283
+ self,
284
+ hidden_states,
285
+ temb=None,
286
+ encoder_hidden_states=None,
287
+ attention_mask=None,
288
+ self_attention_additional_feats=None,
289
+ mode=None,
290
+ ):
291
+ hidden_states = self.resnets[0](hidden_states, temb)
292
+ for attn, resnet, motion_module in zip(
293
+ self.attentions, self.resnets[1:], self.motion_modules
294
+ ):
295
+ hidden_states = attn(
296
+ hidden_states,
297
+ encoder_hidden_states=encoder_hidden_states,
298
+ self_attention_additional_feats=self_attention_additional_feats,
299
+ mode=mode,
300
+ ).sample
301
+ hidden_states = (
302
+ motion_module(
303
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
304
+ )
305
+ if motion_module is not None
306
+ else hidden_states
307
+ )
308
+ hidden_states = resnet(hidden_states, temb)
309
+
310
+ return hidden_states
311
+
312
+
313
+ class CrossAttnDownBlock3D(nn.Module):
314
+ def __init__(
315
+ self,
316
+ in_channels: int,
317
+ out_channels: int,
318
+ temb_channels: int,
319
+ dropout: float = 0.0,
320
+ num_layers: int = 1,
321
+ resnet_eps: float = 1e-6,
322
+ resnet_time_scale_shift: str = "default",
323
+ resnet_act_fn: str = "swish",
324
+ resnet_groups: int = 32,
325
+ resnet_pre_norm: bool = True,
326
+ attn_num_head_channels=1,
327
+ cross_attention_dim=1280,
328
+ output_scale_factor=1.0,
329
+ downsample_padding=1,
330
+ add_downsample=True,
331
+ dual_cross_attention=False,
332
+ use_linear_projection=False,
333
+ only_cross_attention=False,
334
+ upcast_attention=False,
335
+ unet_use_cross_frame_attention=None,
336
+ unet_use_temporal_attention=None,
337
+ use_inflated_groupnorm=None,
338
+ use_motion_module=None,
339
+ motion_module_type=None,
340
+ motion_module_kwargs=None,
341
+ name=None,
342
+ ):
343
+ super().__init__()
344
+ resnets = []
345
+ attentions = []
346
+ motion_modules = []
347
+
348
+ self.has_cross_attention = True
349
+ self.attn_num_head_channels = attn_num_head_channels
350
+ self.name=name
351
+
352
+ for i in range(num_layers):
353
+ in_channels = in_channels if i == 0 else out_channels
354
+ resnets.append(
355
+ ResnetBlock3D(
356
+ in_channels=in_channels,
357
+ out_channels=out_channels,
358
+ temb_channels=temb_channels,
359
+ eps=resnet_eps,
360
+ groups=resnet_groups,
361
+ dropout=dropout,
362
+ time_embedding_norm=resnet_time_scale_shift,
363
+ non_linearity=resnet_act_fn,
364
+ output_scale_factor=output_scale_factor,
365
+ pre_norm=resnet_pre_norm,
366
+ use_inflated_groupnorm=use_inflated_groupnorm,
367
+ )
368
+ )
369
+ if dual_cross_attention:
370
+ raise NotImplementedError
371
+ if self.name is not None:
372
+ attn_name = f"{self.name}_{i}_TransformerModel"
373
+ else:
374
+ attn_name = None
375
+ attentions.append(
376
+ Transformer3DModel(
377
+ attn_num_head_channels,
378
+ out_channels // attn_num_head_channels,
379
+ in_channels=out_channels,
380
+ num_layers=1,
381
+ cross_attention_dim=cross_attention_dim,
382
+ norm_num_groups=resnet_groups,
383
+ use_linear_projection=use_linear_projection,
384
+ only_cross_attention=only_cross_attention,
385
+ upcast_attention=upcast_attention,
386
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
387
+ unet_use_temporal_attention=unet_use_temporal_attention,
388
+ name=attn_name,
389
+ )
390
+ )
391
+ motion_modules.append(
392
+ get_motion_module(
393
+ in_channels=out_channels,
394
+ motion_module_type=motion_module_type,
395
+ motion_module_kwargs=motion_module_kwargs,
396
+ )
397
+ if use_motion_module
398
+ else None
399
+ )
400
+
401
+ self.attentions = nn.ModuleList(attentions)
402
+ self.resnets = nn.ModuleList(resnets)
403
+ self.motion_modules = nn.ModuleList(motion_modules)
404
+
405
+ if add_downsample:
406
+ self.downsamplers = nn.ModuleList(
407
+ [
408
+ Downsample3D(
409
+ out_channels,
410
+ use_conv=True,
411
+ out_channels=out_channels,
412
+ padding=downsample_padding,
413
+ name="op",
414
+ )
415
+ ]
416
+ )
417
+ else:
418
+ self.downsamplers = None
419
+
420
+ self.gradient_checkpointing = False
421
+
422
+ def forward(
423
+ self,
424
+ hidden_states,
425
+ temb=None,
426
+ encoder_hidden_states=None,
427
+ attention_mask=None,
428
+ self_attention_additional_feats=None,
429
+ mode=None,
430
+ ):
431
+ output_states = ()
432
+
433
+ for i, (resnet, attn, motion_module) in enumerate(
434
+ zip(self.resnets, self.attentions, self.motion_modules)
435
+ ):
436
+ # self.gradient_checkpointing = False
437
+ if self.training and self.gradient_checkpointing:
438
+
439
+ def create_custom_forward(module, return_dict=None):
440
+ def custom_forward(*inputs):
441
+ if return_dict is not None:
442
+ return module(*inputs, return_dict=return_dict)
443
+ else:
444
+ return module(*inputs)
445
+
446
+ return custom_forward
447
+
448
+ hidden_states = torch.utils.checkpoint.checkpoint(
449
+ create_custom_forward(resnet), hidden_states, temb
450
+ )
451
+ hidden_states = torch.utils.checkpoint.checkpoint(
452
+ create_custom_forward(attn, return_dict=False),
453
+ hidden_states,
454
+ encoder_hidden_states,
455
+ self_attention_additional_feats,
456
+ mode,
457
+ )[0]
458
+
459
+ # add motion module
460
+ if motion_module is not None:
461
+ hidden_states = torch.utils.checkpoint.checkpoint(
462
+ create_custom_forward(motion_module),
463
+ hidden_states.requires_grad_(),
464
+ temb,
465
+ encoder_hidden_states,
466
+ )
467
+
468
+ else:
469
+ hidden_states = resnet(hidden_states, temb)
470
+ hidden_states = attn(
471
+ hidden_states,
472
+ encoder_hidden_states=encoder_hidden_states,
473
+ self_attention_additional_feats=self_attention_additional_feats,
474
+ mode=mode,
475
+ ).sample
476
+
477
+ # add motion module
478
+ hidden_states = (
479
+ motion_module(
480
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
481
+ )
482
+ if motion_module is not None
483
+ else hidden_states
484
+ )
485
+
486
+ output_states += (hidden_states,)
487
+
488
+ if self.downsamplers is not None:
489
+ for downsampler in self.downsamplers:
490
+ hidden_states = downsampler(hidden_states)
491
+
492
+ output_states += (hidden_states,)
493
+
494
+ return hidden_states, output_states
495
+
496
+
497
+ class DownBlock3D(nn.Module):
498
+ def __init__(
499
+ self,
500
+ in_channels: int,
501
+ out_channels: int,
502
+ temb_channels: int,
503
+ dropout: float = 0.0,
504
+ num_layers: int = 1,
505
+ resnet_eps: float = 1e-6,
506
+ resnet_time_scale_shift: str = "default",
507
+ resnet_act_fn: str = "swish",
508
+ resnet_groups: int = 32,
509
+ resnet_pre_norm: bool = True,
510
+ output_scale_factor=1.0,
511
+ add_downsample=True,
512
+ downsample_padding=1,
513
+ use_inflated_groupnorm=None,
514
+ use_motion_module=None,
515
+ motion_module_type=None,
516
+ motion_module_kwargs=None,
517
+ ):
518
+ super().__init__()
519
+ resnets = []
520
+ motion_modules = []
521
+
522
+ # use_motion_module = False
523
+ for i in range(num_layers):
524
+ in_channels = in_channels if i == 0 else out_channels
525
+ resnets.append(
526
+ ResnetBlock3D(
527
+ in_channels=in_channels,
528
+ out_channels=out_channels,
529
+ temb_channels=temb_channels,
530
+ eps=resnet_eps,
531
+ groups=resnet_groups,
532
+ dropout=dropout,
533
+ time_embedding_norm=resnet_time_scale_shift,
534
+ non_linearity=resnet_act_fn,
535
+ output_scale_factor=output_scale_factor,
536
+ pre_norm=resnet_pre_norm,
537
+ use_inflated_groupnorm=use_inflated_groupnorm,
538
+ )
539
+ )
540
+ motion_modules.append(
541
+ get_motion_module(
542
+ in_channels=out_channels,
543
+ motion_module_type=motion_module_type,
544
+ motion_module_kwargs=motion_module_kwargs,
545
+ )
546
+ if use_motion_module
547
+ else None
548
+ )
549
+
550
+ self.resnets = nn.ModuleList(resnets)
551
+ self.motion_modules = nn.ModuleList(motion_modules)
552
+
553
+ if add_downsample:
554
+ self.downsamplers = nn.ModuleList(
555
+ [
556
+ Downsample3D(
557
+ out_channels,
558
+ use_conv=True,
559
+ out_channels=out_channels,
560
+ padding=downsample_padding,
561
+ name="op",
562
+ )
563
+ ]
564
+ )
565
+ else:
566
+ self.downsamplers = None
567
+
568
+ self.gradient_checkpointing = False
569
+
570
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
571
+ output_states = ()
572
+
573
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
574
+ # print(f"DownBlock3D {self.gradient_checkpointing = }")
575
+ if self.training and self.gradient_checkpointing:
576
+
577
+ def create_custom_forward(module):
578
+ def custom_forward(*inputs):
579
+ return module(*inputs)
580
+
581
+ return custom_forward
582
+
583
+ hidden_states = torch.utils.checkpoint.checkpoint(
584
+ create_custom_forward(resnet), hidden_states, temb
585
+ )
586
+ if motion_module is not None:
587
+ hidden_states = torch.utils.checkpoint.checkpoint(
588
+ create_custom_forward(motion_module),
589
+ hidden_states.requires_grad_(),
590
+ temb,
591
+ encoder_hidden_states,
592
+ )
593
+ else:
594
+ hidden_states = resnet(hidden_states, temb)
595
+
596
+ # add motion module
597
+ hidden_states = (
598
+ motion_module(
599
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
600
+ )
601
+ if motion_module is not None
602
+ else hidden_states
603
+ )
604
+
605
+ output_states += (hidden_states,)
606
+
607
+ if self.downsamplers is not None:
608
+ for downsampler in self.downsamplers:
609
+ hidden_states = downsampler(hidden_states)
610
+
611
+ output_states += (hidden_states,)
612
+
613
+ return hidden_states, output_states
614
+
615
+
616
+ class CrossAttnUpBlock3D(nn.Module):
617
+ def __init__(
618
+ self,
619
+ in_channels: int,
620
+ out_channels: int,
621
+ prev_output_channel: int,
622
+ temb_channels: int,
623
+ dropout: float = 0.0,
624
+ num_layers: int = 1,
625
+ resnet_eps: float = 1e-6,
626
+ resnet_time_scale_shift: str = "default",
627
+ resnet_act_fn: str = "swish",
628
+ resnet_groups: int = 32,
629
+ resnet_pre_norm: bool = True,
630
+ attn_num_head_channels=1,
631
+ cross_attention_dim=1280,
632
+ output_scale_factor=1.0,
633
+ add_upsample=True,
634
+ dual_cross_attention=False,
635
+ use_linear_projection=False,
636
+ only_cross_attention=False,
637
+ upcast_attention=False,
638
+ unet_use_cross_frame_attention=None,
639
+ unet_use_temporal_attention=None,
640
+ use_motion_module=None,
641
+ use_inflated_groupnorm=None,
642
+ motion_module_type=None,
643
+ motion_module_kwargs=None,
644
+ name=None
645
+ ):
646
+ super().__init__()
647
+ resnets = []
648
+ attentions = []
649
+ motion_modules = []
650
+
651
+ self.has_cross_attention = True
652
+ self.attn_num_head_channels = attn_num_head_channels
653
+ self.name = name
654
+
655
+ for i in range(num_layers):
656
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
657
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
658
+
659
+ resnets.append(
660
+ ResnetBlock3D(
661
+ in_channels=resnet_in_channels + res_skip_channels,
662
+ out_channels=out_channels,
663
+ temb_channels=temb_channels,
664
+ eps=resnet_eps,
665
+ groups=resnet_groups,
666
+ dropout=dropout,
667
+ time_embedding_norm=resnet_time_scale_shift,
668
+ non_linearity=resnet_act_fn,
669
+ output_scale_factor=output_scale_factor,
670
+ pre_norm=resnet_pre_norm,
671
+ use_inflated_groupnorm=use_inflated_groupnorm,
672
+ )
673
+ )
674
+ if dual_cross_attention:
675
+ raise NotImplementedError
676
+ if self.name is not None:
677
+ attn_name = f"{self.name}_{i}_TransformerModel"
678
+ else:
679
+ attn_name = None
680
+ attentions.append(
681
+ Transformer3DModel(
682
+ attn_num_head_channels,
683
+ out_channels // attn_num_head_channels,
684
+ in_channels=out_channels,
685
+ num_layers=1,
686
+ cross_attention_dim=cross_attention_dim,
687
+ norm_num_groups=resnet_groups,
688
+ use_linear_projection=use_linear_projection,
689
+ only_cross_attention=only_cross_attention,
690
+ upcast_attention=upcast_attention,
691
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
692
+ unet_use_temporal_attention=unet_use_temporal_attention,
693
+ name=attn_name,
694
+ )
695
+ )
696
+ motion_modules.append(
697
+ get_motion_module(
698
+ in_channels=out_channels,
699
+ motion_module_type=motion_module_type,
700
+ motion_module_kwargs=motion_module_kwargs,
701
+ )
702
+ if use_motion_module
703
+ else None
704
+ )
705
+
706
+ self.attentions = nn.ModuleList(attentions)
707
+ self.resnets = nn.ModuleList(resnets)
708
+ self.motion_modules = nn.ModuleList(motion_modules)
709
+
710
+ if add_upsample:
711
+ self.upsamplers = nn.ModuleList(
712
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
713
+ )
714
+ else:
715
+ self.upsamplers = None
716
+
717
+ self.gradient_checkpointing = False
718
+
719
+ def forward(
720
+ self,
721
+ hidden_states,
722
+ res_hidden_states_tuple,
723
+ temb=None,
724
+ encoder_hidden_states=None,
725
+ upsample_size=None,
726
+ attention_mask=None,
727
+ self_attention_additional_feats=None,
728
+ mode=None,
729
+ ):
730
+ for i, (resnet, attn, motion_module) in enumerate(
731
+ zip(self.resnets, self.attentions, self.motion_modules)
732
+ ):
733
+ # pop res hidden states
734
+ res_hidden_states = res_hidden_states_tuple[-1]
735
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
736
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
737
+
738
+ if self.training and self.gradient_checkpointing:
739
+
740
+ def create_custom_forward(module, return_dict=None):
741
+ def custom_forward(*inputs):
742
+ if return_dict is not None:
743
+ return module(*inputs, return_dict=return_dict)
744
+ else:
745
+ return module(*inputs)
746
+
747
+ return custom_forward
748
+
749
+ hidden_states = torch.utils.checkpoint.checkpoint(
750
+ create_custom_forward(resnet), hidden_states, temb
751
+ )
752
+ hidden_states = torch.utils.checkpoint.checkpoint(
753
+ create_custom_forward(attn, return_dict=False),
754
+ hidden_states,
755
+ encoder_hidden_states,
756
+ self_attention_additional_feats,
757
+ mode,
758
+ )[0]
759
+ if motion_module is not None:
760
+ hidden_states = torch.utils.checkpoint.checkpoint(
761
+ create_custom_forward(motion_module),
762
+ hidden_states.requires_grad_(),
763
+ temb,
764
+ encoder_hidden_states,
765
+ )
766
+
767
+ else:
768
+ hidden_states = resnet(hidden_states, temb)
769
+ hidden_states = attn(
770
+ hidden_states,
771
+ encoder_hidden_states=encoder_hidden_states,
772
+ self_attention_additional_feats=self_attention_additional_feats,
773
+ mode=mode,
774
+ ).sample
775
+
776
+ # add motion module
777
+ hidden_states = (
778
+ motion_module(
779
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
780
+ )
781
+ if motion_module is not None
782
+ else hidden_states
783
+ )
784
+
785
+ if self.upsamplers is not None:
786
+ for upsampler in self.upsamplers:
787
+ hidden_states = upsampler(hidden_states, upsample_size)
788
+
789
+ return hidden_states
790
+
791
+
792
+ class UpBlock3D(nn.Module):
793
+ def __init__(
794
+ self,
795
+ in_channels: int,
796
+ prev_output_channel: int,
797
+ out_channels: int,
798
+ temb_channels: int,
799
+ dropout: float = 0.0,
800
+ num_layers: int = 1,
801
+ resnet_eps: float = 1e-6,
802
+ resnet_time_scale_shift: str = "default",
803
+ resnet_act_fn: str = "swish",
804
+ resnet_groups: int = 32,
805
+ resnet_pre_norm: bool = True,
806
+ output_scale_factor=1.0,
807
+ add_upsample=True,
808
+ use_inflated_groupnorm=None,
809
+ use_motion_module=None,
810
+ motion_module_type=None,
811
+ motion_module_kwargs=None,
812
+ ):
813
+ super().__init__()
814
+ resnets = []
815
+ motion_modules = []
816
+
817
+ # use_motion_module = False
818
+ for i in range(num_layers):
819
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
820
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
821
+
822
+ resnets.append(
823
+ ResnetBlock3D(
824
+ in_channels=resnet_in_channels + res_skip_channels,
825
+ out_channels=out_channels,
826
+ temb_channels=temb_channels,
827
+ eps=resnet_eps,
828
+ groups=resnet_groups,
829
+ dropout=dropout,
830
+ time_embedding_norm=resnet_time_scale_shift,
831
+ non_linearity=resnet_act_fn,
832
+ output_scale_factor=output_scale_factor,
833
+ pre_norm=resnet_pre_norm,
834
+ use_inflated_groupnorm=use_inflated_groupnorm,
835
+ )
836
+ )
837
+ motion_modules.append(
838
+ get_motion_module(
839
+ in_channels=out_channels,
840
+ motion_module_type=motion_module_type,
841
+ motion_module_kwargs=motion_module_kwargs,
842
+ )
843
+ if use_motion_module
844
+ else None
845
+ )
846
+
847
+ self.resnets = nn.ModuleList(resnets)
848
+ self.motion_modules = nn.ModuleList(motion_modules)
849
+
850
+ if add_upsample:
851
+ self.upsamplers = nn.ModuleList(
852
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
853
+ )
854
+ else:
855
+ self.upsamplers = None
856
+
857
+ self.gradient_checkpointing = False
858
+
859
+ def forward(
860
+ self,
861
+ hidden_states,
862
+ res_hidden_states_tuple,
863
+ temb=None,
864
+ upsample_size=None,
865
+ encoder_hidden_states=None,
866
+ ):
867
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
868
+ # pop res hidden states
869
+ res_hidden_states = res_hidden_states_tuple[-1]
870
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
871
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
872
+
873
+ # print(f"UpBlock3D {self.gradient_checkpointing = }")
874
+ if self.training and self.gradient_checkpointing:
875
+
876
+ def create_custom_forward(module):
877
+ def custom_forward(*inputs):
878
+ return module(*inputs)
879
+
880
+ return custom_forward
881
+
882
+ hidden_states = torch.utils.checkpoint.checkpoint(
883
+ create_custom_forward(resnet), hidden_states, temb
884
+ )
885
+ if motion_module is not None:
886
+ hidden_states = torch.utils.checkpoint.checkpoint(
887
+ create_custom_forward(motion_module),
888
+ hidden_states.requires_grad_(),
889
+ temb,
890
+ encoder_hidden_states,
891
+ )
892
+ else:
893
+ hidden_states = resnet(hidden_states, temb)
894
+ hidden_states = (
895
+ motion_module(
896
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
897
+ )
898
+ if motion_module is not None
899
+ else hidden_states
900
+ )
901
+
902
+ if self.upsamplers is not None:
903
+ for upsampler in self.upsamplers:
904
+ hidden_states = upsampler(hidden_states, upsample_size)
905
+
906
+ return hidden_states
src/__init__.py ADDED
File without changes
src/dataset/dance_image.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+
4
+ import torch
5
+ import torchvision.transforms as transforms
6
+ from decord import VideoReader
7
+ from PIL import Image
8
+ from torch.utils.data import Dataset
9
+ from transformers import CLIPImageProcessor
10
+
11
+
12
+ class HumanDanceDataset(Dataset):
13
+ def __init__(
14
+ self,
15
+ img_size,
16
+ img_scale=(1.0, 1.0),
17
+ img_ratio=(0.9, 1.0),
18
+ drop_ratio=0.1,
19
+ data_meta_paths=["./data/fahsion_meta.json"],
20
+ sample_margin=30,
21
+ ):
22
+ super().__init__()
23
+
24
+ self.img_size = img_size
25
+ self.img_scale = img_scale
26
+ self.img_ratio = img_ratio
27
+ self.sample_margin = sample_margin
28
+
29
+ # -----
30
+ # vid_meta format:
31
+ # [{'video_path': , 'kps_path': , 'other':},
32
+ # {'video_path': , 'kps_path': , 'other':}]
33
+ # -----
34
+ vid_meta = []
35
+ for data_meta_path in data_meta_paths:
36
+ vid_meta.extend(json.load(open(data_meta_path, "r")))
37
+ self.vid_meta = vid_meta
38
+
39
+ self.clip_image_processor = CLIPImageProcessor()
40
+
41
+ self.transform = transforms.Compose(
42
+ [
43
+ transforms.RandomResizedCrop(
44
+ self.img_size,
45
+ scale=self.img_scale,
46
+ ratio=self.img_ratio,
47
+ interpolation=transforms.InterpolationMode.BILINEAR,
48
+ ),
49
+ transforms.ToTensor(),
50
+ transforms.Normalize([0.5], [0.5]),
51
+ ]
52
+ )
53
+
54
+ self.cond_transform = transforms.Compose(
55
+ [
56
+ transforms.RandomResizedCrop(
57
+ self.img_size,
58
+ scale=self.img_scale,
59
+ ratio=self.img_ratio,
60
+ interpolation=transforms.InterpolationMode.BILINEAR,
61
+ ),
62
+ transforms.ToTensor(),
63
+ ]
64
+ )
65
+
66
+ self.drop_ratio = drop_ratio
67
+
68
+ def augmentation(self, image, transform, state=None):
69
+ if state is not None:
70
+ torch.set_rng_state(state)
71
+ return transform(image)
72
+
73
+ def __getitem__(self, index):
74
+ video_meta = self.vid_meta[index]
75
+ video_path = video_meta["video_path"]
76
+ kps_path = video_meta["kps_path"]
77
+
78
+ video_reader = VideoReader(video_path)
79
+ kps_reader = VideoReader(kps_path)
80
+
81
+ assert len(video_reader) == len(
82
+ kps_reader
83
+ ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"
84
+
85
+ video_length = len(video_reader)
86
+
87
+ margin = min(self.sample_margin, video_length)
88
+
89
+ ref_img_idx = random.randint(0, video_length - 1)
90
+ if ref_img_idx + margin < video_length:
91
+ tgt_img_idx = random.randint(ref_img_idx + margin, video_length - 1)
92
+ elif ref_img_idx - margin > 0:
93
+ tgt_img_idx = random.randint(0, ref_img_idx - margin)
94
+ else:
95
+ tgt_img_idx = random.randint(0, video_length - 1)
96
+
97
+ ref_img = video_reader[ref_img_idx]
98
+ ref_img_pil = Image.fromarray(ref_img.asnumpy())
99
+ tgt_img = video_reader[tgt_img_idx]
100
+ tgt_img_pil = Image.fromarray(tgt_img.asnumpy())
101
+
102
+ tgt_pose = kps_reader[tgt_img_idx]
103
+ tgt_pose_pil = Image.fromarray(tgt_pose.asnumpy())
104
+
105
+ state = torch.get_rng_state()
106
+ tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
107
+ tgt_pose_img = self.augmentation(tgt_pose_pil, self.cond_transform, state)
108
+ ref_img_vae = self.augmentation(ref_img_pil, self.transform, state)
109
+ clip_image = self.clip_image_processor(
110
+ images=ref_img_pil, return_tensors="pt"
111
+ ).pixel_values[0]
112
+
113
+ sample = dict(
114
+ video_dir=video_path,
115
+ img=tgt_img,
116
+ tgt_pose=tgt_pose_img,
117
+ ref_img=ref_img_vae,
118
+ clip_images=clip_image,
119
+ )
120
+
121
+ return sample
122
+
123
+ def __len__(self):
124
+ return len(self.vid_meta)
src/dataset/dance_video.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from typing import List
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import torchvision.transforms as transforms
9
+ from decord import VideoReader
10
+ from PIL import Image
11
+ from torch.utils.data import Dataset
12
+ from transformers import CLIPImageProcessor
13
+
14
+
15
+ class HumanDanceVideoDataset(Dataset):
16
+ def __init__(
17
+ self,
18
+ sample_rate,
19
+ n_sample_frames,
20
+ width,
21
+ height,
22
+ img_scale=(1.0, 1.0),
23
+ img_ratio=(0.9, 1.0),
24
+ drop_ratio=0.1,
25
+ data_meta_paths=["./data/fashion_meta.json"],
26
+ ):
27
+ super().__init__()
28
+ self.sample_rate = sample_rate
29
+ self.n_sample_frames = n_sample_frames
30
+ self.width = width
31
+ self.height = height
32
+ self.img_scale = img_scale
33
+ self.img_ratio = img_ratio
34
+
35
+ vid_meta = []
36
+ for data_meta_path in data_meta_paths:
37
+ vid_meta.extend(json.load(open(data_meta_path, "r")))
38
+ self.vid_meta = vid_meta
39
+
40
+ self.clip_image_processor = CLIPImageProcessor()
41
+
42
+ self.pixel_transform = transforms.Compose(
43
+ [
44
+ transforms.RandomResizedCrop(
45
+ (height, width),
46
+ scale=self.img_scale,
47
+ ratio=self.img_ratio,
48
+ interpolation=transforms.InterpolationMode.BILINEAR,
49
+ ),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize([0.5], [0.5]),
52
+ ]
53
+ )
54
+
55
+ self.cond_transform = transforms.Compose(
56
+ [
57
+ transforms.RandomResizedCrop(
58
+ (height, width),
59
+ scale=self.img_scale,
60
+ ratio=self.img_ratio,
61
+ interpolation=transforms.InterpolationMode.BILINEAR,
62
+ ),
63
+ transforms.ToTensor(),
64
+ ]
65
+ )
66
+
67
+ self.drop_ratio = drop_ratio
68
+
69
+ def augmentation(self, images, transform, state=None):
70
+ if state is not None:
71
+ torch.set_rng_state(state)
72
+ if isinstance(images, List):
73
+ transformed_images = [transform(img) for img in images]
74
+ ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
75
+ else:
76
+ ret_tensor = transform(images) # (c, h, w)
77
+ return ret_tensor
78
+
79
+ def __getitem__(self, index):
80
+ video_meta = self.vid_meta[index]
81
+ video_path = video_meta["video_path"]
82
+ kps_path = video_meta["kps_path"]
83
+
84
+ video_reader = VideoReader(video_path)
85
+ kps_reader = VideoReader(kps_path)
86
+
87
+ assert len(video_reader) == len(
88
+ kps_reader
89
+ ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"
90
+
91
+ video_length = len(video_reader)
92
+
93
+ clip_length = min(
94
+ video_length, (self.n_sample_frames - 1) * self.sample_rate + 1
95
+ )
96
+ start_idx = random.randint(0, video_length - clip_length)
97
+ batch_index = np.linspace(
98
+ start_idx, start_idx + clip_length - 1, self.n_sample_frames, dtype=int
99
+ ).tolist()
100
+
101
+ # read frames and kps
102
+ vid_pil_image_list = []
103
+ pose_pil_image_list = []
104
+ for index in batch_index:
105
+ img = video_reader[index]
106
+ vid_pil_image_list.append(Image.fromarray(img.asnumpy()))
107
+ img = kps_reader[index]
108
+ pose_pil_image_list.append(Image.fromarray(img.asnumpy()))
109
+
110
+ ref_img_idx = random.randint(0, video_length - 1)
111
+ ref_img = Image.fromarray(video_reader[ref_img_idx].asnumpy())
112
+
113
+ # transform
114
+ state = torch.get_rng_state()
115
+ pixel_values_vid = self.augmentation(
116
+ vid_pil_image_list, self.pixel_transform, state
117
+ )
118
+ pixel_values_pose = self.augmentation(
119
+ pose_pil_image_list, self.cond_transform, state
120
+ )
121
+ pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state)
122
+ clip_ref_img = self.clip_image_processor(
123
+ images=ref_img, return_tensors="pt"
124
+ ).pixel_values[0]
125
+
126
+ sample = dict(
127
+ video_dir=video_path,
128
+ pixel_values_vid=pixel_values_vid,
129
+ pixel_values_pose=pixel_values_pose,
130
+ pixel_values_ref_img=pixel_values_ref_img,
131
+ clip_ref_img=clip_ref_img,
132
+ )
133
+
134
+ return sample
135
+
136
+ def __len__(self):
137
+ return len(self.vid_meta)
src/dwpose/__init__.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/IDEA-Research/DWPose
2
+ # Openpose
3
+ # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
4
+ # 2nd Edited by https://github.com/Hzzone/pytorch-openpose
5
+ # 3rd Edited by ControlNet
6
+ # 4th Edited by ControlNet (added face and correct hands)
7
+
8
+ import copy
9
+ import os
10
+
11
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
12
+ import cv2
13
+ import numpy as np
14
+ import torch
15
+ from controlnet_aux.util import HWC3, resize_image
16
+ from PIL import Image
17
+
18
+ from . import util
19
+ from .wholebody import Wholebody
20
+
21
+
22
+ def draw_pose(pose, H, W):
23
+ bodies = pose["bodies"]
24
+ faces = pose["faces"]
25
+ hands = pose["hands"]
26
+ candidate = bodies["candidate"]
27
+ subset = bodies["subset"]
28
+ canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
29
+
30
+ canvas = util.draw_bodypose(canvas, candidate, subset)
31
+
32
+ canvas = util.draw_handpose(canvas, hands)
33
+
34
+ canvas = util.draw_facepose(canvas, faces)
35
+
36
+ return canvas
37
+
38
+
39
+ class DWposeDetector:
40
+ def __init__(self):
41
+ pass
42
+
43
+ def to(self, device):
44
+ self.pose_estimation = Wholebody(device)
45
+ return self
46
+
47
+ def cal_height(self, input_image):
48
+ input_image = cv2.cvtColor(
49
+ np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR
50
+ )
51
+
52
+ input_image = HWC3(input_image)
53
+ H, W, C = input_image.shape
54
+ with torch.no_grad():
55
+ candidate, subset = self.pose_estimation(input_image)
56
+ nums, keys, locs = candidate.shape
57
+ # candidate[..., 0] /= float(W)
58
+ # candidate[..., 1] /= float(H)
59
+ body = candidate
60
+ return body[0, ..., 1].min(), body[..., 1].max() - body[..., 1].min()
61
+
62
+ def __call__(
63
+ self,
64
+ input_image,
65
+ detect_resolution=512,
66
+ image_resolution=512,
67
+ output_type="pil",
68
+ **kwargs,
69
+ ):
70
+ input_image = cv2.cvtColor(
71
+ np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR
72
+ )
73
+
74
+ input_image = HWC3(input_image)
75
+ input_image = resize_image(input_image, detect_resolution)
76
+ H, W, C = input_image.shape
77
+ with torch.no_grad():
78
+ candidate, subset = self.pose_estimation(input_image)
79
+ nums, keys, locs = candidate.shape
80
+ candidate[..., 0] /= float(W)
81
+ candidate[..., 1] /= float(H)
82
+ score = subset[:, :18]
83
+ max_ind = np.mean(score, axis=-1).argmax(axis=0)
84
+ score = score[[max_ind]]
85
+ body = candidate[:, :18].copy()
86
+ body = body[[max_ind]]
87
+ nums = 1
88
+ body = body.reshape(nums * 18, locs)
89
+ body_score = copy.deepcopy(score)
90
+ for i in range(len(score)):
91
+ for j in range(len(score[i])):
92
+ if score[i][j] > 0.3:
93
+ score[i][j] = int(18 * i + j)
94
+ else:
95
+ score[i][j] = -1
96
+
97
+ un_visible = subset < 0.3
98
+ candidate[un_visible] = -1
99
+
100
+ foot = candidate[:, 18:24]
101
+
102
+ faces = candidate[[max_ind], 24:92]
103
+
104
+ hands = candidate[[max_ind], 92:113]
105
+ hands = np.vstack([hands, candidate[[max_ind], 113:]])
106
+
107
+ bodies = dict(candidate=body, subset=score)
108
+ pose = dict(bodies=bodies, hands=hands, faces=faces)
109
+
110
+ detected_map = draw_pose(pose, H, W)
111
+ detected_map = HWC3(detected_map)
112
+
113
+ img = resize_image(input_image, image_resolution)
114
+ H, W, C = img.shape
115
+
116
+ detected_map = cv2.resize(
117
+ detected_map, (W, H), interpolation=cv2.INTER_LINEAR
118
+ )
119
+
120
+ if output_type == "pil":
121
+ detected_map = Image.fromarray(detected_map)
122
+
123
+ return detected_map, body_score
src/dwpose/onnxdet.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/IDEA-Research/DWPose
2
+ import cv2
3
+ import numpy as np
4
+ import onnxruntime
5
+
6
+
7
+ def nms(boxes, scores, nms_thr):
8
+ """Single class NMS implemented in Numpy."""
9
+ x1 = boxes[:, 0]
10
+ y1 = boxes[:, 1]
11
+ x2 = boxes[:, 2]
12
+ y2 = boxes[:, 3]
13
+
14
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
15
+ order = scores.argsort()[::-1]
16
+
17
+ keep = []
18
+ while order.size > 0:
19
+ i = order[0]
20
+ keep.append(i)
21
+ xx1 = np.maximum(x1[i], x1[order[1:]])
22
+ yy1 = np.maximum(y1[i], y1[order[1:]])
23
+ xx2 = np.minimum(x2[i], x2[order[1:]])
24
+ yy2 = np.minimum(y2[i], y2[order[1:]])
25
+
26
+ w = np.maximum(0.0, xx2 - xx1 + 1)
27
+ h = np.maximum(0.0, yy2 - yy1 + 1)
28
+ inter = w * h
29
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
30
+
31
+ inds = np.where(ovr <= nms_thr)[0]
32
+ order = order[inds + 1]
33
+
34
+ return keep
35
+
36
+
37
+ def multiclass_nms(boxes, scores, nms_thr, score_thr):
38
+ """Multiclass NMS implemented in Numpy. Class-aware version."""
39
+ final_dets = []
40
+ num_classes = scores.shape[1]
41
+ for cls_ind in range(num_classes):
42
+ cls_scores = scores[:, cls_ind]
43
+ valid_score_mask = cls_scores > score_thr
44
+ if valid_score_mask.sum() == 0:
45
+ continue
46
+ else:
47
+ valid_scores = cls_scores[valid_score_mask]
48
+ valid_boxes = boxes[valid_score_mask]
49
+ keep = nms(valid_boxes, valid_scores, nms_thr)
50
+ if len(keep) > 0:
51
+ cls_inds = np.ones((len(keep), 1)) * cls_ind
52
+ dets = np.concatenate(
53
+ [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
54
+ )
55
+ final_dets.append(dets)
56
+ if len(final_dets) == 0:
57
+ return None
58
+ return np.concatenate(final_dets, 0)
59
+
60
+
61
+ def demo_postprocess(outputs, img_size, p6=False):
62
+ grids = []
63
+ expanded_strides = []
64
+ strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
65
+
66
+ hsizes = [img_size[0] // stride for stride in strides]
67
+ wsizes = [img_size[1] // stride for stride in strides]
68
+
69
+ for hsize, wsize, stride in zip(hsizes, wsizes, strides):
70
+ xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
71
+ grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
72
+ grids.append(grid)
73
+ shape = grid.shape[:2]
74
+ expanded_strides.append(np.full((*shape, 1), stride))
75
+
76
+ grids = np.concatenate(grids, 1)
77
+ expanded_strides = np.concatenate(expanded_strides, 1)
78
+ outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
79
+ outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
80
+
81
+ return outputs
82
+
83
+
84
+ def preprocess(img, input_size, swap=(2, 0, 1)):
85
+ if len(img.shape) == 3:
86
+ padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
87
+ else:
88
+ padded_img = np.ones(input_size, dtype=np.uint8) * 114
89
+
90
+ r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
91
+ resized_img = cv2.resize(
92
+ img,
93
+ (int(img.shape[1] * r), int(img.shape[0] * r)),
94
+ interpolation=cv2.INTER_LINEAR,
95
+ ).astype(np.uint8)
96
+ padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
97
+
98
+ padded_img = padded_img.transpose(swap)
99
+ padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
100
+ return padded_img, r
101
+
102
+
103
+ def inference_detector(session, oriImg):
104
+ input_shape = (640, 640)
105
+ img, ratio = preprocess(oriImg, input_shape)
106
+
107
+ ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
108
+ output = session.run(None, ort_inputs)
109
+ predictions = demo_postprocess(output[0], input_shape)[0]
110
+
111
+ boxes = predictions[:, :4]
112
+ scores = predictions[:, 4:5] * predictions[:, 5:]
113
+
114
+ boxes_xyxy = np.ones_like(boxes)
115
+ boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.0
116
+ boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.0
117
+ boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.0
118
+ boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.0
119
+ boxes_xyxy /= ratio
120
+ dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
121
+ if dets is not None:
122
+ final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
123
+ isscore = final_scores > 0.3
124
+ iscat = final_cls_inds == 0
125
+ isbbox = [i and j for (i, j) in zip(isscore, iscat)]
126
+ final_boxes = final_boxes[isbbox]
127
+ else:
128
+ return []
129
+
130
+ return final_boxes
src/dwpose/onnxpose.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/IDEA-Research/DWPose
2
+ from typing import List, Tuple
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import onnxruntime as ort
7
+
8
+
9
+ def preprocess(
10
+ img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
11
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
12
+ """Do preprocessing for RTMPose model inference.
13
+
14
+ Args:
15
+ img (np.ndarray): Input image in shape.
16
+ input_size (tuple): Input image size in shape (w, h).
17
+
18
+ Returns:
19
+ tuple:
20
+ - resized_img (np.ndarray): Preprocessed image.
21
+ - center (np.ndarray): Center of image.
22
+ - scale (np.ndarray): Scale of image.
23
+ """
24
+ # get shape of image
25
+ img_shape = img.shape[:2]
26
+ out_img, out_center, out_scale = [], [], []
27
+ if len(out_bbox) == 0:
28
+ out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
29
+ for i in range(len(out_bbox)):
30
+ x0 = out_bbox[i][0]
31
+ y0 = out_bbox[i][1]
32
+ x1 = out_bbox[i][2]
33
+ y1 = out_bbox[i][3]
34
+ bbox = np.array([x0, y0, x1, y1])
35
+
36
+ # get center and scale
37
+ center, scale = bbox_xyxy2cs(bbox, padding=1.25)
38
+
39
+ # do affine transformation
40
+ resized_img, scale = top_down_affine(input_size, scale, center, img)
41
+
42
+ # normalize image
43
+ mean = np.array([123.675, 116.28, 103.53])
44
+ std = np.array([58.395, 57.12, 57.375])
45
+ resized_img = (resized_img - mean) / std
46
+
47
+ out_img.append(resized_img)
48
+ out_center.append(center)
49
+ out_scale.append(scale)
50
+
51
+ return out_img, out_center, out_scale
52
+
53
+
54
+ def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
55
+ """Inference RTMPose model.
56
+
57
+ Args:
58
+ sess (ort.InferenceSession): ONNXRuntime session.
59
+ img (np.ndarray): Input image in shape.
60
+
61
+ Returns:
62
+ outputs (np.ndarray): Output of RTMPose model.
63
+ """
64
+ all_out = []
65
+ # build input
66
+ for i in range(len(img)):
67
+ input = [img[i].transpose(2, 0, 1)]
68
+
69
+ # build output
70
+ sess_input = {sess.get_inputs()[0].name: input}
71
+ sess_output = []
72
+ for out in sess.get_outputs():
73
+ sess_output.append(out.name)
74
+
75
+ # run model
76
+ outputs = sess.run(sess_output, sess_input)
77
+ all_out.append(outputs)
78
+
79
+ return all_out
80
+
81
+
82
+ def postprocess(
83
+ outputs: List[np.ndarray],
84
+ model_input_size: Tuple[int, int],
85
+ center: Tuple[int, int],
86
+ scale: Tuple[int, int],
87
+ simcc_split_ratio: float = 2.0,
88
+ ) -> Tuple[np.ndarray, np.ndarray]:
89
+ """Postprocess for RTMPose model output.
90
+
91
+ Args:
92
+ outputs (np.ndarray): Output of RTMPose model.
93
+ model_input_size (tuple): RTMPose model Input image size.
94
+ center (tuple): Center of bbox in shape (x, y).
95
+ scale (tuple): Scale of bbox in shape (w, h).
96
+ simcc_split_ratio (float): Split ratio of simcc.
97
+
98
+ Returns:
99
+ tuple:
100
+ - keypoints (np.ndarray): Rescaled keypoints.
101
+ - scores (np.ndarray): Model predict scores.
102
+ """
103
+ all_key = []
104
+ all_score = []
105
+ for i in range(len(outputs)):
106
+ # use simcc to decode
107
+ simcc_x, simcc_y = outputs[i]
108
+ keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
109
+
110
+ # rescale keypoints
111
+ keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
112
+ all_key.append(keypoints[0])
113
+ all_score.append(scores[0])
114
+
115
+ return np.array(all_key), np.array(all_score)
116
+
117
+
118
+ def bbox_xyxy2cs(
119
+ bbox: np.ndarray, padding: float = 1.0
120
+ ) -> Tuple[np.ndarray, np.ndarray]:
121
+ """Transform the bbox format from (x,y,w,h) into (center, scale)
122
+
123
+ Args:
124
+ bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
125
+ as (left, top, right, bottom)
126
+ padding (float): BBox padding factor that will be multilied to scale.
127
+ Default: 1.0
128
+
129
+ Returns:
130
+ tuple: A tuple containing center and scale.
131
+ - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
132
+ (n, 2)
133
+ - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
134
+ (n, 2)
135
+ """
136
+ # convert single bbox from (4, ) to (1, 4)
137
+ dim = bbox.ndim
138
+ if dim == 1:
139
+ bbox = bbox[None, :]
140
+
141
+ # get bbox center and scale
142
+ x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
143
+ center = np.hstack([x1 + x2, y1 + y2]) * 0.5
144
+ scale = np.hstack([x2 - x1, y2 - y1]) * padding
145
+
146
+ if dim == 1:
147
+ center = center[0]
148
+ scale = scale[0]
149
+
150
+ return center, scale
151
+
152
+
153
+ def _fix_aspect_ratio(bbox_scale: np.ndarray, aspect_ratio: float) -> np.ndarray:
154
+ """Extend the scale to match the given aspect ratio.
155
+
156
+ Args:
157
+ scale (np.ndarray): The image scale (w, h) in shape (2, )
158
+ aspect_ratio (float): The ratio of ``w/h``
159
+
160
+ Returns:
161
+ np.ndarray: The reshaped image scale in (2, )
162
+ """
163
+ w, h = np.hsplit(bbox_scale, [1])
164
+ bbox_scale = np.where(
165
+ w > h * aspect_ratio,
166
+ np.hstack([w, w / aspect_ratio]),
167
+ np.hstack([h * aspect_ratio, h]),
168
+ )
169
+ return bbox_scale
170
+
171
+
172
+ def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
173
+ """Rotate a point by an angle.
174
+
175
+ Args:
176
+ pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
177
+ angle_rad (float): rotation angle in radian
178
+
179
+ Returns:
180
+ np.ndarray: Rotated point in shape (2, )
181
+ """
182
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
183
+ rot_mat = np.array([[cs, -sn], [sn, cs]])
184
+ return rot_mat @ pt
185
+
186
+
187
+ def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
188
+ """To calculate the affine matrix, three pairs of points are required. This
189
+ function is used to get the 3rd point, given 2D points a & b.
190
+
191
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
192
+ anticlockwise, using b as the rotation center.
193
+
194
+ Args:
195
+ a (np.ndarray): The 1st point (x,y) in shape (2, )
196
+ b (np.ndarray): The 2nd point (x,y) in shape (2, )
197
+
198
+ Returns:
199
+ np.ndarray: The 3rd point.
200
+ """
201
+ direction = a - b
202
+ c = b + np.r_[-direction[1], direction[0]]
203
+ return c
204
+
205
+
206
+ def get_warp_matrix(
207
+ center: np.ndarray,
208
+ scale: np.ndarray,
209
+ rot: float,
210
+ output_size: Tuple[int, int],
211
+ shift: Tuple[float, float] = (0.0, 0.0),
212
+ inv: bool = False,
213
+ ) -> np.ndarray:
214
+ """Calculate the affine transformation matrix that can warp the bbox area
215
+ in the input image to the output size.
216
+
217
+ Args:
218
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
219
+ scale (np.ndarray[2, ]): Scale of the bounding box
220
+ wrt [width, height].
221
+ rot (float): Rotation angle (degree).
222
+ output_size (np.ndarray[2, ] | list(2,)): Size of the
223
+ destination heatmaps.
224
+ shift (0-100%): Shift translation ratio wrt the width/height.
225
+ Default (0., 0.).
226
+ inv (bool): Option to inverse the affine transform direction.
227
+ (inv=False: src->dst or inv=True: dst->src)
228
+
229
+ Returns:
230
+ np.ndarray: A 2x3 transformation matrix
231
+ """
232
+ shift = np.array(shift)
233
+ src_w = scale[0]
234
+ dst_w = output_size[0]
235
+ dst_h = output_size[1]
236
+
237
+ # compute transformation matrix
238
+ rot_rad = np.deg2rad(rot)
239
+ src_dir = _rotate_point(np.array([0.0, src_w * -0.5]), rot_rad)
240
+ dst_dir = np.array([0.0, dst_w * -0.5])
241
+
242
+ # get four corners of the src rectangle in the original image
243
+ src = np.zeros((3, 2), dtype=np.float32)
244
+ src[0, :] = center + scale * shift
245
+ src[1, :] = center + src_dir + scale * shift
246
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
247
+
248
+ # get four corners of the dst rectangle in the input image
249
+ dst = np.zeros((3, 2), dtype=np.float32)
250
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
251
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
252
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
253
+
254
+ if inv:
255
+ warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
256
+ else:
257
+ warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
258
+
259
+ return warp_mat
260
+
261
+
262
+ def top_down_affine(
263
+ input_size: dict, bbox_scale: dict, bbox_center: dict, img: np.ndarray
264
+ ) -> Tuple[np.ndarray, np.ndarray]:
265
+ """Get the bbox image as the model input by affine transform.
266
+
267
+ Args:
268
+ input_size (dict): The input size of the model.
269
+ bbox_scale (dict): The bbox scale of the img.
270
+ bbox_center (dict): The bbox center of the img.
271
+ img (np.ndarray): The original image.
272
+
273
+ Returns:
274
+ tuple: A tuple containing center and scale.
275
+ - np.ndarray[float32]: img after affine transform.
276
+ - np.ndarray[float32]: bbox scale after affine transform.
277
+ """
278
+ w, h = input_size
279
+ warp_size = (int(w), int(h))
280
+
281
+ # reshape bbox to fixed aspect ratio
282
+ bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
283
+
284
+ # get the affine matrix
285
+ center = bbox_center
286
+ scale = bbox_scale
287
+ rot = 0
288
+ warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
289
+
290
+ # do affine transform
291
+ img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
292
+
293
+ return img, bbox_scale
294
+
295
+
296
+ def get_simcc_maximum(
297
+ simcc_x: np.ndarray, simcc_y: np.ndarray
298
+ ) -> Tuple[np.ndarray, np.ndarray]:
299
+ """Get maximum response location and value from simcc representations.
300
+
301
+ Note:
302
+ instance number: N
303
+ num_keypoints: K
304
+ heatmap height: H
305
+ heatmap width: W
306
+
307
+ Args:
308
+ simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
309
+ simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
310
+
311
+ Returns:
312
+ tuple:
313
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
314
+ (K, 2) or (N, K, 2)
315
+ - vals (np.ndarray): values of maximum heatmap responses in shape
316
+ (K,) or (N, K)
317
+ """
318
+ N, K, Wx = simcc_x.shape
319
+ simcc_x = simcc_x.reshape(N * K, -1)
320
+ simcc_y = simcc_y.reshape(N * K, -1)
321
+
322
+ # get maximum value locations
323
+ x_locs = np.argmax(simcc_x, axis=1)
324
+ y_locs = np.argmax(simcc_y, axis=1)
325
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
326
+ max_val_x = np.amax(simcc_x, axis=1)
327
+ max_val_y = np.amax(simcc_y, axis=1)
328
+
329
+ # get maximum value across x and y axis
330
+ mask = max_val_x > max_val_y
331
+ max_val_x[mask] = max_val_y[mask]
332
+ vals = max_val_x
333
+ locs[vals <= 0.0] = -1
334
+
335
+ # reshape
336
+ locs = locs.reshape(N, K, 2)
337
+ vals = vals.reshape(N, K)
338
+
339
+ return locs, vals
340
+
341
+
342
+ def decode(
343
+ simcc_x: np.ndarray, simcc_y: np.ndarray, simcc_split_ratio
344
+ ) -> Tuple[np.ndarray, np.ndarray]:
345
+ """Modulate simcc distribution with Gaussian.
346
+
347
+ Args:
348
+ simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
349
+ simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
350
+ simcc_split_ratio (int): The split ratio of simcc.
351
+
352
+ Returns:
353
+ tuple: A tuple containing center and scale.
354
+ - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
355
+ - np.ndarray[float32]: scores in shape (K,) or (n, K)
356
+ """
357
+ keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
358
+ keypoints /= simcc_split_ratio
359
+
360
+ return keypoints, scores
361
+
362
+
363
+ def inference_pose(session, out_bbox, oriImg):
364
+ h, w = session.get_inputs()[0].shape[2:]
365
+ model_input_size = (w, h)
366
+ resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
367
+ outputs = inference(session, resized_img)
368
+ keypoints, scores = postprocess(outputs, model_input_size, center, scale)
369
+
370
+ return keypoints, scores
src/dwpose/util.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/IDEA-Research/DWPose
2
+ import math
3
+ import numpy as np
4
+ import matplotlib
5
+ import cv2
6
+
7
+
8
+ eps = 0.01
9
+
10
+
11
+ def smart_resize(x, s):
12
+ Ht, Wt = s
13
+ if x.ndim == 2:
14
+ Ho, Wo = x.shape
15
+ Co = 1
16
+ else:
17
+ Ho, Wo, Co = x.shape
18
+ if Co == 3 or Co == 1:
19
+ k = float(Ht + Wt) / float(Ho + Wo)
20
+ return cv2.resize(
21
+ x,
22
+ (int(Wt), int(Ht)),
23
+ interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4,
24
+ )
25
+ else:
26
+ return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
27
+
28
+
29
+ def smart_resize_k(x, fx, fy):
30
+ if x.ndim == 2:
31
+ Ho, Wo = x.shape
32
+ Co = 1
33
+ else:
34
+ Ho, Wo, Co = x.shape
35
+ Ht, Wt = Ho * fy, Wo * fx
36
+ if Co == 3 or Co == 1:
37
+ k = float(Ht + Wt) / float(Ho + Wo)
38
+ return cv2.resize(
39
+ x,
40
+ (int(Wt), int(Ht)),
41
+ interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4,
42
+ )
43
+ else:
44
+ return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
45
+
46
+
47
+ def padRightDownCorner(img, stride, padValue):
48
+ h = img.shape[0]
49
+ w = img.shape[1]
50
+
51
+ pad = 4 * [None]
52
+ pad[0] = 0 # up
53
+ pad[1] = 0 # left
54
+ pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
55
+ pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
56
+
57
+ img_padded = img
58
+ pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
59
+ img_padded = np.concatenate((pad_up, img_padded), axis=0)
60
+ pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
61
+ img_padded = np.concatenate((pad_left, img_padded), axis=1)
62
+ pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
63
+ img_padded = np.concatenate((img_padded, pad_down), axis=0)
64
+ pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
65
+ img_padded = np.concatenate((img_padded, pad_right), axis=1)
66
+
67
+ return img_padded, pad
68
+
69
+
70
+ def transfer(model, model_weights):
71
+ transfered_model_weights = {}
72
+ for weights_name in model.state_dict().keys():
73
+ transfered_model_weights[weights_name] = model_weights[
74
+ ".".join(weights_name.split(".")[1:])
75
+ ]
76
+ return transfered_model_weights
77
+
78
+
79
+ def draw_bodypose(canvas, candidate, subset):
80
+ H, W, C = canvas.shape
81
+ candidate = np.array(candidate)
82
+ subset = np.array(subset)
83
+
84
+ stickwidth = 4
85
+
86
+ limbSeq = [
87
+ [2, 3],
88
+ [2, 6],
89
+ [3, 4],
90
+ [4, 5],
91
+ [6, 7],
92
+ [7, 8],
93
+ [2, 9],
94
+ [9, 10],
95
+ [10, 11],
96
+ [2, 12],
97
+ [12, 13],
98
+ [13, 14],
99
+ [2, 1],
100
+ [1, 15],
101
+ [15, 17],
102
+ [1, 16],
103
+ [16, 18],
104
+ [3, 17],
105
+ [6, 18],
106
+ ]
107
+
108
+ colors = [
109
+ [255, 0, 0],
110
+ [255, 85, 0],
111
+ [255, 170, 0],
112
+ [255, 255, 0],
113
+ [170, 255, 0],
114
+ [85, 255, 0],
115
+ [0, 255, 0],
116
+ [0, 255, 85],
117
+ [0, 255, 170],
118
+ [0, 255, 255],
119
+ [0, 170, 255],
120
+ [0, 85, 255],
121
+ [0, 0, 255],
122
+ [85, 0, 255],
123
+ [170, 0, 255],
124
+ [255, 0, 255],
125
+ [255, 0, 170],
126
+ [255, 0, 85],
127
+ ]
128
+
129
+ for i in range(17):
130
+ for n in range(len(subset)):
131
+ index = subset[n][np.array(limbSeq[i]) - 1]
132
+ if -1 in index:
133
+ continue
134
+ Y = candidate[index.astype(int), 0] * float(W)
135
+ X = candidate[index.astype(int), 1] * float(H)
136
+ mX = np.mean(X)
137
+ mY = np.mean(Y)
138
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
139
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
140
+ polygon = cv2.ellipse2Poly(
141
+ (int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1
142
+ )
143
+ cv2.fillConvexPoly(canvas, polygon, colors[i])
144
+
145
+ canvas = (canvas * 0.6).astype(np.uint8)
146
+
147
+ for i in range(18):
148
+ for n in range(len(subset)):
149
+ index = int(subset[n][i])
150
+ if index == -1:
151
+ continue
152
+ x, y = candidate[index][0:2]
153
+ x = int(x * W)
154
+ y = int(y * H)
155
+ cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
156
+
157
+ return canvas
158
+
159
+
160
+ def draw_handpose(canvas, all_hand_peaks):
161
+ H, W, C = canvas.shape
162
+
163
+ edges = [
164
+ [0, 1],
165
+ [1, 2],
166
+ [2, 3],
167
+ [3, 4],
168
+ [0, 5],
169
+ [5, 6],
170
+ [6, 7],
171
+ [7, 8],
172
+ [0, 9],
173
+ [9, 10],
174
+ [10, 11],
175
+ [11, 12],
176
+ [0, 13],
177
+ [13, 14],
178
+ [14, 15],
179
+ [15, 16],
180
+ [0, 17],
181
+ [17, 18],
182
+ [18, 19],
183
+ [19, 20],
184
+ ]
185
+
186
+ for peaks in all_hand_peaks:
187
+ peaks = np.array(peaks)
188
+
189
+ for ie, e in enumerate(edges):
190
+ x1, y1 = peaks[e[0]]
191
+ x2, y2 = peaks[e[1]]
192
+ x1 = int(x1 * W)
193
+ y1 = int(y1 * H)
194
+ x2 = int(x2 * W)
195
+ y2 = int(y2 * H)
196
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
197
+ cv2.line(
198
+ canvas,
199
+ (x1, y1),
200
+ (x2, y2),
201
+ matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0])
202
+ * 255,
203
+ thickness=2,
204
+ )
205
+
206
+ for i, keyponit in enumerate(peaks):
207
+ x, y = keyponit
208
+ x = int(x * W)
209
+ y = int(y * H)
210
+ if x > eps and y > eps:
211
+ cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
212
+ return canvas
213
+
214
+
215
+ def draw_facepose(canvas, all_lmks):
216
+ H, W, C = canvas.shape
217
+ for lmks in all_lmks:
218
+ lmks = np.array(lmks)
219
+ for lmk in lmks:
220
+ x, y = lmk
221
+ x = int(x * W)
222
+ y = int(y * H)
223
+ if x > eps and y > eps:
224
+ cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
225
+ return canvas
226
+
227
+
228
+ # detect hand according to body pose keypoints
229
+ # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
230
+ def handDetect(candidate, subset, oriImg):
231
+ # right hand: wrist 4, elbow 3, shoulder 2
232
+ # left hand: wrist 7, elbow 6, shoulder 5
233
+ ratioWristElbow = 0.33
234
+ detect_result = []
235
+ image_height, image_width = oriImg.shape[0:2]
236
+ for person in subset.astype(int):
237
+ # if any of three not detected
238
+ has_left = np.sum(person[[5, 6, 7]] == -1) == 0
239
+ has_right = np.sum(person[[2, 3, 4]] == -1) == 0
240
+ if not (has_left or has_right):
241
+ continue
242
+ hands = []
243
+ # left hand
244
+ if has_left:
245
+ left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
246
+ x1, y1 = candidate[left_shoulder_index][:2]
247
+ x2, y2 = candidate[left_elbow_index][:2]
248
+ x3, y3 = candidate[left_wrist_index][:2]
249
+ hands.append([x1, y1, x2, y2, x3, y3, True])
250
+ # right hand
251
+ if has_right:
252
+ right_shoulder_index, right_elbow_index, right_wrist_index = person[
253
+ [2, 3, 4]
254
+ ]
255
+ x1, y1 = candidate[right_shoulder_index][:2]
256
+ x2, y2 = candidate[right_elbow_index][:2]
257
+ x3, y3 = candidate[right_wrist_index][:2]
258
+ hands.append([x1, y1, x2, y2, x3, y3, False])
259
+
260
+ for x1, y1, x2, y2, x3, y3, is_left in hands:
261
+ # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
262
+ # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
263
+ # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
264
+ # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
265
+ # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
266
+ # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
267
+ x = x3 + ratioWristElbow * (x3 - x2)
268
+ y = y3 + ratioWristElbow * (y3 - y2)
269
+ distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
270
+ distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
271
+ width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
272
+ # x-y refers to the center --> offset to topLeft point
273
+ # handRectangle.x -= handRectangle.width / 2.f;
274
+ # handRectangle.y -= handRectangle.height / 2.f;
275
+ x -= width / 2
276
+ y -= width / 2 # width = height
277
+ # overflow the image
278
+ if x < 0:
279
+ x = 0
280
+ if y < 0:
281
+ y = 0
282
+ width1 = width
283
+ width2 = width
284
+ if x + width > image_width:
285
+ width1 = image_width - x
286
+ if y + width > image_height:
287
+ width2 = image_height - y
288
+ width = min(width1, width2)
289
+ # the max hand box value is 20 pixels
290
+ if width >= 20:
291
+ detect_result.append([int(x), int(y), int(width), is_left])
292
+
293
+ """
294
+ return value: [[x, y, w, True if left hand else False]].
295
+ width=height since the network require squared input.
296
+ x, y is the coordinate of top left
297
+ """
298
+ return detect_result
299
+
300
+
301
+ # Written by Lvmin
302
+ def faceDetect(candidate, subset, oriImg):
303
+ # left right eye ear 14 15 16 17
304
+ detect_result = []
305
+ image_height, image_width = oriImg.shape[0:2]
306
+ for person in subset.astype(int):
307
+ has_head = person[0] > -1
308
+ if not has_head:
309
+ continue
310
+
311
+ has_left_eye = person[14] > -1
312
+ has_right_eye = person[15] > -1
313
+ has_left_ear = person[16] > -1
314
+ has_right_ear = person[17] > -1
315
+
316
+ if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
317
+ continue
318
+
319
+ head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
320
+
321
+ width = 0.0
322
+ x0, y0 = candidate[head][:2]
323
+
324
+ if has_left_eye:
325
+ x1, y1 = candidate[left_eye][:2]
326
+ d = max(abs(x0 - x1), abs(y0 - y1))
327
+ width = max(width, d * 3.0)
328
+
329
+ if has_right_eye:
330
+ x1, y1 = candidate[right_eye][:2]
331
+ d = max(abs(x0 - x1), abs(y0 - y1))
332
+ width = max(width, d * 3.0)
333
+
334
+ if has_left_ear:
335
+ x1, y1 = candidate[left_ear][:2]
336
+ d = max(abs(x0 - x1), abs(y0 - y1))
337
+ width = max(width, d * 1.5)
338
+
339
+ if has_right_ear:
340
+ x1, y1 = candidate[right_ear][:2]
341
+ d = max(abs(x0 - x1), abs(y0 - y1))
342
+ width = max(width, d * 1.5)
343
+
344
+ x, y = x0, y0
345
+
346
+ x -= width
347
+ y -= width
348
+
349
+ if x < 0:
350
+ x = 0
351
+
352
+ if y < 0:
353
+ y = 0
354
+
355
+ width1 = width * 2
356
+ width2 = width * 2
357
+
358
+ if x + width > image_width:
359
+ width1 = image_width - x
360
+
361
+ if y + width > image_height:
362
+ width2 = image_height - y
363
+
364
+ width = min(width1, width2)
365
+
366
+ if width >= 20:
367
+ detect_result.append([int(x), int(y), int(width)])
368
+
369
+ return detect_result
370
+
371
+
372
+ # get max index of 2d array
373
+ def npmax(array):
374
+ arrayindex = array.argmax(1)
375
+ arrayvalue = array.max(1)
376
+ i = arrayvalue.argmax()
377
+ j = arrayindex[i]
378
+ return i, j
src/dwpose/wholebody.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/IDEA-Research/DWPose
2
+ from pathlib import Path
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import onnxruntime as ort
7
+
8
+ from .onnxdet import inference_detector
9
+ from .onnxpose import inference_pose
10
+
11
+ ModelDataPathPrefix = Path("./pretrained_weights")
12
+
13
+
14
+ class Wholebody:
15
+ def __init__(self, device="cuda:0"):
16
+ providers = (
17
+ ["CPUExecutionProvider"] if device == "cpu" else ["CUDAExecutionProvider"]
18
+ )
19
+ onnx_det = ModelDataPathPrefix.joinpath("DWPose/yolox_l.onnx")
20
+ onnx_pose = ModelDataPathPrefix.joinpath("DWPose/dw-ll_ucoco_384.onnx")
21
+
22
+ self.session_det = ort.InferenceSession(
23
+ path_or_bytes=onnx_det, providers=providers
24
+ )
25
+ self.session_pose = ort.InferenceSession(
26
+ path_or_bytes=onnx_pose, providers=providers
27
+ )
28
+
29
+ def __call__(self, oriImg):
30
+ det_result = inference_detector(self.session_det, oriImg)
31
+ keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
32
+
33
+ keypoints_info = np.concatenate((keypoints, scores[..., None]), axis=-1)
34
+ # compute neck joint
35
+ neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
36
+ # neck score when visualizing pred
37
+ neck[:, 2:4] = np.logical_and(
38
+ keypoints_info[:, 5, 2:4] > 0.3, keypoints_info[:, 6, 2:4] > 0.3
39
+ ).astype(int)
40
+ new_keypoints_info = np.insert(keypoints_info, 17, neck, axis=1)
41
+ mmpose_idx = [17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3]
42
+ openpose_idx = [1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17]
43
+ new_keypoints_info[:, openpose_idx] = new_keypoints_info[:, mmpose_idx]
44
+ keypoints_info = new_keypoints_info
45
+
46
+ keypoints, scores = keypoints_info[..., :2], keypoints_info[..., 2]
47
+
48
+ return keypoints, scores
src/models/attention.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward
7
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
8
+ from einops import rearrange
9
+ from torch import nn
10
+
11
+
12
+ class BasicTransformerBlock(nn.Module):
13
+ r"""
14
+ A basic Transformer block.
15
+
16
+ Parameters:
17
+ dim (`int`): The number of channels in the input and output.
18
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
19
+ attention_head_dim (`int`): The number of channels in each head.
20
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
21
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
22
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
23
+ num_embeds_ada_norm (:
24
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
25
+ attention_bias (:
26
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
27
+ only_cross_attention (`bool`, *optional*):
28
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
29
+ double_self_attention (`bool`, *optional*):
30
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
31
+ upcast_attention (`bool`, *optional*):
32
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
33
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
34
+ Whether to use learnable elementwise affine parameters for normalization.
35
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
36
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
37
+ final_dropout (`bool` *optional*, defaults to False):
38
+ Whether to apply a final dropout after the last feed-forward layer.
39
+ attention_type (`str`, *optional*, defaults to `"default"`):
40
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
41
+ positional_embeddings (`str`, *optional*, defaults to `None`):
42
+ The type of positional embeddings to apply to.
43
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
44
+ The maximum number of positional embeddings to apply.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ dim: int,
50
+ num_attention_heads: int,
51
+ attention_head_dim: int,
52
+ dropout=0.0,
53
+ cross_attention_dim: Optional[int] = None,
54
+ activation_fn: str = "geglu",
55
+ num_embeds_ada_norm: Optional[int] = None,
56
+ attention_bias: bool = False,
57
+ only_cross_attention: bool = False,
58
+ double_self_attention: bool = False,
59
+ upcast_attention: bool = False,
60
+ norm_elementwise_affine: bool = True,
61
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
62
+ norm_eps: float = 1e-5,
63
+ final_dropout: bool = False,
64
+ attention_type: str = "default",
65
+ positional_embeddings: Optional[str] = None,
66
+ num_positional_embeddings: Optional[int] = None,
67
+ ):
68
+ super().__init__()
69
+ self.only_cross_attention = only_cross_attention
70
+
71
+ self.use_ada_layer_norm_zero = (
72
+ num_embeds_ada_norm is not None
73
+ ) and norm_type == "ada_norm_zero"
74
+ self.use_ada_layer_norm = (
75
+ num_embeds_ada_norm is not None
76
+ ) and norm_type == "ada_norm"
77
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
78
+ self.use_layer_norm = norm_type == "layer_norm"
79
+
80
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
81
+ raise ValueError(
82
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
83
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
84
+ )
85
+
86
+ if positional_embeddings and (num_positional_embeddings is None):
87
+ raise ValueError(
88
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
89
+ )
90
+
91
+ if positional_embeddings == "sinusoidal":
92
+ self.pos_embed = SinusoidalPositionalEmbedding(
93
+ dim, max_seq_length=num_positional_embeddings
94
+ )
95
+ else:
96
+ self.pos_embed = None
97
+
98
+ # Define 3 blocks. Each block has its own normalization layer.
99
+ # 1. Self-Attn
100
+ if self.use_ada_layer_norm:
101
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
102
+ elif self.use_ada_layer_norm_zero:
103
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
104
+ else:
105
+ self.norm1 = nn.LayerNorm(
106
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
107
+ )
108
+
109
+ self.attn1 = Attention(
110
+ query_dim=dim,
111
+ heads=num_attention_heads,
112
+ dim_head=attention_head_dim,
113
+ dropout=dropout,
114
+ bias=attention_bias,
115
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
116
+ upcast_attention=upcast_attention,
117
+ )
118
+
119
+ # 2. Cross-Attn
120
+ if cross_attention_dim is not None or double_self_attention:
121
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
122
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
123
+ # the second cross attention block.
124
+ self.norm2 = (
125
+ AdaLayerNorm(dim, num_embeds_ada_norm)
126
+ if self.use_ada_layer_norm
127
+ else nn.LayerNorm(
128
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
129
+ )
130
+ )
131
+ self.attn2 = Attention(
132
+ query_dim=dim,
133
+ cross_attention_dim=cross_attention_dim
134
+ if not double_self_attention
135
+ else None,
136
+ heads=num_attention_heads,
137
+ dim_head=attention_head_dim,
138
+ dropout=dropout,
139
+ bias=attention_bias,
140
+ upcast_attention=upcast_attention,
141
+ ) # is self-attn if encoder_hidden_states is none
142
+ else:
143
+ self.norm2 = None
144
+ self.attn2 = None
145
+
146
+ # 3. Feed-forward
147
+ if not self.use_ada_layer_norm_single:
148
+ self.norm3 = nn.LayerNorm(
149
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
150
+ )
151
+
152
+ self.ff = FeedForward(
153
+ dim,
154
+ dropout=dropout,
155
+ activation_fn=activation_fn,
156
+ final_dropout=final_dropout,
157
+ )
158
+
159
+ # 4. Fuser
160
+ if attention_type == "gated" or attention_type == "gated-text-image":
161
+ self.fuser = GatedSelfAttentionDense(
162
+ dim, cross_attention_dim, num_attention_heads, attention_head_dim
163
+ )
164
+
165
+ # 5. Scale-shift for PixArt-Alpha.
166
+ if self.use_ada_layer_norm_single:
167
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
168
+
169
+ # let chunk size default to None
170
+ self._chunk_size = None
171
+ self._chunk_dim = 0
172
+
173
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
174
+ # Sets chunk feed-forward
175
+ self._chunk_size = chunk_size
176
+ self._chunk_dim = dim
177
+
178
+ def forward(
179
+ self,
180
+ hidden_states: torch.FloatTensor,
181
+ attention_mask: Optional[torch.FloatTensor] = None,
182
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
183
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
184
+ timestep: Optional[torch.LongTensor] = None,
185
+ cross_attention_kwargs: Dict[str, Any] = None,
186
+ class_labels: Optional[torch.LongTensor] = None,
187
+ ) -> torch.FloatTensor:
188
+ # Notice that normalization is always applied before the real computation in the following blocks.
189
+ # 0. Self-Attention
190
+ batch_size = hidden_states.shape[0]
191
+
192
+ if self.use_ada_layer_norm:
193
+ norm_hidden_states = self.norm1(hidden_states, timestep)
194
+ elif self.use_ada_layer_norm_zero:
195
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
196
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
197
+ )
198
+ elif self.use_layer_norm:
199
+ norm_hidden_states = self.norm1(hidden_states)
200
+ elif self.use_ada_layer_norm_single:
201
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
202
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
203
+ ).chunk(6, dim=1)
204
+ norm_hidden_states = self.norm1(hidden_states)
205
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
206
+ norm_hidden_states = norm_hidden_states.squeeze(1)
207
+ else:
208
+ raise ValueError("Incorrect norm used")
209
+
210
+ if self.pos_embed is not None:
211
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
212
+
213
+ # 1. Retrieve lora scale.
214
+ lora_scale = (
215
+ cross_attention_kwargs.get("scale", 1.0)
216
+ if cross_attention_kwargs is not None
217
+ else 1.0
218
+ )
219
+
220
+ # 2. Prepare GLIGEN inputs
221
+ cross_attention_kwargs = (
222
+ cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
223
+ )
224
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
225
+
226
+ attn_output = self.attn1(
227
+ norm_hidden_states,
228
+ encoder_hidden_states=encoder_hidden_states
229
+ if self.only_cross_attention
230
+ else None,
231
+ attention_mask=attention_mask,
232
+ **cross_attention_kwargs,
233
+ )
234
+ if self.use_ada_layer_norm_zero:
235
+ attn_output = gate_msa.unsqueeze(1) * attn_output
236
+ elif self.use_ada_layer_norm_single:
237
+ attn_output = gate_msa * attn_output
238
+
239
+ hidden_states = attn_output + hidden_states
240
+ if hidden_states.ndim == 4:
241
+ hidden_states = hidden_states.squeeze(1)
242
+
243
+ # 2.5 GLIGEN Control
244
+ if gligen_kwargs is not None:
245
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
246
+
247
+ # 3. Cross-Attention
248
+ if self.attn2 is not None:
249
+ if self.use_ada_layer_norm:
250
+ norm_hidden_states = self.norm2(hidden_states, timestep)
251
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
252
+ norm_hidden_states = self.norm2(hidden_states)
253
+ elif self.use_ada_layer_norm_single:
254
+ # For PixArt norm2 isn't applied here:
255
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
256
+ norm_hidden_states = hidden_states
257
+ else:
258
+ raise ValueError("Incorrect norm")
259
+
260
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
261
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
262
+
263
+ attn_output = self.attn2(
264
+ norm_hidden_states,
265
+ encoder_hidden_states=encoder_hidden_states,
266
+ attention_mask=encoder_attention_mask,
267
+ **cross_attention_kwargs,
268
+ )
269
+ hidden_states = attn_output + hidden_states
270
+
271
+ # 4. Feed-forward
272
+ if not self.use_ada_layer_norm_single:
273
+ norm_hidden_states = self.norm3(hidden_states)
274
+
275
+ if self.use_ada_layer_norm_zero:
276
+ norm_hidden_states = (
277
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
278
+ )
279
+
280
+ if self.use_ada_layer_norm_single:
281
+ norm_hidden_states = self.norm2(hidden_states)
282
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
283
+
284
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
285
+
286
+ if self.use_ada_layer_norm_zero:
287
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
288
+ elif self.use_ada_layer_norm_single:
289
+ ff_output = gate_mlp * ff_output
290
+
291
+ hidden_states = ff_output + hidden_states
292
+ if hidden_states.ndim == 4:
293
+ hidden_states = hidden_states.squeeze(1)
294
+
295
+ return hidden_states
296
+
297
+
298
+ class TemporalBasicTransformerBlock(nn.Module):
299
+ def __init__(
300
+ self,
301
+ dim: int,
302
+ num_attention_heads: int,
303
+ attention_head_dim: int,
304
+ dropout=0.0,
305
+ cross_attention_dim: Optional[int] = None,
306
+ activation_fn: str = "geglu",
307
+ num_embeds_ada_norm: Optional[int] = None,
308
+ attention_bias: bool = False,
309
+ only_cross_attention: bool = False,
310
+ upcast_attention: bool = False,
311
+ unet_use_cross_frame_attention=None,
312
+ unet_use_temporal_attention=None,
313
+ name=None,
314
+ ):
315
+ super().__init__()
316
+ self.only_cross_attention = only_cross_attention
317
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
318
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
319
+ self.unet_use_temporal_attention = unet_use_temporal_attention
320
+ self.name=name
321
+
322
+ # SC-Attn
323
+ self.attn1 = Attention(
324
+ query_dim=dim,
325
+ heads=num_attention_heads,
326
+ dim_head=attention_head_dim,
327
+ dropout=dropout,
328
+ bias=attention_bias,
329
+ upcast_attention=upcast_attention,
330
+ )
331
+ self.norm1 = (
332
+ AdaLayerNorm(dim, num_embeds_ada_norm)
333
+ if self.use_ada_layer_norm
334
+ else nn.LayerNorm(dim)
335
+ )
336
+
337
+ # Cross-Attn
338
+ if cross_attention_dim is not None:
339
+ self.attn2 = Attention(
340
+ query_dim=dim,
341
+ cross_attention_dim=cross_attention_dim,
342
+ heads=num_attention_heads,
343
+ dim_head=attention_head_dim,
344
+ dropout=dropout,
345
+ bias=attention_bias,
346
+ upcast_attention=upcast_attention,
347
+ )
348
+ else:
349
+ self.attn2 = None
350
+
351
+ if cross_attention_dim is not None:
352
+ self.norm2 = (
353
+ AdaLayerNorm(dim, num_embeds_ada_norm)
354
+ if self.use_ada_layer_norm
355
+ else nn.LayerNorm(dim)
356
+ )
357
+ else:
358
+ self.norm2 = None
359
+
360
+ # Feed-forward
361
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
362
+ self.norm3 = nn.LayerNorm(dim)
363
+ self.use_ada_layer_norm_zero = False
364
+ # Temp-Attn
365
+ assert unet_use_temporal_attention is not None
366
+ if unet_use_temporal_attention:
367
+ self.attn_temp = Attention(
368
+ query_dim=dim,
369
+ heads=num_attention_heads,
370
+ dim_head=attention_head_dim,
371
+ dropout=dropout,
372
+ bias=attention_bias,
373
+ upcast_attention=upcast_attention,
374
+ )
375
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
376
+ self.norm_temp = (
377
+ AdaLayerNorm(dim, num_embeds_ada_norm)
378
+ if self.use_ada_layer_norm
379
+ else nn.LayerNorm(dim)
380
+ )
381
+
382
+ def forward(
383
+ self,
384
+ hidden_states,
385
+ encoder_hidden_states=None,
386
+ timestep=None,
387
+ attention_mask=None,
388
+ video_length=None,
389
+ self_attention_additional_feats=None,
390
+ mode=None,
391
+ ):
392
+ norm_hidden_states = (
393
+ self.norm1(hidden_states, timestep)
394
+ if self.use_ada_layer_norm
395
+ else self.norm1(hidden_states)
396
+ )
397
+ if self.name:
398
+ modify_norm_hidden_states = norm_hidden_states
399
+ if mode == "write":
400
+ self_attention_additional_feats[self.name]=norm_hidden_states
401
+ elif mode == "read" and self_attention_additional_feats:
402
+ ref_states = self_attention_additional_feats[self.name]
403
+ bank_fea = [
404
+ rearrange(
405
+ ref_states.unsqueeze(1).repeat(1, video_length, 1, 1),
406
+ "b t l c -> (b t) l c",
407
+ )
408
+ ]
409
+ modify_norm_hidden_states = torch.cat(
410
+ [norm_hidden_states] + bank_fea, dim=1
411
+ )
412
+
413
+ if self.unet_use_cross_frame_attention:
414
+ hidden_states = (
415
+ self.attn1(
416
+ norm_hidden_states,
417
+ attention_mask=attention_mask,
418
+ encoder_hidden_states=modify_norm_hidden_states,
419
+ video_length=video_length,
420
+ )
421
+ + hidden_states
422
+ )
423
+ else:
424
+ hidden_states = (
425
+ self.attn1(
426
+ norm_hidden_states,
427
+ encoder_hidden_states=modify_norm_hidden_states,
428
+ attention_mask=attention_mask
429
+ )
430
+ + hidden_states
431
+ )
432
+ else:
433
+ if self.unet_use_cross_frame_attention:
434
+ hidden_states = (
435
+ self.attn1(
436
+ norm_hidden_states,
437
+ attention_mask=attention_mask,
438
+ video_length=video_length,
439
+ )
440
+ + hidden_states
441
+ )
442
+ else:
443
+ hidden_states = (
444
+ self.attn1(norm_hidden_states, attention_mask=attention_mask)
445
+ + hidden_states
446
+ )
447
+
448
+ if self.attn2 is not None:
449
+ # Cross-Attention
450
+ norm_hidden_states = (
451
+ self.norm2(hidden_states, timestep)
452
+ if self.use_ada_layer_norm
453
+ else self.norm2(hidden_states)
454
+ )
455
+ hidden_states = (
456
+ self.attn2(
457
+ norm_hidden_states,
458
+ encoder_hidden_states=encoder_hidden_states,
459
+ attention_mask=attention_mask,
460
+ )
461
+ + hidden_states
462
+ )
463
+
464
+ # Feed-forward
465
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
466
+
467
+ # Temporal-Attention
468
+ if self.unet_use_temporal_attention:
469
+ d = hidden_states.shape[1]
470
+ hidden_states = rearrange(
471
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
472
+ )
473
+ norm_hidden_states = (
474
+ self.norm_temp(hidden_states, timestep)
475
+ if self.use_ada_layer_norm
476
+ else self.norm_temp(hidden_states)
477
+ )
478
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
479
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
480
+
481
+ return hidden_states
src/models/motion_module.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+ from diffusers.models.attention import FeedForward
8
+ from diffusers.models.attention_processor import Attention, AttnProcessor
9
+ from diffusers.utils import BaseOutput
10
+ from diffusers.utils.import_utils import is_xformers_available
11
+ from einops import rearrange, repeat
12
+ from torch import nn
13
+
14
+
15
+ def zero_module(module):
16
+ # Zero out the parameters of a module and return it.
17
+ for p in module.parameters():
18
+ p.detach().zero_()
19
+ return module
20
+
21
+
22
+ @dataclass
23
+ class TemporalTransformer3DModelOutput(BaseOutput):
24
+ sample: torch.FloatTensor
25
+
26
+
27
+ if is_xformers_available():
28
+ import xformers
29
+ import xformers.ops
30
+ else:
31
+ xformers = None
32
+
33
+
34
+ def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
35
+ if motion_module_type == "Vanilla":
36
+ return VanillaTemporalModule(
37
+ in_channels=in_channels,
38
+ **motion_module_kwargs,
39
+ )
40
+ else:
41
+ raise ValueError
42
+
43
+
44
+ class VanillaTemporalModule(nn.Module):
45
+ def __init__(
46
+ self,
47
+ in_channels,
48
+ num_attention_heads=8,
49
+ num_transformer_block=2,
50
+ attention_block_types=("Temporal_Self", "Temporal_Self"),
51
+ cross_frame_attention_mode=None,
52
+ temporal_position_encoding=False,
53
+ temporal_position_encoding_max_len=24,
54
+ temporal_attention_dim_div=1,
55
+ zero_initialize=True,
56
+ ):
57
+ super().__init__()
58
+
59
+ self.temporal_transformer = TemporalTransformer3DModel(
60
+ in_channels=in_channels,
61
+ num_attention_heads=num_attention_heads,
62
+ attention_head_dim=in_channels
63
+ // num_attention_heads
64
+ // temporal_attention_dim_div,
65
+ num_layers=num_transformer_block,
66
+ attention_block_types=attention_block_types,
67
+ cross_frame_attention_mode=cross_frame_attention_mode,
68
+ temporal_position_encoding=temporal_position_encoding,
69
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
70
+ )
71
+
72
+ if zero_initialize:
73
+ self.temporal_transformer.proj_out = zero_module(
74
+ self.temporal_transformer.proj_out
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ input_tensor,
80
+ temb,
81
+ encoder_hidden_states,
82
+ attention_mask=None,
83
+ anchor_frame_idx=None,
84
+ ):
85
+ hidden_states = input_tensor
86
+ hidden_states = self.temporal_transformer(
87
+ hidden_states, encoder_hidden_states, attention_mask
88
+ )
89
+
90
+ output = hidden_states
91
+ return output
92
+
93
+
94
+ class TemporalTransformer3DModel(nn.Module):
95
+ def __init__(
96
+ self,
97
+ in_channels,
98
+ num_attention_heads,
99
+ attention_head_dim,
100
+ num_layers,
101
+ attention_block_types=(
102
+ "Temporal_Self",
103
+ "Temporal_Self",
104
+ ),
105
+ dropout=0.0,
106
+ norm_num_groups=32,
107
+ cross_attention_dim=768,
108
+ activation_fn="geglu",
109
+ attention_bias=False,
110
+ upcast_attention=False,
111
+ cross_frame_attention_mode=None,
112
+ temporal_position_encoding=False,
113
+ temporal_position_encoding_max_len=24,
114
+ ):
115
+ super().__init__()
116
+
117
+ inner_dim = num_attention_heads * attention_head_dim
118
+
119
+ self.norm = torch.nn.GroupNorm(
120
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
121
+ )
122
+ self.proj_in = nn.Linear(in_channels, inner_dim)
123
+
124
+ self.transformer_blocks = nn.ModuleList(
125
+ [
126
+ TemporalTransformerBlock(
127
+ dim=inner_dim,
128
+ num_attention_heads=num_attention_heads,
129
+ attention_head_dim=attention_head_dim,
130
+ attention_block_types=attention_block_types,
131
+ dropout=dropout,
132
+ norm_num_groups=norm_num_groups,
133
+ cross_attention_dim=cross_attention_dim,
134
+ activation_fn=activation_fn,
135
+ attention_bias=attention_bias,
136
+ upcast_attention=upcast_attention,
137
+ cross_frame_attention_mode=cross_frame_attention_mode,
138
+ temporal_position_encoding=temporal_position_encoding,
139
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
140
+ )
141
+ for d in range(num_layers)
142
+ ]
143
+ )
144
+ self.proj_out = nn.Linear(inner_dim, in_channels)
145
+
146
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
147
+ assert (
148
+ hidden_states.dim() == 5
149
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
150
+ video_length = hidden_states.shape[2]
151
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
152
+
153
+ batch, channel, height, weight = hidden_states.shape
154
+ residual = hidden_states
155
+
156
+ hidden_states = self.norm(hidden_states)
157
+ inner_dim = hidden_states.shape[1]
158
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
159
+ batch, height * weight, inner_dim
160
+ )
161
+ hidden_states = self.proj_in(hidden_states)
162
+
163
+ # Transformer Blocks
164
+ for block in self.transformer_blocks:
165
+ hidden_states = block(
166
+ hidden_states,
167
+ encoder_hidden_states=encoder_hidden_states,
168
+ video_length=video_length,
169
+ )
170
+
171
+ # output
172
+ hidden_states = self.proj_out(hidden_states)
173
+ hidden_states = (
174
+ hidden_states.reshape(batch, height, weight, inner_dim)
175
+ .permute(0, 3, 1, 2)
176
+ .contiguous()
177
+ )
178
+
179
+ output = hidden_states + residual
180
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
181
+
182
+ return output
183
+
184
+
185
+ class TemporalTransformerBlock(nn.Module):
186
+ def __init__(
187
+ self,
188
+ dim,
189
+ num_attention_heads,
190
+ attention_head_dim,
191
+ attention_block_types=(
192
+ "Temporal_Self",
193
+ "Temporal_Self",
194
+ ),
195
+ dropout=0.0,
196
+ norm_num_groups=32,
197
+ cross_attention_dim=768,
198
+ activation_fn="geglu",
199
+ attention_bias=False,
200
+ upcast_attention=False,
201
+ cross_frame_attention_mode=None,
202
+ temporal_position_encoding=False,
203
+ temporal_position_encoding_max_len=24,
204
+ ):
205
+ super().__init__()
206
+
207
+ attention_blocks = []
208
+ norms = []
209
+
210
+ for block_name in attention_block_types:
211
+ attention_blocks.append(
212
+ VersatileAttention(
213
+ attention_mode=block_name.split("_")[0],
214
+ cross_attention_dim=cross_attention_dim
215
+ if block_name.endswith("_Cross")
216
+ else None,
217
+ query_dim=dim,
218
+ heads=num_attention_heads,
219
+ dim_head=attention_head_dim,
220
+ dropout=dropout,
221
+ bias=attention_bias,
222
+ upcast_attention=upcast_attention,
223
+ cross_frame_attention_mode=cross_frame_attention_mode,
224
+ temporal_position_encoding=temporal_position_encoding,
225
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
226
+ )
227
+ )
228
+ norms.append(nn.LayerNorm(dim))
229
+
230
+ self.attention_blocks = nn.ModuleList(attention_blocks)
231
+ self.norms = nn.ModuleList(norms)
232
+
233
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
234
+ self.ff_norm = nn.LayerNorm(dim)
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states,
239
+ encoder_hidden_states=None,
240
+ attention_mask=None,
241
+ video_length=None,
242
+ ):
243
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
244
+ norm_hidden_states = norm(hidden_states)
245
+ hidden_states = (
246
+ attention_block(
247
+ norm_hidden_states,
248
+ encoder_hidden_states=encoder_hidden_states
249
+ if attention_block.is_cross_attention
250
+ else None,
251
+ video_length=video_length,
252
+ )
253
+ + hidden_states
254
+ )
255
+
256
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
257
+
258
+ output = hidden_states
259
+ return output
260
+
261
+
262
+ class PositionalEncoding(nn.Module):
263
+ def __init__(self, d_model, dropout=0.0, max_len=24):
264
+ super().__init__()
265
+ self.dropout = nn.Dropout(p=dropout)
266
+ position = torch.arange(max_len).unsqueeze(1)
267
+ div_term = torch.exp(
268
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
269
+ )
270
+ pe = torch.zeros(1, max_len, d_model)
271
+ pe[0, :, 0::2] = torch.sin(position * div_term)
272
+ pe[0, :, 1::2] = torch.cos(position * div_term)
273
+ self.register_buffer("pe", pe)
274
+
275
+ def forward(self, x):
276
+ x = x + self.pe[:, : x.size(1)]
277
+ return self.dropout(x)
278
+
279
+
280
+ class VersatileAttention(Attention):
281
+ def __init__(
282
+ self,
283
+ attention_mode=None,
284
+ cross_frame_attention_mode=None,
285
+ temporal_position_encoding=False,
286
+ temporal_position_encoding_max_len=24,
287
+ *args,
288
+ **kwargs,
289
+ ):
290
+ super().__init__(*args, **kwargs)
291
+ assert attention_mode == "Temporal"
292
+
293
+ self.attention_mode = attention_mode
294
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
295
+
296
+ self.pos_encoder = (
297
+ PositionalEncoding(
298
+ kwargs["query_dim"],
299
+ dropout=0.0,
300
+ max_len=temporal_position_encoding_max_len,
301
+ )
302
+ if (temporal_position_encoding and attention_mode == "Temporal")
303
+ else None
304
+ )
305
+
306
+ def extra_repr(self):
307
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
308
+
309
+ def set_use_memory_efficient_attention_xformers(
310
+ self,
311
+ use_memory_efficient_attention_xformers: bool,
312
+ attention_op: Optional[Callable] = None,
313
+ ):
314
+ if use_memory_efficient_attention_xformers:
315
+ if not is_xformers_available():
316
+ raise ModuleNotFoundError(
317
+ (
318
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
319
+ " xformers"
320
+ ),
321
+ name="xformers",
322
+ )
323
+ elif not torch.cuda.is_available():
324
+ raise ValueError(
325
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
326
+ " only available for GPU "
327
+ )
328
+ else:
329
+ try:
330
+ # Make sure we can run the memory efficient attention
331
+ _ = xformers.ops.memory_efficient_attention(
332
+ torch.randn((1, 2, 40), device="cuda"),
333
+ torch.randn((1, 2, 40), device="cuda"),
334
+ torch.randn((1, 2, 40), device="cuda"),
335
+ )
336
+ except Exception as e:
337
+ raise e
338
+
339
+ # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13.
340
+ # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13.
341
+ # You don't need XFormersAttnProcessor here.
342
+ # processor = XFormersAttnProcessor(
343
+ # attention_op=attention_op,
344
+ # )
345
+ processor = AttnProcessor()
346
+ else:
347
+ processor = AttnProcessor()
348
+
349
+ self.set_processor(processor)
350
+
351
+ def forward(
352
+ self,
353
+ hidden_states,
354
+ encoder_hidden_states=None,
355
+ attention_mask=None,
356
+ video_length=None,
357
+ **cross_attention_kwargs,
358
+ ):
359
+ if self.attention_mode == "Temporal":
360
+ d = hidden_states.shape[1] # d means HxW
361
+ hidden_states = rearrange(
362
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
363
+ )
364
+
365
+ if self.pos_encoder is not None:
366
+ hidden_states = self.pos_encoder(hidden_states)
367
+
368
+ encoder_hidden_states = (
369
+ repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
370
+ if encoder_hidden_states is not None
371
+ else encoder_hidden_states
372
+ )
373
+
374
+ else:
375
+ raise NotImplementedError
376
+
377
+ hidden_states = self.processor(
378
+ self,
379
+ hidden_states,
380
+ encoder_hidden_states=encoder_hidden_states,
381
+ attention_mask=attention_mask,
382
+ **cross_attention_kwargs,
383
+ )
384
+
385
+ if self.attention_mode == "Temporal":
386
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
387
+
388
+ return hidden_states
src/models/mutual_self_attention.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py
2
+ from typing import Any, Dict, Optional
3
+
4
+ import torch
5
+ from einops import rearrange
6
+
7
+ from src.models.attention import TemporalBasicTransformerBlock
8
+
9
+ #from .attention import BasicTransformerBlock
10
+ from diffusers.models.attention import BasicTransformerBlock
11
+
12
+ def torch_dfs(model: torch.nn.Module):
13
+ result = [model]
14
+ for child in model.children():
15
+ result += torch_dfs(child)
16
+ return result
17
+
18
+
19
+ class ReferenceAttentionControl:
20
+ def __init__(
21
+ self,
22
+ unet,
23
+ mode="write",
24
+ do_classifier_free_guidance=False,
25
+ attention_auto_machine_weight=float("inf"),
26
+ gn_auto_machine_weight=1.0,
27
+ style_fidelity=1.0,
28
+ reference_attn=True,
29
+ reference_adain=False,
30
+ fusion_blocks="midup",
31
+ batch_size=1,
32
+ ) -> None:
33
+ # 10. Modify self attention and group norm
34
+ self.unet = unet
35
+ assert mode in ["read", "write"]
36
+ assert fusion_blocks in ["midup", "full"]
37
+ self.reference_attn = reference_attn
38
+ self.reference_adain = reference_adain
39
+ self.fusion_blocks = fusion_blocks
40
+ self.register_reference_hooks(
41
+ mode,
42
+ do_classifier_free_guidance,
43
+ attention_auto_machine_weight,
44
+ gn_auto_machine_weight,
45
+ style_fidelity,
46
+ reference_attn,
47
+ reference_adain,
48
+ fusion_blocks,
49
+ batch_size=batch_size,
50
+ )
51
+
52
+ def register_reference_hooks(
53
+ self,
54
+ mode,
55
+ do_classifier_free_guidance,
56
+ attention_auto_machine_weight,
57
+ gn_auto_machine_weight,
58
+ style_fidelity,
59
+ reference_attn,
60
+ reference_adain,
61
+ dtype=torch.float16,
62
+ batch_size=1,
63
+ num_images_per_prompt=1,
64
+ device=torch.device("cpu"),
65
+ fusion_blocks="midup",
66
+ ):
67
+ MODE = mode
68
+ do_classifier_free_guidance = do_classifier_free_guidance
69
+ attention_auto_machine_weight = attention_auto_machine_weight
70
+ gn_auto_machine_weight = gn_auto_machine_weight
71
+ style_fidelity = style_fidelity
72
+ reference_attn = reference_attn
73
+ reference_adain = reference_adain
74
+ fusion_blocks = fusion_blocks
75
+ num_images_per_prompt = num_images_per_prompt
76
+ dtype = dtype
77
+ if do_classifier_free_guidance:
78
+ uc_mask = (
79
+ torch.Tensor(
80
+ [1] * batch_size * num_images_per_prompt * 16
81
+ + [0] * batch_size * num_images_per_prompt * 16
82
+ )
83
+ .to(device)
84
+ .bool()
85
+ )
86
+ else:
87
+ uc_mask = (
88
+ torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
89
+ .to(device)
90
+ .bool()
91
+ )
92
+
93
+ def hacked_basic_transformer_inner_forward(
94
+ self,
95
+ hidden_states: torch.FloatTensor,
96
+ attention_mask: Optional[torch.FloatTensor] = None,
97
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
98
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
99
+ timestep: Optional[torch.LongTensor] = None,
100
+ cross_attention_kwargs: Dict[str, Any] = None,
101
+ class_labels: Optional[torch.LongTensor] = None,
102
+ video_length=None,
103
+ self_attention_additional_feats=None,
104
+ mode=None,
105
+ ):
106
+ if self.use_ada_layer_norm: # False
107
+ norm_hidden_states = self.norm1(hidden_states, timestep)
108
+ elif self.use_ada_layer_norm_zero:
109
+ (
110
+ norm_hidden_states,
111
+ gate_msa,
112
+ shift_mlp,
113
+ scale_mlp,
114
+ gate_mlp,
115
+ ) = self.norm1(
116
+ hidden_states,
117
+ timestep,
118
+ class_labels,
119
+ hidden_dtype=hidden_states.dtype,
120
+ )
121
+ else:
122
+ norm_hidden_states = self.norm1(hidden_states)
123
+
124
+ # 1. Self-Attention
125
+ # self.only_cross_attention = False
126
+ cross_attention_kwargs = (
127
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
128
+ )
129
+ if self.only_cross_attention:
130
+ attn_output = self.attn1(
131
+ norm_hidden_states,
132
+ encoder_hidden_states=encoder_hidden_states
133
+ if self.only_cross_attention
134
+ else None,
135
+ attention_mask=attention_mask,
136
+ **cross_attention_kwargs,
137
+ )
138
+ else:
139
+ if MODE == "write":
140
+ self.bank.append(norm_hidden_states.clone())
141
+ attn_output = self.attn1(
142
+ norm_hidden_states,
143
+ encoder_hidden_states=encoder_hidden_states
144
+ if self.only_cross_attention
145
+ else None,
146
+ attention_mask=attention_mask,
147
+ **cross_attention_kwargs,
148
+ )
149
+ if MODE == "read":
150
+ bank_fea = [
151
+ rearrange(
152
+ d.unsqueeze(1).repeat(1, video_length, 1, 1),
153
+ "b t l c -> (b t) l c",
154
+ )
155
+ for d in self.bank
156
+ ]
157
+ modify_norm_hidden_states = torch.cat(
158
+ [norm_hidden_states] + bank_fea, dim=1
159
+ )
160
+ hidden_states_uc = (
161
+ self.attn1(
162
+ norm_hidden_states,
163
+ encoder_hidden_states=modify_norm_hidden_states,
164
+ attention_mask=attention_mask,
165
+ )
166
+ + hidden_states
167
+ )
168
+ if do_classifier_free_guidance:
169
+ hidden_states_c = hidden_states_uc.clone()
170
+ _uc_mask = uc_mask.clone()
171
+ if hidden_states.shape[0] != _uc_mask.shape[0]:
172
+ _uc_mask = (
173
+ torch.Tensor(
174
+ [1] * (hidden_states.shape[0] // 2)
175
+ + [0] * (hidden_states.shape[0] // 2)
176
+ )
177
+ .to(device)
178
+ .bool()
179
+ )
180
+ hidden_states_c[_uc_mask] = (
181
+ self.attn1(
182
+ norm_hidden_states[_uc_mask],
183
+ encoder_hidden_states=norm_hidden_states[_uc_mask],
184
+ attention_mask=attention_mask,
185
+ )
186
+ + hidden_states[_uc_mask]
187
+ )
188
+ hidden_states = hidden_states_c.clone()
189
+ else:
190
+ hidden_states = hidden_states_uc
191
+
192
+ # self.bank.clear()
193
+ if self.attn2 is not None:
194
+ # Cross-Attention
195
+ norm_hidden_states = (
196
+ self.norm2(hidden_states, timestep)
197
+ if self.use_ada_layer_norm
198
+ else self.norm2(hidden_states)
199
+ )
200
+ hidden_states = (
201
+ self.attn2(
202
+ norm_hidden_states,
203
+ encoder_hidden_states=encoder_hidden_states,
204
+ attention_mask=attention_mask,
205
+ )
206
+ + hidden_states
207
+ )
208
+
209
+ # Feed-forward
210
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
211
+
212
+ # Temporal-Attention
213
+ if self.unet_use_temporal_attention:
214
+ d = hidden_states.shape[1]
215
+ hidden_states = rearrange(
216
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
217
+ )
218
+ norm_hidden_states = (
219
+ self.norm_temp(hidden_states, timestep)
220
+ if self.use_ada_layer_norm
221
+ else self.norm_temp(hidden_states)
222
+ )
223
+ hidden_states = (
224
+ self.attn_temp(norm_hidden_states) + hidden_states
225
+ )
226
+ hidden_states = rearrange(
227
+ hidden_states, "(b d) f c -> (b f) d c", d=d
228
+ )
229
+
230
+ return hidden_states
231
+
232
+ if self.use_ada_layer_norm_zero:
233
+ attn_output = gate_msa.unsqueeze(1) * attn_output
234
+ hidden_states = attn_output + hidden_states
235
+
236
+ if self.attn2 is not None:
237
+ norm_hidden_states = (
238
+ self.norm2(hidden_states, timestep)
239
+ if self.use_ada_layer_norm
240
+ else self.norm2(hidden_states)
241
+ )
242
+
243
+ # 2. Cross-Attention
244
+ attn_output = self.attn2(
245
+ norm_hidden_states,
246
+ encoder_hidden_states=encoder_hidden_states,
247
+ attention_mask=encoder_attention_mask,
248
+ **cross_attention_kwargs,
249
+ )
250
+ hidden_states = attn_output + hidden_states
251
+
252
+ # 3. Feed-forward
253
+ norm_hidden_states = self.norm3(hidden_states)
254
+
255
+ if self.use_ada_layer_norm_zero:
256
+ norm_hidden_states = (
257
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
258
+ )
259
+
260
+ ff_output = self.ff(norm_hidden_states)
261
+
262
+ if self.use_ada_layer_norm_zero:
263
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
264
+
265
+ hidden_states = ff_output + hidden_states
266
+
267
+ return hidden_states
268
+
269
+ if self.reference_attn:
270
+ if self.fusion_blocks == "midup":
271
+ attn_modules = [
272
+ module
273
+ for module in (
274
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
275
+ )
276
+ if isinstance(module, BasicTransformerBlock)
277
+ or isinstance(module, TemporalBasicTransformerBlock)
278
+ ]
279
+ elif self.fusion_blocks == "full":
280
+ attn_modules = [
281
+ module
282
+ for module in torch_dfs(self.unet)
283
+ if isinstance(module, BasicTransformerBlock)
284
+ or isinstance(module, TemporalBasicTransformerBlock)
285
+ ]
286
+ attn_modules = sorted(
287
+ attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
288
+ )
289
+
290
+ for i, module in enumerate(attn_modules):
291
+ module._original_inner_forward = module.forward
292
+ if isinstance(module, BasicTransformerBlock):
293
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
294
+ module, BasicTransformerBlock
295
+ )
296
+ if isinstance(module, TemporalBasicTransformerBlock):
297
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
298
+ module, TemporalBasicTransformerBlock
299
+ )
300
+
301
+ module.bank = []
302
+ module.attn_weight = float(i) / float(len(attn_modules))
303
+
304
+ def update(self, writer, dtype=torch.float16):
305
+ if self.reference_attn:
306
+ if self.fusion_blocks == "midup":
307
+ reader_attn_modules = [
308
+ module
309
+ for module in (
310
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
311
+ )
312
+ if isinstance(module, TemporalBasicTransformerBlock)
313
+ ]
314
+ writer_attn_modules = [
315
+ module
316
+ for module in (
317
+ torch_dfs(writer.unet.mid_block)
318
+ + torch_dfs(writer.unet.up_blocks)
319
+ )
320
+ if isinstance(module, BasicTransformerBlock)
321
+ ]
322
+ elif self.fusion_blocks == "full":
323
+ reader_attn_modules = [
324
+ module
325
+ for module in torch_dfs(self.unet)
326
+ if isinstance(module, TemporalBasicTransformerBlock)
327
+ ]
328
+ writer_attn_modules = [
329
+ module
330
+ for module in torch_dfs(writer.unet)
331
+ if isinstance(module, BasicTransformerBlock)
332
+ ]
333
+ reader_attn_modules = sorted(
334
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
335
+ )
336
+ writer_attn_modules = sorted(
337
+ writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
338
+ )
339
+ for r, w in zip(reader_attn_modules, writer_attn_modules):
340
+ r.bank = [v.clone().to(dtype) for v in w.bank]
341
+ # w.bank.clear()
342
+
343
+ def clear(self):
344
+ if self.reference_attn:
345
+ if self.fusion_blocks == "midup":
346
+ reader_attn_modules = [
347
+ module
348
+ for module in (
349
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
350
+ )
351
+ if isinstance(module, BasicTransformerBlock)
352
+ or isinstance(module, TemporalBasicTransformerBlock)
353
+ ]
354
+ elif self.fusion_blocks == "full":
355
+ reader_attn_modules = [
356
+ module
357
+ for module in torch_dfs(self.unet)
358
+ if isinstance(module, BasicTransformerBlock)
359
+ or isinstance(module, TemporalBasicTransformerBlock)
360
+ ]
361
+ reader_attn_modules = sorted(
362
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
363
+ )
364
+ for r in reader_attn_modules:
365
+ r.bank.clear()
src/models/pose_guider.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.nn.init as init
6
+ from diffusers.models.modeling_utils import ModelMixin
7
+
8
+ from src.models.motion_module import zero_module
9
+ from src.models.resnet import InflatedConv3d
10
+
11
+
12
+ class PoseGuider(ModelMixin):
13
+ def __init__(
14
+ self,
15
+ conditioning_embedding_channels: int,
16
+ conditioning_channels: int = 3,
17
+ block_out_channels: Tuple[int] = (16, 32, 64, 128),
18
+ ):
19
+ super().__init__()
20
+ self.conv_in = InflatedConv3d(
21
+ conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
22
+ )
23
+
24
+ self.blocks = nn.ModuleList([])
25
+
26
+ for i in range(len(block_out_channels) - 1):
27
+ channel_in = block_out_channels[i]
28
+ channel_out = block_out_channels[i + 1]
29
+ self.blocks.append(
30
+ InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
31
+ )
32
+ self.blocks.append(
33
+ InflatedConv3d(
34
+ channel_in, channel_out, kernel_size=3, padding=1, stride=2
35
+ )
36
+ )
37
+
38
+ self.conv_out = zero_module(
39
+ InflatedConv3d(
40
+ block_out_channels[-1],
41
+ conditioning_embedding_channels,
42
+ kernel_size=3,
43
+ padding=1,
44
+ )
45
+ )
46
+
47
+ def forward(self, conditioning):
48
+ embedding = self.conv_in(conditioning)
49
+ embedding = F.silu(embedding)
50
+
51
+ for block in self.blocks:
52
+ embedding = block(embedding)
53
+ embedding = F.silu(embedding)
54
+
55
+ embedding = self.conv_out(embedding)
56
+
57
+ return embedding
src/models/resnet.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+
9
+ class InflatedConv3d(nn.Conv2d):
10
+ def forward(self, x):
11
+ video_length = x.shape[2]
12
+
13
+ x = rearrange(x, "b c f h w -> (b f) c h w")
14
+ x = super().forward(x)
15
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
16
+
17
+ return x
18
+
19
+
20
+ class InflatedGroupNorm(nn.GroupNorm):
21
+ def forward(self, x):
22
+ video_length = x.shape[2]
23
+
24
+ x = rearrange(x, "b c f h w -> (b f) c h w")
25
+ x = super().forward(x)
26
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
27
+
28
+ return x
29
+
30
+
31
+ class Upsample3D(nn.Module):
32
+ def __init__(
33
+ self,
34
+ channels,
35
+ use_conv=False,
36
+ use_conv_transpose=False,
37
+ out_channels=None,
38
+ name="conv",
39
+ ):
40
+ super().__init__()
41
+ self.channels = channels
42
+ self.out_channels = out_channels or channels
43
+ self.use_conv = use_conv
44
+ self.use_conv_transpose = use_conv_transpose
45
+ self.name = name
46
+
47
+ conv = None
48
+ if use_conv_transpose:
49
+ raise NotImplementedError
50
+ elif use_conv:
51
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
52
+
53
+ def forward(self, hidden_states, output_size=None):
54
+ assert hidden_states.shape[1] == self.channels
55
+
56
+ if self.use_conv_transpose:
57
+ raise NotImplementedError
58
+
59
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
60
+ dtype = hidden_states.dtype
61
+ if dtype == torch.bfloat16:
62
+ hidden_states = hidden_states.to(torch.float32)
63
+
64
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
65
+ if hidden_states.shape[0] >= 64:
66
+ hidden_states = hidden_states.contiguous()
67
+
68
+ # if `output_size` is passed we force the interpolation output
69
+ # size and do not make use of `scale_factor=2`
70
+ if output_size is None:
71
+ hidden_states = F.interpolate(
72
+ hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
73
+ )
74
+ else:
75
+ hidden_states = F.interpolate(
76
+ hidden_states, size=output_size, mode="nearest"
77
+ )
78
+
79
+ # If the input is bfloat16, we cast back to bfloat16
80
+ if dtype == torch.bfloat16:
81
+ hidden_states = hidden_states.to(dtype)
82
+
83
+ # if self.use_conv:
84
+ # if self.name == "conv":
85
+ # hidden_states = self.conv(hidden_states)
86
+ # else:
87
+ # hidden_states = self.Conv2d_0(hidden_states)
88
+ hidden_states = self.conv(hidden_states)
89
+
90
+ return hidden_states
91
+
92
+
93
+ class Downsample3D(nn.Module):
94
+ def __init__(
95
+ self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
96
+ ):
97
+ super().__init__()
98
+ self.channels = channels
99
+ self.out_channels = out_channels or channels
100
+ self.use_conv = use_conv
101
+ self.padding = padding
102
+ stride = 2
103
+ self.name = name
104
+
105
+ if use_conv:
106
+ self.conv = InflatedConv3d(
107
+ self.channels, self.out_channels, 3, stride=stride, padding=padding
108
+ )
109
+ else:
110
+ raise NotImplementedError
111
+
112
+ def forward(self, hidden_states):
113
+ assert hidden_states.shape[1] == self.channels
114
+ if self.use_conv and self.padding == 0:
115
+ raise NotImplementedError
116
+
117
+ assert hidden_states.shape[1] == self.channels
118
+ hidden_states = self.conv(hidden_states)
119
+
120
+ return hidden_states
121
+
122
+
123
+ class ResnetBlock3D(nn.Module):
124
+ def __init__(
125
+ self,
126
+ *,
127
+ in_channels,
128
+ out_channels=None,
129
+ conv_shortcut=False,
130
+ dropout=0.0,
131
+ temb_channels=512,
132
+ groups=32,
133
+ groups_out=None,
134
+ pre_norm=True,
135
+ eps=1e-6,
136
+ non_linearity="swish",
137
+ time_embedding_norm="default",
138
+ output_scale_factor=1.0,
139
+ use_in_shortcut=None,
140
+ use_inflated_groupnorm=None,
141
+ ):
142
+ super().__init__()
143
+ self.pre_norm = pre_norm
144
+ self.pre_norm = True
145
+ self.in_channels = in_channels
146
+ out_channels = in_channels if out_channels is None else out_channels
147
+ self.out_channels = out_channels
148
+ self.use_conv_shortcut = conv_shortcut
149
+ self.time_embedding_norm = time_embedding_norm
150
+ self.output_scale_factor = output_scale_factor
151
+
152
+ if groups_out is None:
153
+ groups_out = groups
154
+
155
+ assert use_inflated_groupnorm != None
156
+ if use_inflated_groupnorm:
157
+ self.norm1 = InflatedGroupNorm(
158
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
159
+ )
160
+ else:
161
+ self.norm1 = torch.nn.GroupNorm(
162
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
163
+ )
164
+
165
+ self.conv1 = InflatedConv3d(
166
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
167
+ )
168
+
169
+ if temb_channels is not None:
170
+ if self.time_embedding_norm == "default":
171
+ time_emb_proj_out_channels = out_channels
172
+ elif self.time_embedding_norm == "scale_shift":
173
+ time_emb_proj_out_channels = out_channels * 2
174
+ else:
175
+ raise ValueError(
176
+ f"unknown time_embedding_norm : {self.time_embedding_norm} "
177
+ )
178
+
179
+ self.time_emb_proj = torch.nn.Linear(
180
+ temb_channels, time_emb_proj_out_channels
181
+ )
182
+ else:
183
+ self.time_emb_proj = None
184
+
185
+ if use_inflated_groupnorm:
186
+ self.norm2 = InflatedGroupNorm(
187
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
188
+ )
189
+ else:
190
+ self.norm2 = torch.nn.GroupNorm(
191
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
192
+ )
193
+ self.dropout = torch.nn.Dropout(dropout)
194
+ self.conv2 = InflatedConv3d(
195
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
196
+ )
197
+
198
+ if non_linearity == "swish":
199
+ self.nonlinearity = lambda x: F.silu(x)
200
+ elif non_linearity == "mish":
201
+ self.nonlinearity = Mish()
202
+ elif non_linearity == "silu":
203
+ self.nonlinearity = nn.SiLU()
204
+
205
+ self.use_in_shortcut = (
206
+ self.in_channels != self.out_channels
207
+ if use_in_shortcut is None
208
+ else use_in_shortcut
209
+ )
210
+
211
+ self.conv_shortcut = None
212
+ if self.use_in_shortcut:
213
+ self.conv_shortcut = InflatedConv3d(
214
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
215
+ )
216
+
217
+ def forward(self, input_tensor, temb):
218
+ hidden_states = input_tensor
219
+
220
+ hidden_states = self.norm1(hidden_states)
221
+ hidden_states = self.nonlinearity(hidden_states)
222
+
223
+ hidden_states = self.conv1(hidden_states)
224
+
225
+ if temb is not None:
226
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
227
+
228
+ if temb is not None and self.time_embedding_norm == "default":
229
+ hidden_states = hidden_states + temb
230
+
231
+ hidden_states = self.norm2(hidden_states)
232
+
233
+ if temb is not None and self.time_embedding_norm == "scale_shift":
234
+ scale, shift = torch.chunk(temb, 2, dim=1)
235
+ hidden_states = hidden_states * (1 + scale) + shift
236
+
237
+ hidden_states = self.nonlinearity(hidden_states)
238
+
239
+ hidden_states = self.dropout(hidden_states)
240
+ hidden_states = self.conv2(hidden_states)
241
+
242
+ if self.conv_shortcut is not None:
243
+ input_tensor = self.conv_shortcut(input_tensor)
244
+
245
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
246
+
247
+ return output_tensor
248
+
249
+
250
+ class Mish(torch.nn.Module):
251
+ def forward(self, hidden_states):
252
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
src/models/transformer_2d.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.models.embeddings import CaptionProjection
8
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
9
+ from diffusers.models.modeling_utils import ModelMixin
10
+ from diffusers.models.normalization import AdaLayerNormSingle
11
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
12
+ from torch import nn
13
+
14
+ from .attention import BasicTransformerBlock
15
+
16
+
17
+ @dataclass
18
+ class Transformer2DModelOutput(BaseOutput):
19
+ """
20
+ The output of [`Transformer2DModel`].
21
+
22
+ Args:
23
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
24
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
25
+ distributions for the unnoised latent pixels.
26
+ """
27
+
28
+ sample: torch.FloatTensor
29
+ ref_feature: torch.FloatTensor
30
+
31
+
32
+ class Transformer2DModel(ModelMixin, ConfigMixin):
33
+ """
34
+ A 2D Transformer model for image-like data.
35
+
36
+ Parameters:
37
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
38
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
39
+ in_channels (`int`, *optional*):
40
+ The number of channels in the input and output (specify if the input is **continuous**).
41
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
42
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
43
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
44
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
45
+ This is fixed during training since it is used to learn a number of position embeddings.
46
+ num_vector_embeds (`int`, *optional*):
47
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
48
+ Includes the class for the masked latent pixel.
49
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
50
+ num_embeds_ada_norm ( `int`, *optional*):
51
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
52
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
53
+ added to the hidden states.
54
+
55
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
56
+ attention_bias (`bool`, *optional*):
57
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
58
+ """
59
+
60
+ _supports_gradient_checkpointing = True
61
+
62
+ @register_to_config
63
+ def __init__(
64
+ self,
65
+ num_attention_heads: int = 16,
66
+ attention_head_dim: int = 88,
67
+ in_channels: Optional[int] = None,
68
+ out_channels: Optional[int] = None,
69
+ num_layers: int = 1,
70
+ dropout: float = 0.0,
71
+ norm_num_groups: int = 32,
72
+ cross_attention_dim: Optional[int] = None,
73
+ attention_bias: bool = False,
74
+ sample_size: Optional[int] = None,
75
+ num_vector_embeds: Optional[int] = None,
76
+ patch_size: Optional[int] = None,
77
+ activation_fn: str = "geglu",
78
+ num_embeds_ada_norm: Optional[int] = None,
79
+ use_linear_projection: bool = False,
80
+ only_cross_attention: bool = False,
81
+ double_self_attention: bool = False,
82
+ upcast_attention: bool = False,
83
+ norm_type: str = "layer_norm",
84
+ norm_elementwise_affine: bool = True,
85
+ norm_eps: float = 1e-5,
86
+ attention_type: str = "default",
87
+ caption_channels: int = None,
88
+ ):
89
+ super().__init__()
90
+ self.use_linear_projection = use_linear_projection
91
+ self.num_attention_heads = num_attention_heads
92
+ self.attention_head_dim = attention_head_dim
93
+ inner_dim = num_attention_heads * attention_head_dim
94
+
95
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
96
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
97
+
98
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
99
+ # Define whether input is continuous or discrete depending on configuration
100
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
101
+ self.is_input_vectorized = num_vector_embeds is not None
102
+ self.is_input_patches = in_channels is not None and patch_size is not None
103
+
104
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
105
+ deprecation_message = (
106
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
107
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
108
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
109
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
110
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
111
+ )
112
+ deprecate(
113
+ "norm_type!=num_embeds_ada_norm",
114
+ "1.0.0",
115
+ deprecation_message,
116
+ standard_warn=False,
117
+ )
118
+ norm_type = "ada_norm"
119
+
120
+ if self.is_input_continuous and self.is_input_vectorized:
121
+ raise ValueError(
122
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
123
+ " sure that either `in_channels` or `num_vector_embeds` is None."
124
+ )
125
+ elif self.is_input_vectorized and self.is_input_patches:
126
+ raise ValueError(
127
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
128
+ " sure that either `num_vector_embeds` or `num_patches` is None."
129
+ )
130
+ elif (
131
+ not self.is_input_continuous
132
+ and not self.is_input_vectorized
133
+ and not self.is_input_patches
134
+ ):
135
+ raise ValueError(
136
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
137
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
138
+ )
139
+
140
+ # 2. Define input layers
141
+ self.in_channels = in_channels
142
+
143
+ self.norm = torch.nn.GroupNorm(
144
+ num_groups=norm_num_groups,
145
+ num_channels=in_channels,
146
+ eps=1e-6,
147
+ affine=True,
148
+ )
149
+ if use_linear_projection:
150
+ self.proj_in = linear_cls(in_channels, inner_dim)
151
+ else:
152
+ self.proj_in = conv_cls(
153
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
154
+ )
155
+
156
+ # 3. Define transformers blocks
157
+ self.transformer_blocks = nn.ModuleList(
158
+ [
159
+ BasicTransformerBlock(
160
+ inner_dim,
161
+ num_attention_heads,
162
+ attention_head_dim,
163
+ dropout=dropout,
164
+ cross_attention_dim=cross_attention_dim,
165
+ activation_fn=activation_fn,
166
+ num_embeds_ada_norm=num_embeds_ada_norm,
167
+ attention_bias=attention_bias,
168
+ only_cross_attention=only_cross_attention,
169
+ double_self_attention=double_self_attention,
170
+ upcast_attention=upcast_attention,
171
+ norm_type=norm_type,
172
+ norm_elementwise_affine=norm_elementwise_affine,
173
+ norm_eps=norm_eps,
174
+ attention_type=attention_type,
175
+ )
176
+ for d in range(num_layers)
177
+ ]
178
+ )
179
+
180
+ # 4. Define output layers
181
+ self.out_channels = in_channels if out_channels is None else out_channels
182
+ # TODO: should use out_channels for continuous projections
183
+ if use_linear_projection:
184
+ self.proj_out = linear_cls(inner_dim, in_channels)
185
+ else:
186
+ self.proj_out = conv_cls(
187
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
188
+ )
189
+
190
+ # 5. PixArt-Alpha blocks.
191
+ self.adaln_single = None
192
+ self.use_additional_conditions = False
193
+ if norm_type == "ada_norm_single":
194
+ self.use_additional_conditions = self.config.sample_size == 128
195
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
196
+ # additional conditions until we find better name
197
+ self.adaln_single = AdaLayerNormSingle(
198
+ inner_dim, use_additional_conditions=self.use_additional_conditions
199
+ )
200
+
201
+ self.caption_projection = None
202
+ if caption_channels is not None:
203
+ self.caption_projection = CaptionProjection(
204
+ in_features=caption_channels, hidden_size=inner_dim
205
+ )
206
+
207
+ self.gradient_checkpointing = False
208
+
209
+ def _set_gradient_checkpointing(self, module, value=False):
210
+ if hasattr(module, "gradient_checkpointing"):
211
+ module.gradient_checkpointing = value
212
+
213
+ def forward(
214
+ self,
215
+ hidden_states: torch.Tensor,
216
+ encoder_hidden_states: Optional[torch.Tensor] = None,
217
+ timestep: Optional[torch.LongTensor] = None,
218
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
219
+ class_labels: Optional[torch.LongTensor] = None,
220
+ cross_attention_kwargs: Dict[str, Any] = None,
221
+ attention_mask: Optional[torch.Tensor] = None,
222
+ encoder_attention_mask: Optional[torch.Tensor] = None,
223
+ return_dict: bool = True,
224
+ ):
225
+ """
226
+ The [`Transformer2DModel`] forward method.
227
+
228
+ Args:
229
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
230
+ Input `hidden_states`.
231
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
232
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
233
+ self-attention.
234
+ timestep ( `torch.LongTensor`, *optional*):
235
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
236
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
237
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
238
+ `AdaLayerZeroNorm`.
239
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
240
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
241
+ `self.processor` in
242
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
243
+ attention_mask ( `torch.Tensor`, *optional*):
244
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
245
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
246
+ negative values to the attention scores corresponding to "discard" tokens.
247
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
248
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
249
+
250
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
251
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
252
+
253
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
254
+ above. This bias will be added to the cross-attention scores.
255
+ return_dict (`bool`, *optional*, defaults to `True`):
256
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
257
+ tuple.
258
+
259
+ Returns:
260
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
261
+ `tuple` where the first element is the sample tensor.
262
+ """
263
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
264
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
265
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
266
+ # expects mask of shape:
267
+ # [batch, key_tokens]
268
+ # adds singleton query_tokens dimension:
269
+ # [batch, 1, key_tokens]
270
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
271
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
272
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
273
+ if attention_mask is not None and attention_mask.ndim == 2:
274
+ # assume that mask is expressed as:
275
+ # (1 = keep, 0 = discard)
276
+ # convert mask into a bias that can be added to attention scores:
277
+ # (keep = +0, discard = -10000.0)
278
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
279
+ attention_mask = attention_mask.unsqueeze(1)
280
+
281
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
282
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
283
+ encoder_attention_mask = (
284
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
285
+ ) * -10000.0
286
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
287
+
288
+ # Retrieve lora scale.
289
+ lora_scale = (
290
+ cross_attention_kwargs.get("scale", 1.0)
291
+ if cross_attention_kwargs is not None
292
+ else 1.0
293
+ )
294
+
295
+ # 1. Input
296
+ batch, _, height, width = hidden_states.shape
297
+ residual = hidden_states
298
+
299
+ hidden_states = self.norm(hidden_states)
300
+ if not self.use_linear_projection:
301
+ hidden_states = (
302
+ self.proj_in(hidden_states, scale=lora_scale)
303
+ if not USE_PEFT_BACKEND
304
+ else self.proj_in(hidden_states)
305
+ )
306
+ inner_dim = hidden_states.shape[1]
307
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
308
+ batch, height * width, inner_dim
309
+ )
310
+ else:
311
+ inner_dim = hidden_states.shape[1]
312
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
313
+ batch, height * width, inner_dim
314
+ )
315
+ hidden_states = (
316
+ self.proj_in(hidden_states, scale=lora_scale)
317
+ if not USE_PEFT_BACKEND
318
+ else self.proj_in(hidden_states)
319
+ )
320
+
321
+ # 2. Blocks
322
+ if self.caption_projection is not None:
323
+ batch_size = hidden_states.shape[0]
324
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
325
+ encoder_hidden_states = encoder_hidden_states.view(
326
+ batch_size, -1, hidden_states.shape[-1]
327
+ )
328
+
329
+ ref_feature = hidden_states.reshape(batch, height, width, inner_dim)
330
+ for block in self.transformer_blocks:
331
+ if self.training and self.gradient_checkpointing:
332
+
333
+ def create_custom_forward(module, return_dict=None):
334
+ def custom_forward(*inputs):
335
+ if return_dict is not None:
336
+ return module(*inputs, return_dict=return_dict)
337
+ else:
338
+ return module(*inputs)
339
+
340
+ return custom_forward
341
+
342
+ ckpt_kwargs: Dict[str, Any] = (
343
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
344
+ )
345
+ hidden_states = torch.utils.checkpoint.checkpoint(
346
+ create_custom_forward(block),
347
+ hidden_states,
348
+ attention_mask,
349
+ encoder_hidden_states,
350
+ encoder_attention_mask,
351
+ timestep,
352
+ cross_attention_kwargs,
353
+ class_labels,
354
+ **ckpt_kwargs,
355
+ )
356
+ else:
357
+ hidden_states = block(
358
+ hidden_states,
359
+ attention_mask=attention_mask,
360
+ encoder_hidden_states=encoder_hidden_states,
361
+ encoder_attention_mask=encoder_attention_mask,
362
+ timestep=timestep,
363
+ cross_attention_kwargs=cross_attention_kwargs,
364
+ class_labels=class_labels,
365
+ )
366
+
367
+ # 3. Output
368
+ if self.is_input_continuous:
369
+ if not self.use_linear_projection:
370
+ hidden_states = (
371
+ hidden_states.reshape(batch, height, width, inner_dim)
372
+ .permute(0, 3, 1, 2)
373
+ .contiguous()
374
+ )
375
+ hidden_states = (
376
+ self.proj_out(hidden_states, scale=lora_scale)
377
+ if not USE_PEFT_BACKEND
378
+ else self.proj_out(hidden_states)
379
+ )
380
+ else:
381
+ hidden_states = (
382
+ self.proj_out(hidden_states, scale=lora_scale)
383
+ if not USE_PEFT_BACKEND
384
+ else self.proj_out(hidden_states)
385
+ )
386
+ hidden_states = (
387
+ hidden_states.reshape(batch, height, width, inner_dim)
388
+ .permute(0, 3, 1, 2)
389
+ .contiguous()
390
+ )
391
+
392
+ output = hidden_states + residual
393
+ if not return_dict:
394
+ return (output, ref_feature)
395
+
396
+ return Transformer2DModelOutput(sample=output, ref_feature=ref_feature)
src/models/transformer_3d.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from diffusers.models import ModelMixin
7
+ from diffusers.utils import BaseOutput
8
+ from diffusers.utils.import_utils import is_xformers_available
9
+ from einops import rearrange, repeat
10
+ from torch import nn
11
+
12
+ from .attention import TemporalBasicTransformerBlock
13
+
14
+
15
+ @dataclass
16
+ class Transformer3DModelOutput(BaseOutput):
17
+ sample: torch.FloatTensor
18
+
19
+
20
+ if is_xformers_available():
21
+ import xformers
22
+ import xformers.ops
23
+ else:
24
+ xformers = None
25
+
26
+
27
+ class Transformer3DModel(ModelMixin, ConfigMixin):
28
+ _supports_gradient_checkpointing = True
29
+
30
+ @register_to_config
31
+ def __init__(
32
+ self,
33
+ num_attention_heads: int = 16,
34
+ attention_head_dim: int = 88,
35
+ in_channels: Optional[int] = None,
36
+ num_layers: int = 1,
37
+ dropout: float = 0.0,
38
+ norm_num_groups: int = 32,
39
+ cross_attention_dim: Optional[int] = None,
40
+ attention_bias: bool = False,
41
+ activation_fn: str = "geglu",
42
+ num_embeds_ada_norm: Optional[int] = None,
43
+ use_linear_projection: bool = False,
44
+ only_cross_attention: bool = False,
45
+ upcast_attention: bool = False,
46
+ unet_use_cross_frame_attention=None,
47
+ unet_use_temporal_attention=None,
48
+ name=None,
49
+ ):
50
+ super().__init__()
51
+ self.use_linear_projection = use_linear_projection
52
+ self.num_attention_heads = num_attention_heads
53
+ self.attention_head_dim = attention_head_dim
54
+ inner_dim = num_attention_heads * attention_head_dim
55
+
56
+ # Define input layers
57
+ self.in_channels = in_channels
58
+ self.name=name
59
+
60
+ self.norm = torch.nn.GroupNorm(
61
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
62
+ )
63
+ if use_linear_projection:
64
+ self.proj_in = nn.Linear(in_channels, inner_dim)
65
+ else:
66
+ self.proj_in = nn.Conv2d(
67
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
68
+ )
69
+
70
+ # Define transformers blocks
71
+ self.transformer_blocks = nn.ModuleList(
72
+ [
73
+ TemporalBasicTransformerBlock(
74
+ inner_dim,
75
+ num_attention_heads,
76
+ attention_head_dim,
77
+ dropout=dropout,
78
+ cross_attention_dim=cross_attention_dim,
79
+ activation_fn=activation_fn,
80
+ num_embeds_ada_norm=num_embeds_ada_norm,
81
+ attention_bias=attention_bias,
82
+ only_cross_attention=only_cross_attention,
83
+ upcast_attention=upcast_attention,
84
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
85
+ unet_use_temporal_attention=unet_use_temporal_attention,
86
+ name=f"{self.name}_{d}_TransformerBlock" if self.name else None,
87
+ )
88
+ for d in range(num_layers)
89
+ ]
90
+ )
91
+
92
+ # 4. Define output layers
93
+ if use_linear_projection:
94
+ self.proj_out = nn.Linear(in_channels, inner_dim)
95
+ else:
96
+ self.proj_out = nn.Conv2d(
97
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
98
+ )
99
+
100
+ self.gradient_checkpointing = False
101
+
102
+ def _set_gradient_checkpointing(self, module, value=False):
103
+ if hasattr(module, "gradient_checkpointing"):
104
+ module.gradient_checkpointing = value
105
+
106
+ def forward(
107
+ self,
108
+ hidden_states,
109
+ encoder_hidden_states=None,
110
+ self_attention_additional_feats=None,
111
+ mode=None,
112
+ timestep=None,
113
+ return_dict: bool = True,
114
+ ):
115
+ # Input
116
+ assert (
117
+ hidden_states.dim() == 5
118
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
119
+ video_length = hidden_states.shape[2]
120
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
121
+ if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
122
+ encoder_hidden_states = repeat(
123
+ encoder_hidden_states, "b n c -> (b f) n c", f=video_length
124
+ )
125
+
126
+ batch, channel, height, weight = hidden_states.shape
127
+ residual = hidden_states
128
+
129
+ hidden_states = self.norm(hidden_states)
130
+ if not self.use_linear_projection:
131
+ hidden_states = self.proj_in(hidden_states)
132
+ inner_dim = hidden_states.shape[1]
133
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
134
+ batch, height * weight, inner_dim
135
+ )
136
+ else:
137
+ inner_dim = hidden_states.shape[1]
138
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
139
+ batch, height * weight, inner_dim
140
+ )
141
+ hidden_states = self.proj_in(hidden_states)
142
+
143
+ # Blocks
144
+ for i, block in enumerate(self.transformer_blocks):
145
+
146
+ if self.training and self.gradient_checkpointing:
147
+
148
+ def create_custom_forward(module, return_dict=None):
149
+ def custom_forward(*inputs):
150
+ if return_dict is not None:
151
+ return module(*inputs, return_dict=return_dict)
152
+ else:
153
+ return module(*inputs)
154
+
155
+ return custom_forward
156
+
157
+ # if hasattr(self.block, 'bank') and len(self.block.bank) > 0:
158
+ # hidden_states
159
+ hidden_states = torch.utils.checkpoint.checkpoint(
160
+ create_custom_forward(block),
161
+ hidden_states,
162
+ encoder_hidden_states=encoder_hidden_states,
163
+ timestep=timestep,
164
+ attention_mask=None,
165
+ video_length=video_length,
166
+ self_attention_additional_feats=self_attention_additional_feats,
167
+ mode=mode,
168
+ )
169
+ else:
170
+
171
+ hidden_states = block(
172
+ hidden_states,
173
+ encoder_hidden_states=encoder_hidden_states,
174
+ timestep=timestep,
175
+ self_attention_additional_feats=self_attention_additional_feats,
176
+ mode=mode,
177
+ video_length=video_length,
178
+ )
179
+
180
+ # Output
181
+ if not self.use_linear_projection:
182
+ hidden_states = (
183
+ hidden_states.reshape(batch, height, weight, inner_dim)
184
+ .permute(0, 3, 1, 2)
185
+ .contiguous()
186
+ )
187
+ hidden_states = self.proj_out(hidden_states)
188
+ else:
189
+ hidden_states = self.proj_out(hidden_states)
190
+ hidden_states = (
191
+ hidden_states.reshape(batch, height, weight, inner_dim)
192
+ .permute(0, 3, 1, 2)
193
+ .contiguous()
194
+ )
195
+
196
+ output = hidden_states + residual
197
+
198
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
199
+ if not return_dict:
200
+ return (output,)
201
+
202
+ return Transformer3DModelOutput(sample=output)
src/models/unet_2d_blocks.py ADDED
@@ -0,0 +1,1074 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+ from typing import Any, Dict, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers.models.activations import get_activation
8
+ from diffusers.models.attention_processor import Attention
9
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
10
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
11
+ from diffusers.utils import is_torch_version, logging
12
+ from diffusers.utils.torch_utils import apply_freeu
13
+ from torch import nn
14
+
15
+ from .transformer_2d import Transformer2DModel
16
+
17
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
+
19
+
20
+ def get_down_block(
21
+ down_block_type: str,
22
+ num_layers: int,
23
+ in_channels: int,
24
+ out_channels: int,
25
+ temb_channels: int,
26
+ add_downsample: bool,
27
+ resnet_eps: float,
28
+ resnet_act_fn: str,
29
+ transformer_layers_per_block: int = 1,
30
+ num_attention_heads: Optional[int] = None,
31
+ resnet_groups: Optional[int] = None,
32
+ cross_attention_dim: Optional[int] = None,
33
+ downsample_padding: Optional[int] = None,
34
+ dual_cross_attention: bool = False,
35
+ use_linear_projection: bool = False,
36
+ only_cross_attention: bool = False,
37
+ upcast_attention: bool = False,
38
+ resnet_time_scale_shift: str = "default",
39
+ attention_type: str = "default",
40
+ resnet_skip_time_act: bool = False,
41
+ resnet_out_scale_factor: float = 1.0,
42
+ cross_attention_norm: Optional[str] = None,
43
+ attention_head_dim: Optional[int] = None,
44
+ downsample_type: Optional[str] = None,
45
+ dropout: float = 0.0,
46
+ ):
47
+ # If attn head dim is not defined, we default it to the number of heads
48
+ if attention_head_dim is None:
49
+ logger.warn(
50
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
51
+ )
52
+ attention_head_dim = num_attention_heads
53
+
54
+ down_block_type = (
55
+ down_block_type[7:]
56
+ if down_block_type.startswith("UNetRes")
57
+ else down_block_type
58
+ )
59
+ if down_block_type == "DownBlock2D":
60
+ return DownBlock2D(
61
+ num_layers=num_layers,
62
+ in_channels=in_channels,
63
+ out_channels=out_channels,
64
+ temb_channels=temb_channels,
65
+ dropout=dropout,
66
+ add_downsample=add_downsample,
67
+ resnet_eps=resnet_eps,
68
+ resnet_act_fn=resnet_act_fn,
69
+ resnet_groups=resnet_groups,
70
+ downsample_padding=downsample_padding,
71
+ resnet_time_scale_shift=resnet_time_scale_shift,
72
+ )
73
+ elif down_block_type == "CrossAttnDownBlock2D":
74
+ if cross_attention_dim is None:
75
+ raise ValueError(
76
+ "cross_attention_dim must be specified for CrossAttnDownBlock2D"
77
+ )
78
+ return CrossAttnDownBlock2D(
79
+ num_layers=num_layers,
80
+ transformer_layers_per_block=transformer_layers_per_block,
81
+ in_channels=in_channels,
82
+ out_channels=out_channels,
83
+ temb_channels=temb_channels,
84
+ dropout=dropout,
85
+ add_downsample=add_downsample,
86
+ resnet_eps=resnet_eps,
87
+ resnet_act_fn=resnet_act_fn,
88
+ resnet_groups=resnet_groups,
89
+ downsample_padding=downsample_padding,
90
+ cross_attention_dim=cross_attention_dim,
91
+ num_attention_heads=num_attention_heads,
92
+ dual_cross_attention=dual_cross_attention,
93
+ use_linear_projection=use_linear_projection,
94
+ only_cross_attention=only_cross_attention,
95
+ upcast_attention=upcast_attention,
96
+ resnet_time_scale_shift=resnet_time_scale_shift,
97
+ attention_type=attention_type,
98
+ )
99
+ raise ValueError(f"{down_block_type} does not exist.")
100
+
101
+
102
+ def get_up_block(
103
+ up_block_type: str,
104
+ num_layers: int,
105
+ in_channels: int,
106
+ out_channels: int,
107
+ prev_output_channel: int,
108
+ temb_channels: int,
109
+ add_upsample: bool,
110
+ resnet_eps: float,
111
+ resnet_act_fn: str,
112
+ resolution_idx: Optional[int] = None,
113
+ transformer_layers_per_block: int = 1,
114
+ num_attention_heads: Optional[int] = None,
115
+ resnet_groups: Optional[int] = None,
116
+ cross_attention_dim: Optional[int] = None,
117
+ dual_cross_attention: bool = False,
118
+ use_linear_projection: bool = False,
119
+ only_cross_attention: bool = False,
120
+ upcast_attention: bool = False,
121
+ resnet_time_scale_shift: str = "default",
122
+ attention_type: str = "default",
123
+ resnet_skip_time_act: bool = False,
124
+ resnet_out_scale_factor: float = 1.0,
125
+ cross_attention_norm: Optional[str] = None,
126
+ attention_head_dim: Optional[int] = None,
127
+ upsample_type: Optional[str] = None,
128
+ dropout: float = 0.0,
129
+ ) -> nn.Module:
130
+ # If attn head dim is not defined, we default it to the number of heads
131
+ if attention_head_dim is None:
132
+ logger.warn(
133
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
134
+ )
135
+ attention_head_dim = num_attention_heads
136
+
137
+ up_block_type = (
138
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
139
+ )
140
+ if up_block_type == "UpBlock2D":
141
+ return UpBlock2D(
142
+ num_layers=num_layers,
143
+ in_channels=in_channels,
144
+ out_channels=out_channels,
145
+ prev_output_channel=prev_output_channel,
146
+ temb_channels=temb_channels,
147
+ resolution_idx=resolution_idx,
148
+ dropout=dropout,
149
+ add_upsample=add_upsample,
150
+ resnet_eps=resnet_eps,
151
+ resnet_act_fn=resnet_act_fn,
152
+ resnet_groups=resnet_groups,
153
+ resnet_time_scale_shift=resnet_time_scale_shift,
154
+ )
155
+ elif up_block_type == "CrossAttnUpBlock2D":
156
+ if cross_attention_dim is None:
157
+ raise ValueError(
158
+ "cross_attention_dim must be specified for CrossAttnUpBlock2D"
159
+ )
160
+ return CrossAttnUpBlock2D(
161
+ num_layers=num_layers,
162
+ transformer_layers_per_block=transformer_layers_per_block,
163
+ in_channels=in_channels,
164
+ out_channels=out_channels,
165
+ prev_output_channel=prev_output_channel,
166
+ temb_channels=temb_channels,
167
+ resolution_idx=resolution_idx,
168
+ dropout=dropout,
169
+ add_upsample=add_upsample,
170
+ resnet_eps=resnet_eps,
171
+ resnet_act_fn=resnet_act_fn,
172
+ resnet_groups=resnet_groups,
173
+ cross_attention_dim=cross_attention_dim,
174
+ num_attention_heads=num_attention_heads,
175
+ dual_cross_attention=dual_cross_attention,
176
+ use_linear_projection=use_linear_projection,
177
+ only_cross_attention=only_cross_attention,
178
+ upcast_attention=upcast_attention,
179
+ resnet_time_scale_shift=resnet_time_scale_shift,
180
+ attention_type=attention_type,
181
+ )
182
+
183
+ raise ValueError(f"{up_block_type} does not exist.")
184
+
185
+
186
+ class AutoencoderTinyBlock(nn.Module):
187
+ """
188
+ Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
189
+ blocks.
190
+
191
+ Args:
192
+ in_channels (`int`): The number of input channels.
193
+ out_channels (`int`): The number of output channels.
194
+ act_fn (`str`):
195
+ ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
196
+
197
+ Returns:
198
+ `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
199
+ `out_channels`.
200
+ """
201
+
202
+ def __init__(self, in_channels: int, out_channels: int, act_fn: str):
203
+ super().__init__()
204
+ act_fn = get_activation(act_fn)
205
+ self.conv = nn.Sequential(
206
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
207
+ act_fn,
208
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
209
+ act_fn,
210
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
211
+ )
212
+ self.skip = (
213
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
214
+ if in_channels != out_channels
215
+ else nn.Identity()
216
+ )
217
+ self.fuse = nn.ReLU()
218
+
219
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
220
+ return self.fuse(self.conv(x) + self.skip(x))
221
+
222
+
223
+ class UNetMidBlock2D(nn.Module):
224
+ """
225
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
226
+
227
+ Args:
228
+ in_channels (`int`): The number of input channels.
229
+ temb_channels (`int`): The number of temporal embedding channels.
230
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
231
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
232
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
233
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
234
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
235
+ model on tasks with long-range temporal dependencies.
236
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
237
+ resnet_groups (`int`, *optional*, defaults to 32):
238
+ The number of groups to use in the group normalization layers of the resnet blocks.
239
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
240
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
241
+ Whether to use pre-normalization for the resnet blocks.
242
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
243
+ attention_head_dim (`int`, *optional*, defaults to 1):
244
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
245
+ the number of input channels.
246
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
247
+
248
+ Returns:
249
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
250
+ in_channels, height, width)`.
251
+
252
+ """
253
+
254
+ def __init__(
255
+ self,
256
+ in_channels: int,
257
+ temb_channels: int,
258
+ dropout: float = 0.0,
259
+ num_layers: int = 1,
260
+ resnet_eps: float = 1e-6,
261
+ resnet_time_scale_shift: str = "default", # default, spatial
262
+ resnet_act_fn: str = "swish",
263
+ resnet_groups: int = 32,
264
+ attn_groups: Optional[int] = None,
265
+ resnet_pre_norm: bool = True,
266
+ add_attention: bool = True,
267
+ attention_head_dim: int = 1,
268
+ output_scale_factor: float = 1.0,
269
+ ):
270
+ super().__init__()
271
+ resnet_groups = (
272
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
273
+ )
274
+ self.add_attention = add_attention
275
+
276
+ if attn_groups is None:
277
+ attn_groups = (
278
+ resnet_groups if resnet_time_scale_shift == "default" else None
279
+ )
280
+
281
+ # there is always at least one resnet
282
+ resnets = [
283
+ ResnetBlock2D(
284
+ in_channels=in_channels,
285
+ out_channels=in_channels,
286
+ temb_channels=temb_channels,
287
+ eps=resnet_eps,
288
+ groups=resnet_groups,
289
+ dropout=dropout,
290
+ time_embedding_norm=resnet_time_scale_shift,
291
+ non_linearity=resnet_act_fn,
292
+ output_scale_factor=output_scale_factor,
293
+ pre_norm=resnet_pre_norm,
294
+ )
295
+ ]
296
+ attentions = []
297
+
298
+ if attention_head_dim is None:
299
+ logger.warn(
300
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
301
+ )
302
+ attention_head_dim = in_channels
303
+
304
+ for _ in range(num_layers):
305
+ if self.add_attention:
306
+ attentions.append(
307
+ Attention(
308
+ in_channels,
309
+ heads=in_channels // attention_head_dim,
310
+ dim_head=attention_head_dim,
311
+ rescale_output_factor=output_scale_factor,
312
+ eps=resnet_eps,
313
+ norm_num_groups=attn_groups,
314
+ spatial_norm_dim=temb_channels
315
+ if resnet_time_scale_shift == "spatial"
316
+ else None,
317
+ residual_connection=True,
318
+ bias=True,
319
+ upcast_softmax=True,
320
+ _from_deprecated_attn_block=True,
321
+ )
322
+ )
323
+ else:
324
+ attentions.append(None)
325
+
326
+ resnets.append(
327
+ ResnetBlock2D(
328
+ in_channels=in_channels,
329
+ out_channels=in_channels,
330
+ temb_channels=temb_channels,
331
+ eps=resnet_eps,
332
+ groups=resnet_groups,
333
+ dropout=dropout,
334
+ time_embedding_norm=resnet_time_scale_shift,
335
+ non_linearity=resnet_act_fn,
336
+ output_scale_factor=output_scale_factor,
337
+ pre_norm=resnet_pre_norm,
338
+ )
339
+ )
340
+
341
+ self.attentions = nn.ModuleList(attentions)
342
+ self.resnets = nn.ModuleList(resnets)
343
+
344
+ def forward(
345
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
346
+ ) -> torch.FloatTensor:
347
+ hidden_states = self.resnets[0](hidden_states, temb)
348
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
349
+ if attn is not None:
350
+ hidden_states = attn(hidden_states, temb=temb)
351
+ hidden_states = resnet(hidden_states, temb)
352
+
353
+ return hidden_states
354
+
355
+
356
+ class UNetMidBlock2DCrossAttn(nn.Module):
357
+ def __init__(
358
+ self,
359
+ in_channels: int,
360
+ temb_channels: int,
361
+ dropout: float = 0.0,
362
+ num_layers: int = 1,
363
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
364
+ resnet_eps: float = 1e-6,
365
+ resnet_time_scale_shift: str = "default",
366
+ resnet_act_fn: str = "swish",
367
+ resnet_groups: int = 32,
368
+ resnet_pre_norm: bool = True,
369
+ num_attention_heads: int = 1,
370
+ output_scale_factor: float = 1.0,
371
+ cross_attention_dim: int = 1280,
372
+ dual_cross_attention: bool = False,
373
+ use_linear_projection: bool = False,
374
+ upcast_attention: bool = False,
375
+ attention_type: str = "default",
376
+ ):
377
+ super().__init__()
378
+
379
+ self.has_cross_attention = True
380
+ self.num_attention_heads = num_attention_heads
381
+ resnet_groups = (
382
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
383
+ )
384
+
385
+ # support for variable transformer layers per block
386
+ if isinstance(transformer_layers_per_block, int):
387
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
388
+
389
+ # there is always at least one resnet
390
+ resnets = [
391
+ ResnetBlock2D(
392
+ in_channels=in_channels,
393
+ out_channels=in_channels,
394
+ temb_channels=temb_channels,
395
+ eps=resnet_eps,
396
+ groups=resnet_groups,
397
+ dropout=dropout,
398
+ time_embedding_norm=resnet_time_scale_shift,
399
+ non_linearity=resnet_act_fn,
400
+ output_scale_factor=output_scale_factor,
401
+ pre_norm=resnet_pre_norm,
402
+ )
403
+ ]
404
+ attentions = []
405
+
406
+ for i in range(num_layers):
407
+ if not dual_cross_attention:
408
+ attentions.append(
409
+ Transformer2DModel(
410
+ num_attention_heads,
411
+ in_channels // num_attention_heads,
412
+ in_channels=in_channels,
413
+ num_layers=transformer_layers_per_block[i],
414
+ cross_attention_dim=cross_attention_dim,
415
+ norm_num_groups=resnet_groups,
416
+ use_linear_projection=use_linear_projection,
417
+ upcast_attention=upcast_attention,
418
+ attention_type=attention_type,
419
+ )
420
+ )
421
+ else:
422
+ attentions.append(
423
+ DualTransformer2DModel(
424
+ num_attention_heads,
425
+ in_channels // num_attention_heads,
426
+ in_channels=in_channels,
427
+ num_layers=1,
428
+ cross_attention_dim=cross_attention_dim,
429
+ norm_num_groups=resnet_groups,
430
+ )
431
+ )
432
+ resnets.append(
433
+ ResnetBlock2D(
434
+ in_channels=in_channels,
435
+ out_channels=in_channels,
436
+ temb_channels=temb_channels,
437
+ eps=resnet_eps,
438
+ groups=resnet_groups,
439
+ dropout=dropout,
440
+ time_embedding_norm=resnet_time_scale_shift,
441
+ non_linearity=resnet_act_fn,
442
+ output_scale_factor=output_scale_factor,
443
+ pre_norm=resnet_pre_norm,
444
+ )
445
+ )
446
+
447
+ self.attentions = nn.ModuleList(attentions)
448
+ self.resnets = nn.ModuleList(resnets)
449
+
450
+ self.gradient_checkpointing = False
451
+
452
+ def forward(
453
+ self,
454
+ hidden_states: torch.FloatTensor,
455
+ temb: Optional[torch.FloatTensor] = None,
456
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
457
+ attention_mask: Optional[torch.FloatTensor] = None,
458
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
459
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
460
+ ) -> torch.FloatTensor:
461
+ lora_scale = (
462
+ cross_attention_kwargs.get("scale", 1.0)
463
+ if cross_attention_kwargs is not None
464
+ else 1.0
465
+ )
466
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
467
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
468
+ if self.training and self.gradient_checkpointing:
469
+
470
+ def create_custom_forward(module, return_dict=None):
471
+ def custom_forward(*inputs):
472
+ if return_dict is not None:
473
+ return module(*inputs, return_dict=return_dict)
474
+ else:
475
+ return module(*inputs)
476
+
477
+ return custom_forward
478
+
479
+ ckpt_kwargs: Dict[str, Any] = (
480
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
481
+ )
482
+ hidden_states, ref_feature = attn(
483
+ hidden_states,
484
+ encoder_hidden_states=encoder_hidden_states,
485
+ cross_attention_kwargs=cross_attention_kwargs,
486
+ attention_mask=attention_mask,
487
+ encoder_attention_mask=encoder_attention_mask,
488
+ return_dict=False,
489
+ )
490
+ hidden_states = torch.utils.checkpoint.checkpoint(
491
+ create_custom_forward(resnet),
492
+ hidden_states,
493
+ temb,
494
+ **ckpt_kwargs,
495
+ )
496
+ else:
497
+ hidden_states, ref_feature = attn(
498
+ hidden_states,
499
+ encoder_hidden_states=encoder_hidden_states,
500
+ cross_attention_kwargs=cross_attention_kwargs,
501
+ attention_mask=attention_mask,
502
+ encoder_attention_mask=encoder_attention_mask,
503
+ return_dict=False,
504
+ )
505
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
506
+
507
+ return hidden_states
508
+
509
+
510
+ class CrossAttnDownBlock2D(nn.Module):
511
+ def __init__(
512
+ self,
513
+ in_channels: int,
514
+ out_channels: int,
515
+ temb_channels: int,
516
+ dropout: float = 0.0,
517
+ num_layers: int = 1,
518
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
519
+ resnet_eps: float = 1e-6,
520
+ resnet_time_scale_shift: str = "default",
521
+ resnet_act_fn: str = "swish",
522
+ resnet_groups: int = 32,
523
+ resnet_pre_norm: bool = True,
524
+ num_attention_heads: int = 1,
525
+ cross_attention_dim: int = 1280,
526
+ output_scale_factor: float = 1.0,
527
+ downsample_padding: int = 1,
528
+ add_downsample: bool = True,
529
+ dual_cross_attention: bool = False,
530
+ use_linear_projection: bool = False,
531
+ only_cross_attention: bool = False,
532
+ upcast_attention: bool = False,
533
+ attention_type: str = "default",
534
+ ):
535
+ super().__init__()
536
+ resnets = []
537
+ attentions = []
538
+
539
+ self.has_cross_attention = True
540
+ self.num_attention_heads = num_attention_heads
541
+ if isinstance(transformer_layers_per_block, int):
542
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
543
+
544
+ for i in range(num_layers):
545
+ in_channels = in_channels if i == 0 else out_channels
546
+ resnets.append(
547
+ ResnetBlock2D(
548
+ in_channels=in_channels,
549
+ out_channels=out_channels,
550
+ temb_channels=temb_channels,
551
+ eps=resnet_eps,
552
+ groups=resnet_groups,
553
+ dropout=dropout,
554
+ time_embedding_norm=resnet_time_scale_shift,
555
+ non_linearity=resnet_act_fn,
556
+ output_scale_factor=output_scale_factor,
557
+ pre_norm=resnet_pre_norm,
558
+ )
559
+ )
560
+ if not dual_cross_attention:
561
+ attentions.append(
562
+ Transformer2DModel(
563
+ num_attention_heads,
564
+ out_channels // num_attention_heads,
565
+ in_channels=out_channels,
566
+ num_layers=transformer_layers_per_block[i],
567
+ cross_attention_dim=cross_attention_dim,
568
+ norm_num_groups=resnet_groups,
569
+ use_linear_projection=use_linear_projection,
570
+ only_cross_attention=only_cross_attention,
571
+ upcast_attention=upcast_attention,
572
+ attention_type=attention_type,
573
+ )
574
+ )
575
+ else:
576
+ attentions.append(
577
+ DualTransformer2DModel(
578
+ num_attention_heads,
579
+ out_channels // num_attention_heads,
580
+ in_channels=out_channels,
581
+ num_layers=1,
582
+ cross_attention_dim=cross_attention_dim,
583
+ norm_num_groups=resnet_groups,
584
+ )
585
+ )
586
+ self.attentions = nn.ModuleList(attentions)
587
+ self.resnets = nn.ModuleList(resnets)
588
+
589
+ if add_downsample:
590
+ self.downsamplers = nn.ModuleList(
591
+ [
592
+ Downsample2D(
593
+ out_channels,
594
+ use_conv=True,
595
+ out_channels=out_channels,
596
+ padding=downsample_padding,
597
+ name="op",
598
+ )
599
+ ]
600
+ )
601
+ else:
602
+ self.downsamplers = None
603
+
604
+ self.gradient_checkpointing = False
605
+
606
+ def forward(
607
+ self,
608
+ hidden_states: torch.FloatTensor,
609
+ temb: Optional[torch.FloatTensor] = None,
610
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
611
+ attention_mask: Optional[torch.FloatTensor] = None,
612
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
613
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
614
+ additional_residuals: Optional[torch.FloatTensor] = None,
615
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
616
+ output_states = ()
617
+
618
+ lora_scale = (
619
+ cross_attention_kwargs.get("scale", 1.0)
620
+ if cross_attention_kwargs is not None
621
+ else 1.0
622
+ )
623
+
624
+ blocks = list(zip(self.resnets, self.attentions))
625
+
626
+ for i, (resnet, attn) in enumerate(blocks):
627
+ if self.training and self.gradient_checkpointing:
628
+
629
+ def create_custom_forward(module, return_dict=None):
630
+ def custom_forward(*inputs):
631
+ if return_dict is not None:
632
+ return module(*inputs, return_dict=return_dict)
633
+ else:
634
+ return module(*inputs)
635
+
636
+ return custom_forward
637
+
638
+ ckpt_kwargs: Dict[str, Any] = (
639
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
640
+ )
641
+ hidden_states = torch.utils.checkpoint.checkpoint(
642
+ create_custom_forward(resnet),
643
+ hidden_states,
644
+ temb,
645
+ **ckpt_kwargs,
646
+ )
647
+ hidden_states, ref_feature = attn(
648
+ hidden_states,
649
+ encoder_hidden_states=encoder_hidden_states,
650
+ cross_attention_kwargs=cross_attention_kwargs,
651
+ attention_mask=attention_mask,
652
+ encoder_attention_mask=encoder_attention_mask,
653
+ return_dict=False,
654
+ )
655
+ else:
656
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
657
+ hidden_states, ref_feature = attn(
658
+ hidden_states,
659
+ encoder_hidden_states=encoder_hidden_states,
660
+ cross_attention_kwargs=cross_attention_kwargs,
661
+ attention_mask=attention_mask,
662
+ encoder_attention_mask=encoder_attention_mask,
663
+ return_dict=False,
664
+ )
665
+
666
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
667
+ if i == len(blocks) - 1 and additional_residuals is not None:
668
+ hidden_states = hidden_states + additional_residuals
669
+
670
+ output_states = output_states + (hidden_states,)
671
+
672
+ if self.downsamplers is not None:
673
+ for downsampler in self.downsamplers:
674
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
675
+
676
+ output_states = output_states + (hidden_states,)
677
+
678
+ return hidden_states, output_states
679
+
680
+
681
+ class DownBlock2D(nn.Module):
682
+ def __init__(
683
+ self,
684
+ in_channels: int,
685
+ out_channels: int,
686
+ temb_channels: int,
687
+ dropout: float = 0.0,
688
+ num_layers: int = 1,
689
+ resnet_eps: float = 1e-6,
690
+ resnet_time_scale_shift: str = "default",
691
+ resnet_act_fn: str = "swish",
692
+ resnet_groups: int = 32,
693
+ resnet_pre_norm: bool = True,
694
+ output_scale_factor: float = 1.0,
695
+ add_downsample: bool = True,
696
+ downsample_padding: int = 1,
697
+ ):
698
+ super().__init__()
699
+ resnets = []
700
+
701
+ for i in range(num_layers):
702
+ in_channels = in_channels if i == 0 else out_channels
703
+ resnets.append(
704
+ ResnetBlock2D(
705
+ in_channels=in_channels,
706
+ out_channels=out_channels,
707
+ temb_channels=temb_channels,
708
+ eps=resnet_eps,
709
+ groups=resnet_groups,
710
+ dropout=dropout,
711
+ time_embedding_norm=resnet_time_scale_shift,
712
+ non_linearity=resnet_act_fn,
713
+ output_scale_factor=output_scale_factor,
714
+ pre_norm=resnet_pre_norm,
715
+ )
716
+ )
717
+
718
+ self.resnets = nn.ModuleList(resnets)
719
+
720
+ if add_downsample:
721
+ self.downsamplers = nn.ModuleList(
722
+ [
723
+ Downsample2D(
724
+ out_channels,
725
+ use_conv=True,
726
+ out_channels=out_channels,
727
+ padding=downsample_padding,
728
+ name="op",
729
+ )
730
+ ]
731
+ )
732
+ else:
733
+ self.downsamplers = None
734
+
735
+ self.gradient_checkpointing = False
736
+
737
+ def forward(
738
+ self,
739
+ hidden_states: torch.FloatTensor,
740
+ temb: Optional[torch.FloatTensor] = None,
741
+ scale: float = 1.0,
742
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
743
+ output_states = ()
744
+
745
+ for resnet in self.resnets:
746
+ if self.training and self.gradient_checkpointing:
747
+
748
+ def create_custom_forward(module):
749
+ def custom_forward(*inputs):
750
+ return module(*inputs)
751
+
752
+ return custom_forward
753
+
754
+ if is_torch_version(">=", "1.11.0"):
755
+ hidden_states = torch.utils.checkpoint.checkpoint(
756
+ create_custom_forward(resnet),
757
+ hidden_states,
758
+ temb,
759
+ use_reentrant=False,
760
+ )
761
+ else:
762
+ hidden_states = torch.utils.checkpoint.checkpoint(
763
+ create_custom_forward(resnet), hidden_states, temb
764
+ )
765
+ else:
766
+ hidden_states = resnet(hidden_states, temb, scale=scale)
767
+
768
+ output_states = output_states + (hidden_states,)
769
+
770
+ if self.downsamplers is not None:
771
+ for downsampler in self.downsamplers:
772
+ hidden_states = downsampler(hidden_states, scale=scale)
773
+
774
+ output_states = output_states + (hidden_states,)
775
+
776
+ return hidden_states, output_states
777
+
778
+
779
+ class CrossAttnUpBlock2D(nn.Module):
780
+ def __init__(
781
+ self,
782
+ in_channels: int,
783
+ out_channels: int,
784
+ prev_output_channel: int,
785
+ temb_channels: int,
786
+ resolution_idx: Optional[int] = None,
787
+ dropout: float = 0.0,
788
+ num_layers: int = 1,
789
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
790
+ resnet_eps: float = 1e-6,
791
+ resnet_time_scale_shift: str = "default",
792
+ resnet_act_fn: str = "swish",
793
+ resnet_groups: int = 32,
794
+ resnet_pre_norm: bool = True,
795
+ num_attention_heads: int = 1,
796
+ cross_attention_dim: int = 1280,
797
+ output_scale_factor: float = 1.0,
798
+ add_upsample: bool = True,
799
+ dual_cross_attention: bool = False,
800
+ use_linear_projection: bool = False,
801
+ only_cross_attention: bool = False,
802
+ upcast_attention: bool = False,
803
+ attention_type: str = "default",
804
+ ):
805
+ super().__init__()
806
+ resnets = []
807
+ attentions = []
808
+
809
+ self.has_cross_attention = True
810
+ self.num_attention_heads = num_attention_heads
811
+
812
+ if isinstance(transformer_layers_per_block, int):
813
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
814
+
815
+ for i in range(num_layers):
816
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
817
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
818
+
819
+ resnets.append(
820
+ ResnetBlock2D(
821
+ in_channels=resnet_in_channels + res_skip_channels,
822
+ out_channels=out_channels,
823
+ temb_channels=temb_channels,
824
+ eps=resnet_eps,
825
+ groups=resnet_groups,
826
+ dropout=dropout,
827
+ time_embedding_norm=resnet_time_scale_shift,
828
+ non_linearity=resnet_act_fn,
829
+ output_scale_factor=output_scale_factor,
830
+ pre_norm=resnet_pre_norm,
831
+ )
832
+ )
833
+ if not dual_cross_attention:
834
+ attentions.append(
835
+ Transformer2DModel(
836
+ num_attention_heads,
837
+ out_channels // num_attention_heads,
838
+ in_channels=out_channels,
839
+ num_layers=transformer_layers_per_block[i],
840
+ cross_attention_dim=cross_attention_dim,
841
+ norm_num_groups=resnet_groups,
842
+ use_linear_projection=use_linear_projection,
843
+ only_cross_attention=only_cross_attention,
844
+ upcast_attention=upcast_attention,
845
+ attention_type=attention_type,
846
+ )
847
+ )
848
+ else:
849
+ attentions.append(
850
+ DualTransformer2DModel(
851
+ num_attention_heads,
852
+ out_channels // num_attention_heads,
853
+ in_channels=out_channels,
854
+ num_layers=1,
855
+ cross_attention_dim=cross_attention_dim,
856
+ norm_num_groups=resnet_groups,
857
+ )
858
+ )
859
+ self.attentions = nn.ModuleList(attentions)
860
+ self.resnets = nn.ModuleList(resnets)
861
+
862
+ if add_upsample:
863
+ self.upsamplers = nn.ModuleList(
864
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
865
+ )
866
+ else:
867
+ self.upsamplers = None
868
+
869
+ self.gradient_checkpointing = False
870
+ self.resolution_idx = resolution_idx
871
+
872
+ def forward(
873
+ self,
874
+ hidden_states: torch.FloatTensor,
875
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
876
+ temb: Optional[torch.FloatTensor] = None,
877
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
878
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
879
+ upsample_size: Optional[int] = None,
880
+ attention_mask: Optional[torch.FloatTensor] = None,
881
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
882
+ ) -> torch.FloatTensor:
883
+ lora_scale = (
884
+ cross_attention_kwargs.get("scale", 1.0)
885
+ if cross_attention_kwargs is not None
886
+ else 1.0
887
+ )
888
+ is_freeu_enabled = (
889
+ getattr(self, "s1", None)
890
+ and getattr(self, "s2", None)
891
+ and getattr(self, "b1", None)
892
+ and getattr(self, "b2", None)
893
+ )
894
+
895
+ for resnet, attn in zip(self.resnets, self.attentions):
896
+ # pop res hidden states
897
+ res_hidden_states = res_hidden_states_tuple[-1]
898
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
899
+
900
+ # FreeU: Only operate on the first two stages
901
+ if is_freeu_enabled:
902
+ hidden_states, res_hidden_states = apply_freeu(
903
+ self.resolution_idx,
904
+ hidden_states,
905
+ res_hidden_states,
906
+ s1=self.s1,
907
+ s2=self.s2,
908
+ b1=self.b1,
909
+ b2=self.b2,
910
+ )
911
+
912
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
913
+
914
+ if self.training and self.gradient_checkpointing:
915
+
916
+ def create_custom_forward(module, return_dict=None):
917
+ def custom_forward(*inputs):
918
+ if return_dict is not None:
919
+ return module(*inputs, return_dict=return_dict)
920
+ else:
921
+ return module(*inputs)
922
+
923
+ return custom_forward
924
+
925
+ ckpt_kwargs: Dict[str, Any] = (
926
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
927
+ )
928
+ hidden_states = torch.utils.checkpoint.checkpoint(
929
+ create_custom_forward(resnet),
930
+ hidden_states,
931
+ temb,
932
+ **ckpt_kwargs,
933
+ )
934
+ hidden_states, ref_feature = attn(
935
+ hidden_states,
936
+ encoder_hidden_states=encoder_hidden_states,
937
+ cross_attention_kwargs=cross_attention_kwargs,
938
+ attention_mask=attention_mask,
939
+ encoder_attention_mask=encoder_attention_mask,
940
+ return_dict=False,
941
+ )
942
+ else:
943
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
944
+ hidden_states, ref_feature = attn(
945
+ hidden_states,
946
+ encoder_hidden_states=encoder_hidden_states,
947
+ cross_attention_kwargs=cross_attention_kwargs,
948
+ attention_mask=attention_mask,
949
+ encoder_attention_mask=encoder_attention_mask,
950
+ return_dict=False,
951
+ )
952
+
953
+ if self.upsamplers is not None:
954
+ for upsampler in self.upsamplers:
955
+ hidden_states = upsampler(
956
+ hidden_states, upsample_size, scale=lora_scale
957
+ )
958
+
959
+ return hidden_states
960
+
961
+
962
+ class UpBlock2D(nn.Module):
963
+ def __init__(
964
+ self,
965
+ in_channels: int,
966
+ prev_output_channel: int,
967
+ out_channels: int,
968
+ temb_channels: int,
969
+ resolution_idx: Optional[int] = None,
970
+ dropout: float = 0.0,
971
+ num_layers: int = 1,
972
+ resnet_eps: float = 1e-6,
973
+ resnet_time_scale_shift: str = "default",
974
+ resnet_act_fn: str = "swish",
975
+ resnet_groups: int = 32,
976
+ resnet_pre_norm: bool = True,
977
+ output_scale_factor: float = 1.0,
978
+ add_upsample: bool = True,
979
+ ):
980
+ super().__init__()
981
+ resnets = []
982
+
983
+ for i in range(num_layers):
984
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
985
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
986
+
987
+ resnets.append(
988
+ ResnetBlock2D(
989
+ in_channels=resnet_in_channels + res_skip_channels,
990
+ out_channels=out_channels,
991
+ temb_channels=temb_channels,
992
+ eps=resnet_eps,
993
+ groups=resnet_groups,
994
+ dropout=dropout,
995
+ time_embedding_norm=resnet_time_scale_shift,
996
+ non_linearity=resnet_act_fn,
997
+ output_scale_factor=output_scale_factor,
998
+ pre_norm=resnet_pre_norm,
999
+ )
1000
+ )
1001
+
1002
+ self.resnets = nn.ModuleList(resnets)
1003
+
1004
+ if add_upsample:
1005
+ self.upsamplers = nn.ModuleList(
1006
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
1007
+ )
1008
+ else:
1009
+ self.upsamplers = None
1010
+
1011
+ self.gradient_checkpointing = False
1012
+ self.resolution_idx = resolution_idx
1013
+
1014
+ def forward(
1015
+ self,
1016
+ hidden_states: torch.FloatTensor,
1017
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1018
+ temb: Optional[torch.FloatTensor] = None,
1019
+ upsample_size: Optional[int] = None,
1020
+ scale: float = 1.0,
1021
+ ) -> torch.FloatTensor:
1022
+ is_freeu_enabled = (
1023
+ getattr(self, "s1", None)
1024
+ and getattr(self, "s2", None)
1025
+ and getattr(self, "b1", None)
1026
+ and getattr(self, "b2", None)
1027
+ )
1028
+
1029
+ for resnet in self.resnets:
1030
+ # pop res hidden states
1031
+ res_hidden_states = res_hidden_states_tuple[-1]
1032
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1033
+
1034
+ # FreeU: Only operate on the first two stages
1035
+ if is_freeu_enabled:
1036
+ hidden_states, res_hidden_states = apply_freeu(
1037
+ self.resolution_idx,
1038
+ hidden_states,
1039
+ res_hidden_states,
1040
+ s1=self.s1,
1041
+ s2=self.s2,
1042
+ b1=self.b1,
1043
+ b2=self.b2,
1044
+ )
1045
+
1046
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1047
+
1048
+ if self.training and self.gradient_checkpointing:
1049
+
1050
+ def create_custom_forward(module):
1051
+ def custom_forward(*inputs):
1052
+ return module(*inputs)
1053
+
1054
+ return custom_forward
1055
+
1056
+ if is_torch_version(">=", "1.11.0"):
1057
+ hidden_states = torch.utils.checkpoint.checkpoint(
1058
+ create_custom_forward(resnet),
1059
+ hidden_states,
1060
+ temb,
1061
+ use_reentrant=False,
1062
+ )
1063
+ else:
1064
+ hidden_states = torch.utils.checkpoint.checkpoint(
1065
+ create_custom_forward(resnet), hidden_states, temb
1066
+ )
1067
+ else:
1068
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1069
+
1070
+ if self.upsamplers is not None:
1071
+ for upsampler in self.upsamplers:
1072
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
1073
+
1074
+ return hidden_states
src/models/unet_2d_condition.py ADDED
@@ -0,0 +1,1308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.utils.checkpoint
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import UNet2DConditionLoadersMixin
10
+ from diffusers.models.activations import get_activation
11
+ from diffusers.models.attention_processor import (
12
+ ADDED_KV_ATTENTION_PROCESSORS,
13
+ CROSS_ATTENTION_PROCESSORS,
14
+ AttentionProcessor,
15
+ AttnAddedKVProcessor,
16
+ AttnProcessor,
17
+ )
18
+ from diffusers.models.embeddings import (
19
+ GaussianFourierProjection,
20
+ ImageHintTimeEmbedding,
21
+ ImageProjection,
22
+ ImageTimeEmbedding,
23
+ PositionNet,
24
+ TextImageProjection,
25
+ TextImageTimeEmbedding,
26
+ TextTimeEmbedding,
27
+ TimestepEmbedding,
28
+ Timesteps,
29
+ )
30
+ from diffusers.models.modeling_utils import ModelMixin
31
+ from diffusers.utils import (
32
+ USE_PEFT_BACKEND,
33
+ BaseOutput,
34
+ deprecate,
35
+ logging,
36
+ scale_lora_layers,
37
+ unscale_lora_layers,
38
+ )
39
+
40
+ from .unet_2d_blocks import (
41
+ UNetMidBlock2D,
42
+ UNetMidBlock2DCrossAttn,
43
+ get_down_block,
44
+ get_up_block,
45
+ )
46
+
47
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
+
49
+
50
+ @dataclass
51
+ class UNet2DConditionOutput(BaseOutput):
52
+ """
53
+ The output of [`UNet2DConditionModel`].
54
+
55
+ Args:
56
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
57
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
58
+ """
59
+
60
+ sample: torch.FloatTensor = None
61
+ ref_features: Tuple[torch.FloatTensor] = None
62
+
63
+
64
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
65
+ r"""
66
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
67
+ shaped output.
68
+
69
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
70
+ for all models (such as downloading or saving).
71
+
72
+ Parameters:
73
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
74
+ Height and width of input/output sample.
75
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
76
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
77
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
78
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
79
+ Whether to flip the sin to cos in the time embedding.
80
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
81
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
82
+ The tuple of downsample blocks to use.
83
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
84
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
85
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
86
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
87
+ The tuple of upsample blocks to use.
88
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
89
+ Whether to include self-attention in the basic transformer blocks, see
90
+ [`~models.attention.BasicTransformerBlock`].
91
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
92
+ The tuple of output channels for each block.
93
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
94
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
95
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
96
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
97
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
98
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
99
+ If `None`, normalization and activation layers is skipped in post-processing.
100
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
101
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
102
+ The dimension of the cross attention features.
103
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
104
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
105
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
106
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
107
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
108
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
109
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
110
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
111
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
112
+ encoder_hid_dim (`int`, *optional*, defaults to None):
113
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
114
+ dimension to `cross_attention_dim`.
115
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
116
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
117
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
118
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
119
+ num_attention_heads (`int`, *optional*):
120
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
121
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
122
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
123
+ class_embed_type (`str`, *optional*, defaults to `None`):
124
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
125
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
126
+ addition_embed_type (`str`, *optional*, defaults to `None`):
127
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
128
+ "text". "text" will use the `TextTimeEmbedding` layer.
129
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
130
+ Dimension for the timestep embeddings.
131
+ num_class_embeds (`int`, *optional*, defaults to `None`):
132
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
133
+ class conditioning with `class_embed_type` equal to `None`.
134
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
135
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
136
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
137
+ An optional override for the dimension of the projected time embedding.
138
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
139
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
140
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
141
+ timestep_post_act (`str`, *optional*, defaults to `None`):
142
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
143
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
144
+ The dimension of `cond_proj` layer in the timestep embedding.
145
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
146
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
147
+ *optional*): The dimension of the `class_labels` input when
148
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
149
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
150
+ embeddings with the class embeddings.
151
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
152
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
153
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
154
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
155
+ otherwise.
156
+ """
157
+
158
+ _supports_gradient_checkpointing = True
159
+
160
+ @register_to_config
161
+ def __init__(
162
+ self,
163
+ sample_size: Optional[int] = None,
164
+ in_channels: int = 4,
165
+ out_channels: int = 4,
166
+ center_input_sample: bool = False,
167
+ flip_sin_to_cos: bool = True,
168
+ freq_shift: int = 0,
169
+ down_block_types: Tuple[str] = (
170
+ "CrossAttnDownBlock2D",
171
+ "CrossAttnDownBlock2D",
172
+ "CrossAttnDownBlock2D",
173
+ "DownBlock2D",
174
+ ),
175
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
176
+ up_block_types: Tuple[str] = (
177
+ "UpBlock2D",
178
+ "CrossAttnUpBlock2D",
179
+ "CrossAttnUpBlock2D",
180
+ "CrossAttnUpBlock2D",
181
+ ),
182
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
183
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
184
+ layers_per_block: Union[int, Tuple[int]] = 2,
185
+ downsample_padding: int = 1,
186
+ mid_block_scale_factor: float = 1,
187
+ dropout: float = 0.0,
188
+ act_fn: str = "silu",
189
+ norm_num_groups: Optional[int] = 32,
190
+ norm_eps: float = 1e-5,
191
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
192
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
193
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
194
+ encoder_hid_dim: Optional[int] = None,
195
+ encoder_hid_dim_type: Optional[str] = None,
196
+ attention_head_dim: Union[int, Tuple[int]] = 8,
197
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
198
+ dual_cross_attention: bool = False,
199
+ use_linear_projection: bool = False,
200
+ class_embed_type: Optional[str] = None,
201
+ addition_embed_type: Optional[str] = None,
202
+ addition_time_embed_dim: Optional[int] = None,
203
+ num_class_embeds: Optional[int] = None,
204
+ upcast_attention: bool = False,
205
+ resnet_time_scale_shift: str = "default",
206
+ resnet_skip_time_act: bool = False,
207
+ resnet_out_scale_factor: int = 1.0,
208
+ time_embedding_type: str = "positional",
209
+ time_embedding_dim: Optional[int] = None,
210
+ time_embedding_act_fn: Optional[str] = None,
211
+ timestep_post_act: Optional[str] = None,
212
+ time_cond_proj_dim: Optional[int] = None,
213
+ conv_in_kernel: int = 3,
214
+ conv_out_kernel: int = 3,
215
+ projection_class_embeddings_input_dim: Optional[int] = None,
216
+ attention_type: str = "default",
217
+ class_embeddings_concat: bool = False,
218
+ mid_block_only_cross_attention: Optional[bool] = None,
219
+ cross_attention_norm: Optional[str] = None,
220
+ addition_embed_type_num_heads=64,
221
+ ):
222
+ super().__init__()
223
+
224
+ self.sample_size = sample_size
225
+
226
+ if num_attention_heads is not None:
227
+ raise ValueError(
228
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
229
+ )
230
+
231
+ # If `num_attention_heads` is not defined (which is the case for most models)
232
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
233
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
234
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
235
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
236
+ # which is why we correct for the naming here.
237
+ num_attention_heads = num_attention_heads or attention_head_dim
238
+
239
+ # Check inputs
240
+ if len(down_block_types) != len(up_block_types):
241
+ raise ValueError(
242
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
243
+ )
244
+
245
+ if len(block_out_channels) != len(down_block_types):
246
+ raise ValueError(
247
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
248
+ )
249
+
250
+ if not isinstance(only_cross_attention, bool) and len(
251
+ only_cross_attention
252
+ ) != len(down_block_types):
253
+ raise ValueError(
254
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
255
+ )
256
+
257
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
258
+ down_block_types
259
+ ):
260
+ raise ValueError(
261
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
262
+ )
263
+
264
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
265
+ down_block_types
266
+ ):
267
+ raise ValueError(
268
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
269
+ )
270
+
271
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
272
+ down_block_types
273
+ ):
274
+ raise ValueError(
275
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
276
+ )
277
+
278
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
279
+ down_block_types
280
+ ):
281
+ raise ValueError(
282
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
283
+ )
284
+ if (
285
+ isinstance(transformer_layers_per_block, list)
286
+ and reverse_transformer_layers_per_block is None
287
+ ):
288
+ for layer_number_per_block in transformer_layers_per_block:
289
+ if isinstance(layer_number_per_block, list):
290
+ raise ValueError(
291
+ "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet."
292
+ )
293
+
294
+ # input
295
+ conv_in_padding = (conv_in_kernel - 1) // 2
296
+ self.conv_in = nn.Conv2d(
297
+ in_channels,
298
+ block_out_channels[0],
299
+ kernel_size=conv_in_kernel,
300
+ padding=conv_in_padding,
301
+ )
302
+
303
+ # time
304
+ if time_embedding_type == "fourier":
305
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
306
+ if time_embed_dim % 2 != 0:
307
+ raise ValueError(
308
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
309
+ )
310
+ self.time_proj = GaussianFourierProjection(
311
+ time_embed_dim // 2,
312
+ set_W_to_weight=False,
313
+ log=False,
314
+ flip_sin_to_cos=flip_sin_to_cos,
315
+ )
316
+ timestep_input_dim = time_embed_dim
317
+ elif time_embedding_type == "positional":
318
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
319
+
320
+ self.time_proj = Timesteps(
321
+ block_out_channels[0], flip_sin_to_cos, freq_shift
322
+ )
323
+ timestep_input_dim = block_out_channels[0]
324
+ else:
325
+ raise ValueError(
326
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
327
+ )
328
+
329
+ self.time_embedding = TimestepEmbedding(
330
+ timestep_input_dim,
331
+ time_embed_dim,
332
+ act_fn=act_fn,
333
+ post_act_fn=timestep_post_act,
334
+ cond_proj_dim=time_cond_proj_dim,
335
+ )
336
+
337
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
338
+ encoder_hid_dim_type = "text_proj"
339
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
340
+ logger.info(
341
+ "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
342
+ )
343
+
344
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
345
+ raise ValueError(
346
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
347
+ )
348
+
349
+ if encoder_hid_dim_type == "text_proj":
350
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
351
+ elif encoder_hid_dim_type == "text_image_proj":
352
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
353
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
354
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
355
+ self.encoder_hid_proj = TextImageProjection(
356
+ text_embed_dim=encoder_hid_dim,
357
+ image_embed_dim=cross_attention_dim,
358
+ cross_attention_dim=cross_attention_dim,
359
+ )
360
+ elif encoder_hid_dim_type == "image_proj":
361
+ # Kandinsky 2.2
362
+ self.encoder_hid_proj = ImageProjection(
363
+ image_embed_dim=encoder_hid_dim,
364
+ cross_attention_dim=cross_attention_dim,
365
+ )
366
+ elif encoder_hid_dim_type is not None:
367
+ raise ValueError(
368
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
369
+ )
370
+ else:
371
+ self.encoder_hid_proj = None
372
+
373
+ # class embedding
374
+ if class_embed_type is None and num_class_embeds is not None:
375
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
376
+ elif class_embed_type == "timestep":
377
+ self.class_embedding = TimestepEmbedding(
378
+ timestep_input_dim, time_embed_dim, act_fn=act_fn
379
+ )
380
+ elif class_embed_type == "identity":
381
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
382
+ elif class_embed_type == "projection":
383
+ if projection_class_embeddings_input_dim is None:
384
+ raise ValueError(
385
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
386
+ )
387
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
388
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
389
+ # 2. it projects from an arbitrary input dimension.
390
+ #
391
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
392
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
393
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
394
+ self.class_embedding = TimestepEmbedding(
395
+ projection_class_embeddings_input_dim, time_embed_dim
396
+ )
397
+ elif class_embed_type == "simple_projection":
398
+ if projection_class_embeddings_input_dim is None:
399
+ raise ValueError(
400
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
401
+ )
402
+ self.class_embedding = nn.Linear(
403
+ projection_class_embeddings_input_dim, time_embed_dim
404
+ )
405
+ else:
406
+ self.class_embedding = None
407
+
408
+ if addition_embed_type == "text":
409
+ if encoder_hid_dim is not None:
410
+ text_time_embedding_from_dim = encoder_hid_dim
411
+ else:
412
+ text_time_embedding_from_dim = cross_attention_dim
413
+
414
+ self.add_embedding = TextTimeEmbedding(
415
+ text_time_embedding_from_dim,
416
+ time_embed_dim,
417
+ num_heads=addition_embed_type_num_heads,
418
+ )
419
+ elif addition_embed_type == "text_image":
420
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
421
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
422
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
423
+ self.add_embedding = TextImageTimeEmbedding(
424
+ text_embed_dim=cross_attention_dim,
425
+ image_embed_dim=cross_attention_dim,
426
+ time_embed_dim=time_embed_dim,
427
+ )
428
+ elif addition_embed_type == "text_time":
429
+ self.add_time_proj = Timesteps(
430
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
431
+ )
432
+ self.add_embedding = TimestepEmbedding(
433
+ projection_class_embeddings_input_dim, time_embed_dim
434
+ )
435
+ elif addition_embed_type == "image":
436
+ # Kandinsky 2.2
437
+ self.add_embedding = ImageTimeEmbedding(
438
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
439
+ )
440
+ elif addition_embed_type == "image_hint":
441
+ # Kandinsky 2.2 ControlNet
442
+ self.add_embedding = ImageHintTimeEmbedding(
443
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
444
+ )
445
+ elif addition_embed_type is not None:
446
+ raise ValueError(
447
+ f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
448
+ )
449
+
450
+ if time_embedding_act_fn is None:
451
+ self.time_embed_act = None
452
+ else:
453
+ self.time_embed_act = get_activation(time_embedding_act_fn)
454
+
455
+ self.down_blocks = nn.ModuleList([])
456
+ self.up_blocks = nn.ModuleList([])
457
+
458
+ if isinstance(only_cross_attention, bool):
459
+ if mid_block_only_cross_attention is None:
460
+ mid_block_only_cross_attention = only_cross_attention
461
+
462
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
463
+
464
+ if mid_block_only_cross_attention is None:
465
+ mid_block_only_cross_attention = False
466
+
467
+ if isinstance(num_attention_heads, int):
468
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
469
+
470
+ if isinstance(attention_head_dim, int):
471
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
472
+
473
+ if isinstance(cross_attention_dim, int):
474
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
475
+
476
+ if isinstance(layers_per_block, int):
477
+ layers_per_block = [layers_per_block] * len(down_block_types)
478
+
479
+ if isinstance(transformer_layers_per_block, int):
480
+ transformer_layers_per_block = [transformer_layers_per_block] * len(
481
+ down_block_types
482
+ )
483
+
484
+ if class_embeddings_concat:
485
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
486
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
487
+ # regular time embeddings
488
+ blocks_time_embed_dim = time_embed_dim * 2
489
+ else:
490
+ blocks_time_embed_dim = time_embed_dim
491
+
492
+ # down
493
+ output_channel = block_out_channels[0]
494
+ for i, down_block_type in enumerate(down_block_types):
495
+ input_channel = output_channel
496
+ output_channel = block_out_channels[i]
497
+ is_final_block = i == len(block_out_channels) - 1
498
+
499
+ down_block = get_down_block(
500
+ down_block_type,
501
+ num_layers=layers_per_block[i],
502
+ transformer_layers_per_block=transformer_layers_per_block[i],
503
+ in_channels=input_channel,
504
+ out_channels=output_channel,
505
+ temb_channels=blocks_time_embed_dim,
506
+ add_downsample=not is_final_block,
507
+ resnet_eps=norm_eps,
508
+ resnet_act_fn=act_fn,
509
+ resnet_groups=norm_num_groups,
510
+ cross_attention_dim=cross_attention_dim[i],
511
+ num_attention_heads=num_attention_heads[i],
512
+ downsample_padding=downsample_padding,
513
+ dual_cross_attention=dual_cross_attention,
514
+ use_linear_projection=use_linear_projection,
515
+ only_cross_attention=only_cross_attention[i],
516
+ upcast_attention=upcast_attention,
517
+ resnet_time_scale_shift=resnet_time_scale_shift,
518
+ attention_type=attention_type,
519
+ resnet_skip_time_act=resnet_skip_time_act,
520
+ resnet_out_scale_factor=resnet_out_scale_factor,
521
+ cross_attention_norm=cross_attention_norm,
522
+ attention_head_dim=attention_head_dim[i]
523
+ if attention_head_dim[i] is not None
524
+ else output_channel,
525
+ dropout=dropout,
526
+ )
527
+ self.down_blocks.append(down_block)
528
+
529
+ # mid
530
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
531
+ self.mid_block = UNetMidBlock2DCrossAttn(
532
+ transformer_layers_per_block=transformer_layers_per_block[-1],
533
+ in_channels=block_out_channels[-1],
534
+ temb_channels=blocks_time_embed_dim,
535
+ dropout=dropout,
536
+ resnet_eps=norm_eps,
537
+ resnet_act_fn=act_fn,
538
+ output_scale_factor=mid_block_scale_factor,
539
+ resnet_time_scale_shift=resnet_time_scale_shift,
540
+ cross_attention_dim=cross_attention_dim[-1],
541
+ num_attention_heads=num_attention_heads[-1],
542
+ resnet_groups=norm_num_groups,
543
+ dual_cross_attention=dual_cross_attention,
544
+ use_linear_projection=use_linear_projection,
545
+ upcast_attention=upcast_attention,
546
+ attention_type=attention_type,
547
+ )
548
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
549
+ raise NotImplementedError(f"Unsupport mid_block_type: {mid_block_type}")
550
+ elif mid_block_type == "UNetMidBlock2D":
551
+ self.mid_block = UNetMidBlock2D(
552
+ in_channels=block_out_channels[-1],
553
+ temb_channels=blocks_time_embed_dim,
554
+ dropout=dropout,
555
+ num_layers=0,
556
+ resnet_eps=norm_eps,
557
+ resnet_act_fn=act_fn,
558
+ output_scale_factor=mid_block_scale_factor,
559
+ resnet_groups=norm_num_groups,
560
+ resnet_time_scale_shift=resnet_time_scale_shift,
561
+ add_attention=False,
562
+ )
563
+ elif mid_block_type is None:
564
+ self.mid_block = None
565
+ else:
566
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
567
+
568
+ # count how many layers upsample the images
569
+ self.num_upsamplers = 0
570
+
571
+ # up
572
+ reversed_block_out_channels = list(reversed(block_out_channels))
573
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
574
+ reversed_layers_per_block = list(reversed(layers_per_block))
575
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
576
+ reversed_transformer_layers_per_block = (
577
+ list(reversed(transformer_layers_per_block))
578
+ if reverse_transformer_layers_per_block is None
579
+ else reverse_transformer_layers_per_block
580
+ )
581
+ only_cross_attention = list(reversed(only_cross_attention))
582
+
583
+ output_channel = reversed_block_out_channels[0]
584
+ for i, up_block_type in enumerate(up_block_types):
585
+ is_final_block = i == len(block_out_channels) - 1
586
+
587
+ prev_output_channel = output_channel
588
+ output_channel = reversed_block_out_channels[i]
589
+ input_channel = reversed_block_out_channels[
590
+ min(i + 1, len(block_out_channels) - 1)
591
+ ]
592
+
593
+ # add upsample block for all BUT final layer
594
+ if not is_final_block:
595
+ add_upsample = True
596
+ self.num_upsamplers += 1
597
+ else:
598
+ add_upsample = False
599
+
600
+ up_block = get_up_block(
601
+ up_block_type,
602
+ num_layers=reversed_layers_per_block[i] + 1,
603
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
604
+ in_channels=input_channel,
605
+ out_channels=output_channel,
606
+ prev_output_channel=prev_output_channel,
607
+ temb_channels=blocks_time_embed_dim,
608
+ add_upsample=add_upsample,
609
+ resnet_eps=norm_eps,
610
+ resnet_act_fn=act_fn,
611
+ resolution_idx=i,
612
+ resnet_groups=norm_num_groups,
613
+ cross_attention_dim=reversed_cross_attention_dim[i],
614
+ num_attention_heads=reversed_num_attention_heads[i],
615
+ dual_cross_attention=dual_cross_attention,
616
+ use_linear_projection=use_linear_projection,
617
+ only_cross_attention=only_cross_attention[i],
618
+ upcast_attention=upcast_attention,
619
+ resnet_time_scale_shift=resnet_time_scale_shift,
620
+ attention_type=attention_type,
621
+ resnet_skip_time_act=resnet_skip_time_act,
622
+ resnet_out_scale_factor=resnet_out_scale_factor,
623
+ cross_attention_norm=cross_attention_norm,
624
+ attention_head_dim=attention_head_dim[i]
625
+ if attention_head_dim[i] is not None
626
+ else output_channel,
627
+ dropout=dropout,
628
+ )
629
+ self.up_blocks.append(up_block)
630
+ prev_output_channel = output_channel
631
+
632
+ # out
633
+ if norm_num_groups is not None:
634
+ self.conv_norm_out = nn.GroupNorm(
635
+ num_channels=block_out_channels[0],
636
+ num_groups=norm_num_groups,
637
+ eps=norm_eps,
638
+ )
639
+
640
+ self.conv_act = get_activation(act_fn)
641
+
642
+ else:
643
+ self.conv_norm_out = None
644
+ self.conv_act = None
645
+ self.conv_norm_out = None
646
+
647
+ conv_out_padding = (conv_out_kernel - 1) // 2
648
+ # self.conv_out = nn.Conv2d(
649
+ # block_out_channels[0],
650
+ # out_channels,
651
+ # kernel_size=conv_out_kernel,
652
+ # padding=conv_out_padding,
653
+ # )
654
+
655
+ if attention_type in ["gated", "gated-text-image"]:
656
+ positive_len = 768
657
+ if isinstance(cross_attention_dim, int):
658
+ positive_len = cross_attention_dim
659
+ elif isinstance(cross_attention_dim, tuple) or isinstance(
660
+ cross_attention_dim, list
661
+ ):
662
+ positive_len = cross_attention_dim[0]
663
+
664
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
665
+ self.position_net = PositionNet(
666
+ positive_len=positive_len,
667
+ out_dim=cross_attention_dim,
668
+ feature_type=feature_type,
669
+ )
670
+
671
+ @property
672
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
673
+ r"""
674
+ Returns:
675
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
676
+ indexed by its weight name.
677
+ """
678
+ # set recursively
679
+ processors = {}
680
+
681
+ def fn_recursive_add_processors(
682
+ name: str,
683
+ module: torch.nn.Module,
684
+ processors: Dict[str, AttentionProcessor],
685
+ ):
686
+ if hasattr(module, "get_processor"):
687
+ processors[f"{name}.processor"] = module.get_processor(
688
+ return_deprecated_lora=True
689
+ )
690
+
691
+ for sub_name, child in module.named_children():
692
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
693
+
694
+ return processors
695
+
696
+ for name, module in self.named_children():
697
+ fn_recursive_add_processors(name, module, processors)
698
+
699
+ return processors
700
+
701
+ def set_attn_processor(
702
+ self,
703
+ processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
704
+ _remove_lora=False,
705
+ ):
706
+ r"""
707
+ Sets the attention processor to use to compute attention.
708
+
709
+ Parameters:
710
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
711
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
712
+ for **all** `Attention` layers.
713
+
714
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
715
+ processor. This is strongly recommended when setting trainable attention processors.
716
+
717
+ """
718
+ count = len(self.attn_processors.keys())
719
+
720
+ if isinstance(processor, dict) and len(processor) != count:
721
+ raise ValueError(
722
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
723
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
724
+ )
725
+
726
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
727
+ if hasattr(module, "set_processor"):
728
+ if not isinstance(processor, dict):
729
+ module.set_processor(processor, _remove_lora=_remove_lora)
730
+ else:
731
+ module.set_processor(
732
+ processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
733
+ )
734
+
735
+ for sub_name, child in module.named_children():
736
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
737
+
738
+ for name, module in self.named_children():
739
+ fn_recursive_attn_processor(name, module, processor)
740
+
741
+ def set_default_attn_processor(self):
742
+ """
743
+ Disables custom attention processors and sets the default attention implementation.
744
+ """
745
+ if all(
746
+ proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
747
+ for proc in self.attn_processors.values()
748
+ ):
749
+ processor = AttnAddedKVProcessor()
750
+ elif all(
751
+ proc.__class__ in CROSS_ATTENTION_PROCESSORS
752
+ for proc in self.attn_processors.values()
753
+ ):
754
+ processor = AttnProcessor()
755
+ else:
756
+ raise ValueError(
757
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
758
+ )
759
+
760
+ self.set_attn_processor(processor, _remove_lora=True)
761
+
762
+ def set_attention_slice(self, slice_size):
763
+ r"""
764
+ Enable sliced attention computation.
765
+
766
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
767
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
768
+
769
+ Args:
770
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
771
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
772
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
773
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
774
+ must be a multiple of `slice_size`.
775
+ """
776
+ sliceable_head_dims = []
777
+
778
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
779
+ if hasattr(module, "set_attention_slice"):
780
+ sliceable_head_dims.append(module.sliceable_head_dim)
781
+
782
+ for child in module.children():
783
+ fn_recursive_retrieve_sliceable_dims(child)
784
+
785
+ # retrieve number of attention layers
786
+ for module in self.children():
787
+ fn_recursive_retrieve_sliceable_dims(module)
788
+
789
+ num_sliceable_layers = len(sliceable_head_dims)
790
+
791
+ if slice_size == "auto":
792
+ # half the attention head size is usually a good trade-off between
793
+ # speed and memory
794
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
795
+ elif slice_size == "max":
796
+ # make smallest slice possible
797
+ slice_size = num_sliceable_layers * [1]
798
+
799
+ slice_size = (
800
+ num_sliceable_layers * [slice_size]
801
+ if not isinstance(slice_size, list)
802
+ else slice_size
803
+ )
804
+
805
+ if len(slice_size) != len(sliceable_head_dims):
806
+ raise ValueError(
807
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
808
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
809
+ )
810
+
811
+ for i in range(len(slice_size)):
812
+ size = slice_size[i]
813
+ dim = sliceable_head_dims[i]
814
+ if size is not None and size > dim:
815
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
816
+
817
+ # Recursively walk through all the children.
818
+ # Any children which exposes the set_attention_slice method
819
+ # gets the message
820
+ def fn_recursive_set_attention_slice(
821
+ module: torch.nn.Module, slice_size: List[int]
822
+ ):
823
+ if hasattr(module, "set_attention_slice"):
824
+ module.set_attention_slice(slice_size.pop())
825
+
826
+ for child in module.children():
827
+ fn_recursive_set_attention_slice(child, slice_size)
828
+
829
+ reversed_slice_size = list(reversed(slice_size))
830
+ for module in self.children():
831
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
832
+
833
+ def _set_gradient_checkpointing(self, module, value=False):
834
+ if hasattr(module, "gradient_checkpointing"):
835
+ module.gradient_checkpointing = value
836
+
837
+ def enable_freeu(self, s1, s2, b1, b2):
838
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
839
+
840
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
841
+
842
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
843
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
844
+
845
+ Args:
846
+ s1 (`float`):
847
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
848
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
849
+ s2 (`float`):
850
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
851
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
852
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
853
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
854
+ """
855
+ for i, upsample_block in enumerate(self.up_blocks):
856
+ setattr(upsample_block, "s1", s1)
857
+ setattr(upsample_block, "s2", s2)
858
+ setattr(upsample_block, "b1", b1)
859
+ setattr(upsample_block, "b2", b2)
860
+
861
+ def disable_freeu(self):
862
+ """Disables the FreeU mechanism."""
863
+ freeu_keys = {"s1", "s2", "b1", "b2"}
864
+ for i, upsample_block in enumerate(self.up_blocks):
865
+ for k in freeu_keys:
866
+ if (
867
+ hasattr(upsample_block, k)
868
+ or getattr(upsample_block, k, None) is not None
869
+ ):
870
+ setattr(upsample_block, k, None)
871
+
872
+ def forward(
873
+ self,
874
+ sample: torch.FloatTensor,
875
+ timestep: Union[torch.Tensor, float, int],
876
+ encoder_hidden_states: torch.Tensor,
877
+ class_labels: Optional[torch.Tensor] = None,
878
+ timestep_cond: Optional[torch.Tensor] = None,
879
+ attention_mask: Optional[torch.Tensor] = None,
880
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
881
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
882
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
883
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
884
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
885
+ encoder_attention_mask: Optional[torch.Tensor] = None,
886
+ return_dict: bool = True,
887
+ ) -> Union[UNet2DConditionOutput, Tuple]:
888
+ r"""
889
+ The [`UNet2DConditionModel`] forward method.
890
+
891
+ Args:
892
+ sample (`torch.FloatTensor`):
893
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
894
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
895
+ encoder_hidden_states (`torch.FloatTensor`):
896
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
897
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
898
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
899
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
900
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
901
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
902
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
903
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
904
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
905
+ negative values to the attention scores corresponding to "discard" tokens.
906
+ cross_attention_kwargs (`dict`, *optional*):
907
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
908
+ `self.processor` in
909
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
910
+ added_cond_kwargs: (`dict`, *optional*):
911
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
912
+ are passed along to the UNet blocks.
913
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
914
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
915
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
916
+ A tensor that if specified is added to the residual of the middle unet block.
917
+ encoder_attention_mask (`torch.Tensor`):
918
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
919
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
920
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
921
+ return_dict (`bool`, *optional*, defaults to `True`):
922
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
923
+ tuple.
924
+ cross_attention_kwargs (`dict`, *optional*):
925
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
926
+ added_cond_kwargs: (`dict`, *optional*):
927
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
928
+ are passed along to the UNet blocks.
929
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
930
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
931
+ example from ControlNet side model(s)
932
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
933
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
934
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
935
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
936
+
937
+ Returns:
938
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
939
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
940
+ a `tuple` is returned where the first element is the sample tensor.
941
+ """
942
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
943
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
944
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
945
+ # on the fly if necessary.
946
+ default_overall_up_factor = 2**self.num_upsamplers
947
+
948
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
949
+ forward_upsample_size = False
950
+ upsample_size = None
951
+
952
+ for dim in sample.shape[-2:]:
953
+ if dim % default_overall_up_factor != 0:
954
+ # Forward upsample size to force interpolation output size.
955
+ forward_upsample_size = True
956
+ break
957
+
958
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
959
+ # expects mask of shape:
960
+ # [batch, key_tokens]
961
+ # adds singleton query_tokens dimension:
962
+ # [batch, 1, key_tokens]
963
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
964
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
965
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
966
+ if attention_mask is not None:
967
+ # assume that mask is expressed as:
968
+ # (1 = keep, 0 = discard)
969
+ # convert mask into a bias that can be added to attention scores:
970
+ # (keep = +0, discard = -10000.0)
971
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
972
+ attention_mask = attention_mask.unsqueeze(1)
973
+
974
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
975
+ if encoder_attention_mask is not None:
976
+ encoder_attention_mask = (
977
+ 1 - encoder_attention_mask.to(sample.dtype)
978
+ ) * -10000.0
979
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
980
+
981
+ # 0. center input if necessary
982
+ if self.config.center_input_sample:
983
+ sample = 2 * sample - 1.0
984
+
985
+ # 1. time
986
+ timesteps = timestep
987
+ if not torch.is_tensor(timesteps):
988
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
989
+ # This would be a good case for the `match` statement (Python 3.10+)
990
+ is_mps = sample.device.type == "mps"
991
+ if isinstance(timestep, float):
992
+ dtype = torch.float32 if is_mps else torch.float64
993
+ else:
994
+ dtype = torch.int32 if is_mps else torch.int64
995
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
996
+ elif len(timesteps.shape) == 0:
997
+ timesteps = timesteps[None].to(sample.device)
998
+
999
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1000
+ timesteps = timesteps.expand(sample.shape[0])
1001
+
1002
+ t_emb = self.time_proj(timesteps)
1003
+
1004
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1005
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1006
+ # there might be better ways to encapsulate this.
1007
+ t_emb = t_emb.to(dtype=sample.dtype)
1008
+
1009
+ emb = self.time_embedding(t_emb, timestep_cond)
1010
+ aug_emb = None
1011
+
1012
+ if self.class_embedding is not None:
1013
+ if class_labels is None:
1014
+ raise ValueError(
1015
+ "class_labels should be provided when num_class_embeds > 0"
1016
+ )
1017
+
1018
+ if self.config.class_embed_type == "timestep":
1019
+ class_labels = self.time_proj(class_labels)
1020
+
1021
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1022
+ # there might be better ways to encapsulate this.
1023
+ class_labels = class_labels.to(dtype=sample.dtype)
1024
+
1025
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1026
+
1027
+ if self.config.class_embeddings_concat:
1028
+ emb = torch.cat([emb, class_emb], dim=-1)
1029
+ else:
1030
+ emb = emb + class_emb
1031
+
1032
+ if self.config.addition_embed_type == "text":
1033
+ aug_emb = self.add_embedding(encoder_hidden_states)
1034
+ elif self.config.addition_embed_type == "text_image":
1035
+ # Kandinsky 2.1 - style
1036
+ if "image_embeds" not in added_cond_kwargs:
1037
+ raise ValueError(
1038
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1039
+ )
1040
+
1041
+ image_embs = added_cond_kwargs.get("image_embeds")
1042
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1043
+ aug_emb = self.add_embedding(text_embs, image_embs)
1044
+ elif self.config.addition_embed_type == "text_time":
1045
+ # SDXL - style
1046
+ if "text_embeds" not in added_cond_kwargs:
1047
+ raise ValueError(
1048
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1049
+ )
1050
+ text_embeds = added_cond_kwargs.get("text_embeds")
1051
+ if "time_ids" not in added_cond_kwargs:
1052
+ raise ValueError(
1053
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1054
+ )
1055
+ time_ids = added_cond_kwargs.get("time_ids")
1056
+ time_embeds = self.add_time_proj(time_ids.flatten())
1057
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1058
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1059
+ add_embeds = add_embeds.to(emb.dtype)
1060
+ aug_emb = self.add_embedding(add_embeds)
1061
+ elif self.config.addition_embed_type == "image":
1062
+ # Kandinsky 2.2 - style
1063
+ if "image_embeds" not in added_cond_kwargs:
1064
+ raise ValueError(
1065
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1066
+ )
1067
+ image_embs = added_cond_kwargs.get("image_embeds")
1068
+ aug_emb = self.add_embedding(image_embs)
1069
+ elif self.config.addition_embed_type == "image_hint":
1070
+ # Kandinsky 2.2 - style
1071
+ if (
1072
+ "image_embeds" not in added_cond_kwargs
1073
+ or "hint" not in added_cond_kwargs
1074
+ ):
1075
+ raise ValueError(
1076
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1077
+ )
1078
+ image_embs = added_cond_kwargs.get("image_embeds")
1079
+ hint = added_cond_kwargs.get("hint")
1080
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1081
+ sample = torch.cat([sample, hint], dim=1)
1082
+
1083
+ emb = emb + aug_emb if aug_emb is not None else emb
1084
+
1085
+ if self.time_embed_act is not None:
1086
+ emb = self.time_embed_act(emb)
1087
+
1088
+ if (
1089
+ self.encoder_hid_proj is not None
1090
+ and self.config.encoder_hid_dim_type == "text_proj"
1091
+ ):
1092
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1093
+ elif (
1094
+ self.encoder_hid_proj is not None
1095
+ and self.config.encoder_hid_dim_type == "text_image_proj"
1096
+ ):
1097
+ # Kadinsky 2.1 - style
1098
+ if "image_embeds" not in added_cond_kwargs:
1099
+ raise ValueError(
1100
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1101
+ )
1102
+
1103
+ image_embeds = added_cond_kwargs.get("image_embeds")
1104
+ encoder_hidden_states = self.encoder_hid_proj(
1105
+ encoder_hidden_states, image_embeds
1106
+ )
1107
+ elif (
1108
+ self.encoder_hid_proj is not None
1109
+ and self.config.encoder_hid_dim_type == "image_proj"
1110
+ ):
1111
+ # Kandinsky 2.2 - style
1112
+ if "image_embeds" not in added_cond_kwargs:
1113
+ raise ValueError(
1114
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1115
+ )
1116
+ image_embeds = added_cond_kwargs.get("image_embeds")
1117
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1118
+ elif (
1119
+ self.encoder_hid_proj is not None
1120
+ and self.config.encoder_hid_dim_type == "ip_image_proj"
1121
+ ):
1122
+ if "image_embeds" not in added_cond_kwargs:
1123
+ raise ValueError(
1124
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1125
+ )
1126
+ image_embeds = added_cond_kwargs.get("image_embeds")
1127
+ image_embeds = self.encoder_hid_proj(image_embeds).to(
1128
+ encoder_hidden_states.dtype
1129
+ )
1130
+ encoder_hidden_states = torch.cat(
1131
+ [encoder_hidden_states, image_embeds], dim=1
1132
+ )
1133
+
1134
+ # 2. pre-process
1135
+ sample = self.conv_in(sample)
1136
+
1137
+ # 2.5 GLIGEN position net
1138
+ if (
1139
+ cross_attention_kwargs is not None
1140
+ and cross_attention_kwargs.get("gligen", None) is not None
1141
+ ):
1142
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1143
+ gligen_args = cross_attention_kwargs.pop("gligen")
1144
+ cross_attention_kwargs["gligen"] = {
1145
+ "objs": self.position_net(**gligen_args)
1146
+ }
1147
+
1148
+ # 3. down
1149
+ lora_scale = (
1150
+ cross_attention_kwargs.get("scale", 1.0)
1151
+ if cross_attention_kwargs is not None
1152
+ else 1.0
1153
+ )
1154
+ if USE_PEFT_BACKEND:
1155
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1156
+ scale_lora_layers(self, lora_scale)
1157
+
1158
+ is_controlnet = (
1159
+ mid_block_additional_residual is not None
1160
+ and down_block_additional_residuals is not None
1161
+ )
1162
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1163
+ is_adapter = down_intrablock_additional_residuals is not None
1164
+ # maintain backward compatibility for legacy usage, where
1165
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1166
+ # but can only use one or the other
1167
+ if (
1168
+ not is_adapter
1169
+ and mid_block_additional_residual is None
1170
+ and down_block_additional_residuals is not None
1171
+ ):
1172
+ deprecate(
1173
+ "T2I should not use down_block_additional_residuals",
1174
+ "1.3.0",
1175
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1176
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1177
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1178
+ standard_warn=False,
1179
+ )
1180
+ down_intrablock_additional_residuals = down_block_additional_residuals
1181
+ is_adapter = True
1182
+
1183
+ down_block_res_samples = (sample,)
1184
+ tot_referece_features = ()
1185
+ for downsample_block in self.down_blocks:
1186
+ if (
1187
+ hasattr(downsample_block, "has_cross_attention")
1188
+ and downsample_block.has_cross_attention
1189
+ ):
1190
+ # For t2i-adapter CrossAttnDownBlock2D
1191
+ additional_residuals = {}
1192
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1193
+ additional_residuals[
1194
+ "additional_residuals"
1195
+ ] = down_intrablock_additional_residuals.pop(0)
1196
+
1197
+ sample, res_samples = downsample_block(
1198
+ hidden_states=sample,
1199
+ temb=emb,
1200
+ encoder_hidden_states=encoder_hidden_states,
1201
+ attention_mask=attention_mask,
1202
+ cross_attention_kwargs=cross_attention_kwargs,
1203
+ encoder_attention_mask=encoder_attention_mask,
1204
+ **additional_residuals,
1205
+ )
1206
+ else:
1207
+ sample, res_samples = downsample_block(
1208
+ hidden_states=sample, temb=emb, scale=lora_scale
1209
+ )
1210
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1211
+ sample += down_intrablock_additional_residuals.pop(0)
1212
+
1213
+ down_block_res_samples += res_samples
1214
+
1215
+ if is_controlnet:
1216
+ new_down_block_res_samples = ()
1217
+
1218
+ for down_block_res_sample, down_block_additional_residual in zip(
1219
+ down_block_res_samples, down_block_additional_residuals
1220
+ ):
1221
+ down_block_res_sample = (
1222
+ down_block_res_sample + down_block_additional_residual
1223
+ )
1224
+ new_down_block_res_samples = new_down_block_res_samples + (
1225
+ down_block_res_sample,
1226
+ )
1227
+
1228
+ down_block_res_samples = new_down_block_res_samples
1229
+
1230
+ # 4. mid
1231
+ if self.mid_block is not None:
1232
+ if (
1233
+ hasattr(self.mid_block, "has_cross_attention")
1234
+ and self.mid_block.has_cross_attention
1235
+ ):
1236
+ sample = self.mid_block(
1237
+ sample,
1238
+ emb,
1239
+ encoder_hidden_states=encoder_hidden_states,
1240
+ attention_mask=attention_mask,
1241
+ cross_attention_kwargs=cross_attention_kwargs,
1242
+ encoder_attention_mask=encoder_attention_mask,
1243
+ )
1244
+ else:
1245
+ sample = self.mid_block(sample, emb)
1246
+
1247
+ # To support T2I-Adapter-XL
1248
+ if (
1249
+ is_adapter
1250
+ and len(down_intrablock_additional_residuals) > 0
1251
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1252
+ ):
1253
+ sample += down_intrablock_additional_residuals.pop(0)
1254
+
1255
+ if is_controlnet:
1256
+ sample = sample + mid_block_additional_residual
1257
+
1258
+ # 5. up
1259
+ for i, upsample_block in enumerate(self.up_blocks):
1260
+ is_final_block = i == len(self.up_blocks) - 1
1261
+
1262
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1263
+ down_block_res_samples = down_block_res_samples[
1264
+ : -len(upsample_block.resnets)
1265
+ ]
1266
+
1267
+ # if we have not reached the final block and need to forward the
1268
+ # upsample size, we do it here
1269
+ if not is_final_block and forward_upsample_size:
1270
+ upsample_size = down_block_res_samples[-1].shape[2:]
1271
+
1272
+ if (
1273
+ hasattr(upsample_block, "has_cross_attention")
1274
+ and upsample_block.has_cross_attention
1275
+ ):
1276
+ sample = upsample_block(
1277
+ hidden_states=sample,
1278
+ temb=emb,
1279
+ res_hidden_states_tuple=res_samples,
1280
+ encoder_hidden_states=encoder_hidden_states,
1281
+ cross_attention_kwargs=cross_attention_kwargs,
1282
+ upsample_size=upsample_size,
1283
+ attention_mask=attention_mask,
1284
+ encoder_attention_mask=encoder_attention_mask,
1285
+ )
1286
+ else:
1287
+ sample = upsample_block(
1288
+ hidden_states=sample,
1289
+ temb=emb,
1290
+ res_hidden_states_tuple=res_samples,
1291
+ upsample_size=upsample_size,
1292
+ scale=lora_scale,
1293
+ )
1294
+
1295
+ # 6. post-process
1296
+ # if self.conv_norm_out:
1297
+ # sample = self.conv_norm_out(sample)
1298
+ # sample = self.conv_act(sample)
1299
+ # sample = self.conv_out(sample)
1300
+
1301
+ if USE_PEFT_BACKEND:
1302
+ # remove `lora_scale` from each PEFT layer
1303
+ unscale_lora_layers(self, lora_scale)
1304
+
1305
+ if not return_dict:
1306
+ return (sample,)
1307
+
1308
+ return UNet2DConditionOutput(sample=sample)
src/models/unet_3d.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
2
+
3
+ from collections import OrderedDict
4
+ from dataclasses import dataclass
5
+ from os import PathLike
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.models.attention_processor import AttentionProcessor
14
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
15
+ try:
16
+ from diffusers.modeling_utils import ModelMixin
17
+ except:
18
+ from diffusers.models.modeling_utils import ModelMixin
19
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
20
+ from safetensors.torch import load_file
21
+
22
+ from .resnet import InflatedConv3d, InflatedGroupNorm
23
+ from .unet_3d_blocks import UNetMidBlock3DCrossAttn, get_down_block, get_up_block
24
+
25
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
26
+ from einops import rearrange
27
+ from scipy.io import loadmat, savemat
28
+
29
+
30
+ @dataclass
31
+ class UNet3DConditionOutput(BaseOutput):
32
+ sample: torch.FloatTensor
33
+
34
+
35
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
36
+ _supports_gradient_checkpointing = True
37
+
38
+ @register_to_config
39
+ def __init__(
40
+ self,
41
+ sample_size: Optional[int] = None,
42
+ in_channels: int = 4,
43
+ out_channels: int = 4,
44
+ center_input_sample: bool = False,
45
+ flip_sin_to_cos: bool = True,
46
+ freq_shift: int = 0,
47
+ down_block_types: Tuple[str] = (
48
+ "CrossAttnDownBlock3D",
49
+ "CrossAttnDownBlock3D",
50
+ "CrossAttnDownBlock3D",
51
+ "DownBlock3D",
52
+ ),
53
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
54
+ up_block_types: Tuple[str] = (
55
+ "UpBlock3D",
56
+ "CrossAttnUpBlock3D",
57
+ "CrossAttnUpBlock3D",
58
+ "CrossAttnUpBlock3D",
59
+ ),
60
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
61
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
62
+ layers_per_block: int = 2,
63
+ downsample_padding: int = 1,
64
+ mid_block_scale_factor: float = 1,
65
+ act_fn: str = "silu",
66
+ norm_num_groups: int = 32,
67
+ norm_eps: float = 1e-5,
68
+ cross_attention_dim: int = 1280,
69
+ attention_head_dim: Union[int, Tuple[int]] = 8,
70
+ dual_cross_attention: bool = False,
71
+ use_linear_projection: bool = False,
72
+ class_embed_type: Optional[str] = None,
73
+ num_class_embeds: Optional[int] = None,
74
+ upcast_attention: bool = False,
75
+ resnet_time_scale_shift: str = "default",
76
+ use_inflated_groupnorm=False,
77
+ # Additional
78
+ use_motion_module=False, ######
79
+ motion_module_resolutions=(1, 2, 4, 8), ####
80
+ motion_module_mid_block=False, #####
81
+ motion_module_decoder_only=False, #####
82
+ motion_module_type=None, #####
83
+ motion_module_kwargs={}, #####
84
+ unet_use_cross_frame_attention=None, #####
85
+ unet_use_temporal_attention=None, #####
86
+ mode=None, #####
87
+ task_type="action", #####
88
+ ):
89
+ super().__init__()
90
+
91
+ self.sample_size = sample_size
92
+ time_embed_dim = block_out_channels[0] * 4
93
+
94
+ # input
95
+ self.conv_in = InflatedConv3d(
96
+ in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
97
+ )
98
+
99
+ # time
100
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
101
+ timestep_input_dim = block_out_channels[0]
102
+
103
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
104
+
105
+ # class embedding
106
+ if class_embed_type is None and num_class_embeds is not None:
107
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
108
+ elif class_embed_type == "timestep":
109
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
110
+ elif class_embed_type == "identity":
111
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
112
+ else:
113
+ self.class_embedding = None
114
+
115
+ self.down_blocks = nn.ModuleList([])
116
+ self.mid_block = None
117
+ self.up_blocks = nn.ModuleList([])
118
+
119
+ if isinstance(only_cross_attention, bool):
120
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
121
+
122
+ if isinstance(attention_head_dim, int):
123
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
124
+
125
+ # down
126
+ output_channel = block_out_channels[0]
127
+ for i, down_block_type in enumerate(down_block_types):
128
+ if task_type == "action":
129
+ name_index, mid_name = None, None
130
+ else:
131
+ name_index, mid_name = i, "MidBlock"
132
+ res = 2**i
133
+ input_channel = output_channel
134
+ output_channel = block_out_channels[i]
135
+ is_final_block = i == len(block_out_channels) - 1
136
+
137
+ down_block = get_down_block(
138
+ down_block_type,
139
+ num_layers=layers_per_block,
140
+ in_channels=input_channel,
141
+ out_channels=output_channel,
142
+ temb_channels=time_embed_dim,
143
+ add_downsample=not is_final_block,
144
+ resnet_eps=norm_eps,
145
+ resnet_act_fn=act_fn,
146
+ resnet_groups=norm_num_groups,
147
+ cross_attention_dim=cross_attention_dim,
148
+ attn_num_head_channels=attention_head_dim[i],
149
+ downsample_padding=downsample_padding,
150
+ dual_cross_attention=dual_cross_attention,
151
+ use_linear_projection=use_linear_projection,
152
+ only_cross_attention=only_cross_attention[i],
153
+ upcast_attention=upcast_attention,
154
+ resnet_time_scale_shift=resnet_time_scale_shift,
155
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
156
+ unet_use_temporal_attention=unet_use_temporal_attention,
157
+ use_inflated_groupnorm=use_inflated_groupnorm,
158
+ use_motion_module=use_motion_module
159
+ and (res in motion_module_resolutions)
160
+ and (not motion_module_decoder_only),
161
+ motion_module_type=motion_module_type,
162
+ motion_module_kwargs=motion_module_kwargs,
163
+ name_index=name_index, #####
164
+ )
165
+ self.down_blocks.append(down_block)
166
+
167
+ # mid
168
+
169
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
170
+ self.mid_block = UNetMidBlock3DCrossAttn(
171
+ in_channels=block_out_channels[-1],
172
+ temb_channels=time_embed_dim,
173
+ resnet_eps=norm_eps,
174
+ resnet_act_fn=act_fn,
175
+ output_scale_factor=mid_block_scale_factor,
176
+ resnet_time_scale_shift=resnet_time_scale_shift,
177
+ cross_attention_dim=cross_attention_dim,
178
+ attn_num_head_channels=attention_head_dim[-1],
179
+ resnet_groups=norm_num_groups,
180
+ dual_cross_attention=dual_cross_attention,
181
+ use_linear_projection=use_linear_projection,
182
+ upcast_attention=upcast_attention,
183
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
184
+ unet_use_temporal_attention=unet_use_temporal_attention,
185
+ use_inflated_groupnorm=use_inflated_groupnorm,
186
+ use_motion_module=use_motion_module and motion_module_mid_block,
187
+ motion_module_type=motion_module_type,
188
+ motion_module_kwargs=motion_module_kwargs,
189
+ name=mid_name,
190
+ )
191
+ else:
192
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
193
+
194
+ # count how many layers upsample the videos
195
+ self.num_upsamplers = 0
196
+
197
+ # up
198
+ reversed_block_out_channels = list(reversed(block_out_channels))
199
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
200
+ only_cross_attention = list(reversed(only_cross_attention))
201
+ output_channel = reversed_block_out_channels[0]
202
+ for i, up_block_type in enumerate(up_block_types):
203
+ res = 2 ** (3 - i)
204
+ is_final_block = i == len(block_out_channels) - 1
205
+
206
+ if task_type == "action":
207
+ name_index = None
208
+ else:
209
+ name_index = i
210
+
211
+ prev_output_channel = output_channel
212
+ output_channel = reversed_block_out_channels[i]
213
+ input_channel = reversed_block_out_channels[
214
+ min(i + 1, len(block_out_channels) - 1)
215
+ ]
216
+
217
+ # add upsample block for all BUT final layer
218
+ if not is_final_block:
219
+ add_upsample = True
220
+ self.num_upsamplers += 1
221
+ else:
222
+ add_upsample = False
223
+
224
+ up_block = get_up_block(
225
+ up_block_type,
226
+ num_layers=layers_per_block + 1,
227
+ in_channels=input_channel,
228
+ out_channels=output_channel,
229
+ prev_output_channel=prev_output_channel,
230
+ temb_channels=time_embed_dim,
231
+ add_upsample=add_upsample,
232
+ resnet_eps=norm_eps,
233
+ resnet_act_fn=act_fn,
234
+ resnet_groups=norm_num_groups,
235
+ cross_attention_dim=cross_attention_dim,
236
+ attn_num_head_channels=reversed_attention_head_dim[i],
237
+ dual_cross_attention=dual_cross_attention,
238
+ use_linear_projection=use_linear_projection,
239
+ only_cross_attention=only_cross_attention[i],
240
+ upcast_attention=upcast_attention,
241
+ resnet_time_scale_shift=resnet_time_scale_shift,
242
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
243
+ unet_use_temporal_attention=unet_use_temporal_attention,
244
+ use_inflated_groupnorm=use_inflated_groupnorm,
245
+ use_motion_module=use_motion_module
246
+ and (res in motion_module_resolutions),
247
+ motion_module_type=motion_module_type,
248
+ motion_module_kwargs=motion_module_kwargs,
249
+ name_index=name_index,
250
+ )
251
+ self.up_blocks.append(up_block)
252
+ prev_output_channel = output_channel
253
+
254
+ # out
255
+ if use_inflated_groupnorm:
256
+ self.conv_norm_out = InflatedGroupNorm(
257
+ num_channels=block_out_channels[0],
258
+ num_groups=norm_num_groups,
259
+ eps=norm_eps,
260
+ )
261
+ else:
262
+ self.conv_norm_out = nn.GroupNorm(
263
+ num_channels=block_out_channels[0],
264
+ num_groups=norm_num_groups,
265
+ eps=norm_eps,
266
+ )
267
+ self.conv_act = nn.SiLU()
268
+ self.conv_out = InflatedConv3d(
269
+ block_out_channels[0], out_channels, kernel_size=3, padding=1
270
+ )
271
+
272
+ self.mode = mode
273
+
274
+ @property
275
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
276
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
277
+ r"""
278
+ Returns:
279
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
280
+ indexed by its weight name.
281
+ """
282
+ # set recursively
283
+ processors = {}
284
+
285
+ def fn_recursive_add_processors(
286
+ name: str,
287
+ module: torch.nn.Module,
288
+ processors: Dict[str, AttentionProcessor],
289
+ ):
290
+ if hasattr(module, "set_processor"):
291
+ processors[f"{name}.processor"] = module.processor
292
+
293
+ for sub_name, child in module.named_children():
294
+ if "temporal_transformer" not in sub_name:
295
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
296
+
297
+ return processors
298
+
299
+ for name, module in self.named_children():
300
+ if "temporal_transformer" not in name:
301
+ fn_recursive_add_processors(name, module, processors)
302
+
303
+ return processors
304
+
305
+ def set_attention_slice(self, slice_size):
306
+ r"""
307
+ Enable sliced attention computation.
308
+
309
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
310
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
311
+
312
+ Args:
313
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
314
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
315
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
316
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
317
+ must be a multiple of `slice_size`.
318
+ """
319
+ sliceable_head_dims = []
320
+
321
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
322
+ if hasattr(module, "set_attention_slice"):
323
+ sliceable_head_dims.append(module.sliceable_head_dim)
324
+
325
+ for child in module.children():
326
+ fn_recursive_retrieve_slicable_dims(child)
327
+
328
+ # retrieve number of attention layers
329
+ for module in self.children():
330
+ fn_recursive_retrieve_slicable_dims(module)
331
+
332
+ num_slicable_layers = len(sliceable_head_dims)
333
+
334
+ if slice_size == "auto":
335
+ # half the attention head size is usually a good trade-off between
336
+ # speed and memory
337
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
338
+ elif slice_size == "max":
339
+ # make smallest slice possible
340
+ slice_size = num_slicable_layers * [1]
341
+
342
+ slice_size = (
343
+ num_slicable_layers * [slice_size]
344
+ if not isinstance(slice_size, list)
345
+ else slice_size
346
+ )
347
+
348
+ if len(slice_size) != len(sliceable_head_dims):
349
+ raise ValueError(
350
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
351
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
352
+ )
353
+
354
+ for i in range(len(slice_size)):
355
+ size = slice_size[i]
356
+ dim = sliceable_head_dims[i]
357
+ if size is not None and size > dim:
358
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
359
+
360
+ # Recursively walk through all the children.
361
+ # Any children which exposes the set_attention_slice method
362
+ # gets the message
363
+ def fn_recursive_set_attention_slice(
364
+ module: torch.nn.Module, slice_size: List[int]
365
+ ):
366
+ if hasattr(module, "set_attention_slice"):
367
+ module.set_attention_slice(slice_size.pop())
368
+
369
+ for child in module.children():
370
+ fn_recursive_set_attention_slice(child, slice_size)
371
+
372
+ reversed_slice_size = list(reversed(slice_size))
373
+ for module in self.children():
374
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
375
+
376
+ def _set_gradient_checkpointing(self, module, value=False):
377
+ if hasattr(module, "gradient_checkpointing"):
378
+ module.gradient_checkpointing = value
379
+
380
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
381
+ def set_attn_processor(
382
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
383
+ ):
384
+ r"""
385
+ Sets the attention processor to use to compute attention.
386
+
387
+ Parameters:
388
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
389
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
390
+ for **all** `Attention` layers.
391
+
392
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
393
+ processor. This is strongly recommended when setting trainable attention processors.
394
+
395
+ """
396
+ count = len(self.attn_processors.keys())
397
+
398
+ if isinstance(processor, dict) and len(processor) != count:
399
+ raise ValueError(
400
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
401
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
402
+ )
403
+
404
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
405
+ if hasattr(module, "set_processor"):
406
+ if not isinstance(processor, dict):
407
+ module.set_processor(processor)
408
+ else:
409
+ module.set_processor(processor.pop(f"{name}.processor"))
410
+
411
+ for sub_name, child in module.named_children():
412
+ if "temporal_transformer" not in sub_name:
413
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
414
+
415
+ for name, module in self.named_children():
416
+ if "temporal_transformer" not in name:
417
+ fn_recursive_attn_processor(name, module, processor)
418
+
419
+ def forward(
420
+ self,
421
+ sample: torch.FloatTensor,
422
+ timestep: Union[torch.Tensor, float, int],
423
+ encoder_hidden_states: torch.Tensor,
424
+ class_labels: Optional[torch.Tensor] = None,
425
+ pose_cond_fea: Optional[torch.Tensor] = None,
426
+ attention_mask: Optional[torch.Tensor] = None,
427
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
428
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
429
+ return_dict: bool = True,
430
+ self_attention_additional_feats = None,
431
+ ) -> Union[UNet3DConditionOutput, Tuple]:
432
+ r"""
433
+ Args:
434
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
435
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
436
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
437
+ return_dict (`bool`, *optional*, defaults to `True`):
438
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
439
+
440
+ Returns:
441
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
442
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
443
+ returning a tuple, the first element is the sample tensor.
444
+ """
445
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
446
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
447
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
448
+ # on the fly if necessary.
449
+ default_overall_up_factor = 2**self.num_upsamplers
450
+
451
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
452
+ forward_upsample_size = False
453
+ upsample_size = None
454
+
455
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
456
+ logger.info("Forward upsample size to force interpolation output size.")
457
+ forward_upsample_size = True
458
+
459
+ # prepare attention_mask
460
+ if attention_mask is not None:
461
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
462
+ attention_mask = attention_mask.unsqueeze(1)
463
+
464
+ # center input if necessary
465
+ if self.config.center_input_sample:
466
+ sample = 2 * sample - 1.0
467
+
468
+ # time
469
+ timesteps = timestep
470
+ if not torch.is_tensor(timesteps):
471
+ # This would be a good case for the `match` statement (Python 3.10+)
472
+ is_mps = sample.device.type == "mps"
473
+ if isinstance(timestep, float):
474
+ dtype = torch.float32 if is_mps else torch.float64
475
+ else:
476
+ dtype = torch.int32 if is_mps else torch.int64
477
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
478
+ elif len(timesteps.shape) == 0:
479
+ timesteps = timesteps[None].to(sample.device)
480
+
481
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
482
+ timesteps = timesteps.expand(sample.shape[0])
483
+
484
+ t_emb = self.time_proj(timesteps)
485
+
486
+ # timesteps does not contain any weights and will always return f32 tensors
487
+ # but time_embedding might actually be running in fp16. so we need to cast here.
488
+ # there might be better ways to encapsulate this.
489
+ t_emb = t_emb.to(dtype=self.dtype)
490
+ emb = self.time_embedding(t_emb)
491
+
492
+ if self.class_embedding is not None:
493
+ if class_labels is None:
494
+ raise ValueError(
495
+ "class_labels should be provided when num_class_embeds > 0"
496
+ )
497
+
498
+ if self.config.class_embed_type == "timestep":
499
+ class_labels = self.time_proj(class_labels)
500
+
501
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
502
+ emb = emb + class_emb
503
+
504
+ # pre-process
505
+ sample = self.conv_in(sample)
506
+ if pose_cond_fea is not None:
507
+ sample = sample + pose_cond_fea
508
+
509
+ # down
510
+ down_block_res_samples = (sample,)
511
+ for downsample_block in self.down_blocks:
512
+ if (
513
+ hasattr(downsample_block, "has_cross_attention")
514
+ and downsample_block.has_cross_attention
515
+ ):
516
+ sample, res_samples = downsample_block(
517
+ hidden_states=sample,
518
+ temb=emb,
519
+ encoder_hidden_states=encoder_hidden_states,
520
+ attention_mask=attention_mask,
521
+ self_attention_additional_feats=self_attention_additional_feats,
522
+ mode=self.mode,
523
+ )
524
+ else:
525
+ sample, res_samples = downsample_block(
526
+ hidden_states=sample,
527
+ temb=emb,
528
+ encoder_hidden_states=encoder_hidden_states,
529
+ )
530
+
531
+ down_block_res_samples += res_samples
532
+
533
+ if down_block_additional_residuals is not None:
534
+ new_down_block_res_samples = ()
535
+
536
+ for down_block_res_sample, down_block_additional_residual in zip(
537
+ down_block_res_samples, down_block_additional_residuals
538
+ ):
539
+ down_block_additional_residual = rearrange(
540
+ down_block_additional_residual, "(b f) c h w -> b c f h w", f=down_block_res_sample.shape[2]
541
+ )
542
+ down_block_res_sample = (
543
+ down_block_res_sample + down_block_additional_residual
544
+ )
545
+ new_down_block_res_samples += (down_block_res_sample,)
546
+
547
+ down_block_res_samples = new_down_block_res_samples
548
+
549
+ # mid
550
+ sample = self.mid_block(
551
+ sample,
552
+ emb,
553
+ encoder_hidden_states=encoder_hidden_states,
554
+ attention_mask=attention_mask,
555
+ self_attention_additional_feats=self_attention_additional_feats,
556
+ mode=self.mode,
557
+ )
558
+
559
+ if mid_block_additional_residual is not None:
560
+ mid_block_additional_residual = rearrange(
561
+ mid_block_additional_residual, "(b f) c h w -> b c f h w", f=sample.shape[2]
562
+ )
563
+ sample = sample + mid_block_additional_residual
564
+
565
+ # up
566
+ for i, upsample_block in enumerate(self.up_blocks):
567
+ is_final_block = i == len(self.up_blocks) - 1
568
+
569
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
570
+ down_block_res_samples = down_block_res_samples[
571
+ : -len(upsample_block.resnets)
572
+ ]
573
+
574
+ # if we have not reached the final block and need to forward the
575
+ # upsample size, we do it here
576
+ if not is_final_block and forward_upsample_size:
577
+ upsample_size = down_block_res_samples[-1].shape[2:]
578
+
579
+ if (
580
+ hasattr(upsample_block, "has_cross_attention")
581
+ and upsample_block.has_cross_attention
582
+ ):
583
+ sample = upsample_block(
584
+ hidden_states=sample,
585
+ temb=emb,
586
+ res_hidden_states_tuple=res_samples,
587
+ encoder_hidden_states=encoder_hidden_states,
588
+ upsample_size=upsample_size,
589
+ attention_mask=attention_mask,
590
+ self_attention_additional_feats=self_attention_additional_feats,
591
+ mode=self.mode,
592
+ )
593
+ else:
594
+ sample = upsample_block(
595
+ hidden_states=sample,
596
+ temb=emb,
597
+ res_hidden_states_tuple=res_samples,
598
+ upsample_size=upsample_size,
599
+ encoder_hidden_states=encoder_hidden_states,
600
+ )
601
+
602
+ # post-process
603
+ sample = self.conv_norm_out(sample)
604
+ sample = self.conv_act(sample)
605
+ sample = self.conv_out(sample)
606
+
607
+ if not return_dict:
608
+ return (sample,)
609
+
610
+ return UNet3DConditionOutput(sample=sample)
611
+
612
+ @classmethod
613
+ def from_pretrained_2d(
614
+ cls,
615
+ pretrained_model_path: PathLike,
616
+ motion_module_path: PathLike,
617
+ subfolder=None,
618
+ unet_additional_kwargs=None,
619
+ mm_zero_proj_out=False,
620
+ ):
621
+ pretrained_model_path = Path(pretrained_model_path)
622
+ motion_module_path = Path(motion_module_path)
623
+ if subfolder is not None:
624
+ pretrained_model_path = pretrained_model_path.joinpath(subfolder)
625
+ logger.info(
626
+ f"loaded temporal unet's pretrained weights from {pretrained_model_path} ..."
627
+ )
628
+
629
+ config_file = pretrained_model_path / "config.json"
630
+ if not (config_file.exists() and config_file.is_file()):
631
+ raise RuntimeError(f"{config_file} does not exist or is not a file")
632
+
633
+ unet_config = cls.load_config(config_file)
634
+ unet_config["_class_name"] = cls.__name__
635
+ unet_config["down_block_types"] = [
636
+ "CrossAttnDownBlock3D",
637
+ "CrossAttnDownBlock3D",
638
+ "CrossAttnDownBlock3D",
639
+ "DownBlock3D",
640
+ ]
641
+ unet_config["up_block_types"] = [
642
+ "UpBlock3D",
643
+ "CrossAttnUpBlock3D",
644
+ "CrossAttnUpBlock3D",
645
+ "CrossAttnUpBlock3D",
646
+ ]
647
+ unet_config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
648
+
649
+ model = cls.from_config(unet_config, **unet_additional_kwargs)
650
+ # load the vanilla weights
651
+ # for key in unet_additional_kwargs['motion_module_kwargs'].keys():
652
+ # print(key)
653
+ # print(unet_additional_kwargs['motion_module_kwargs'][key])
654
+ # print('__________________')
655
+ if pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME).exists():
656
+ logger.debug(
657
+ f"loading safeTensors weights from {pretrained_model_path} ..."
658
+ )
659
+ state_dict = load_file(
660
+ pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME), device="cpu"
661
+ )
662
+
663
+ elif pretrained_model_path.joinpath(WEIGHTS_NAME).exists():
664
+ logger.debug(f"loading weights from {pretrained_model_path} ...")
665
+ state_dict = torch.load(
666
+ pretrained_model_path.joinpath(WEIGHTS_NAME),
667
+ map_location="cpu",
668
+ weights_only=True,
669
+ )
670
+ else:
671
+ raise FileNotFoundError(f"no weights file found in {pretrained_model_path}")
672
+
673
+ # load the motion module weights
674
+ if motion_module_path.exists() and motion_module_path.is_file():
675
+ if motion_module_path.suffix.lower() in [".pth", ".pt", ".ckpt"]:
676
+ logger.info(f"Load motion module params from {motion_module_path}")
677
+ motion_state_dict = torch.load(
678
+ motion_module_path, map_location="cpu", weights_only=True
679
+ )
680
+ elif motion_module_path.suffix.lower() == ".safetensors":
681
+ motion_state_dict = load_file(motion_module_path, device="cpu")
682
+ else:
683
+ raise RuntimeError(
684
+ f"unknown file format for motion module weights: {motion_module_path.suffix}"
685
+ )
686
+ if mm_zero_proj_out:
687
+ logger.info(f"Zero initialize proj_out layers in motion module...")
688
+ new_motion_state_dict = OrderedDict()
689
+ for k in motion_state_dict:
690
+ if "proj_out" in k:
691
+ continue
692
+ new_motion_state_dict[k] = motion_state_dict[k]
693
+ motion_state_dict = new_motion_state_dict
694
+
695
+ # merge the state dicts
696
+ state_dict.update(motion_state_dict)
697
+
698
+ # load the weights into the model
699
+ m, u = model.load_state_dict(state_dict, strict=False)
700
+ logger.debug(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
701
+
702
+ params = [
703
+ p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()
704
+ ]
705
+ logger.info(f"Loaded {sum(params) / 1e6}M-parameter motion module")
706
+
707
+ return model
src/models/unet_3d_blocks.py ADDED
@@ -0,0 +1,906 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import pdb
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from .motion_module import get_motion_module
9
+
10
+ # from .motion_module import get_motion_module
11
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
12
+ from .transformer_3d import Transformer3DModel
13
+
14
+
15
+ def get_down_block(
16
+ down_block_type,
17
+ num_layers,
18
+ in_channels,
19
+ out_channels,
20
+ temb_channels,
21
+ add_downsample,
22
+ resnet_eps,
23
+ resnet_act_fn,
24
+ attn_num_head_channels,
25
+ resnet_groups=None,
26
+ cross_attention_dim=None,
27
+ downsample_padding=None,
28
+ dual_cross_attention=False,
29
+ use_linear_projection=False,
30
+ only_cross_attention=False,
31
+ upcast_attention=False,
32
+ resnet_time_scale_shift="default",
33
+ unet_use_cross_frame_attention=None,
34
+ unet_use_temporal_attention=None,
35
+ use_inflated_groupnorm=None,
36
+ use_motion_module=None,
37
+ motion_module_type=None,
38
+ motion_module_kwargs=None,
39
+ name_index=None,
40
+ ):
41
+ down_block_type = (
42
+ down_block_type[7:]
43
+ if down_block_type.startswith("UNetRes")
44
+ else down_block_type
45
+ )
46
+ if down_block_type == "DownBlock3D":
47
+ return DownBlock3D(
48
+ num_layers=num_layers,
49
+ in_channels=in_channels,
50
+ out_channels=out_channels,
51
+ temb_channels=temb_channels,
52
+ add_downsample=add_downsample,
53
+ resnet_eps=resnet_eps,
54
+ resnet_act_fn=resnet_act_fn,
55
+ resnet_groups=resnet_groups,
56
+ downsample_padding=downsample_padding,
57
+ resnet_time_scale_shift=resnet_time_scale_shift,
58
+ use_inflated_groupnorm=use_inflated_groupnorm,
59
+ use_motion_module=use_motion_module,
60
+ motion_module_type=motion_module_type,
61
+ motion_module_kwargs=motion_module_kwargs,
62
+ )
63
+ elif down_block_type == "CrossAttnDownBlock3D":
64
+ if cross_attention_dim is None:
65
+ raise ValueError(
66
+ "cross_attention_dim must be specified for CrossAttnDownBlock3D"
67
+ )
68
+ if name_index is not None:
69
+ name_index = f"CrossAttnDownBlock_{name_index}_"
70
+ return CrossAttnDownBlock3D(
71
+ num_layers=num_layers,
72
+ in_channels=in_channels,
73
+ out_channels=out_channels,
74
+ temb_channels=temb_channels,
75
+ add_downsample=add_downsample,
76
+ resnet_eps=resnet_eps,
77
+ resnet_act_fn=resnet_act_fn,
78
+ resnet_groups=resnet_groups,
79
+ downsample_padding=downsample_padding,
80
+ cross_attention_dim=cross_attention_dim,
81
+ attn_num_head_channels=attn_num_head_channels,
82
+ dual_cross_attention=dual_cross_attention,
83
+ use_linear_projection=use_linear_projection,
84
+ only_cross_attention=only_cross_attention,
85
+ upcast_attention=upcast_attention,
86
+ resnet_time_scale_shift=resnet_time_scale_shift,
87
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
88
+ unet_use_temporal_attention=unet_use_temporal_attention,
89
+ use_inflated_groupnorm=use_inflated_groupnorm,
90
+ use_motion_module=use_motion_module,
91
+ motion_module_type=motion_module_type,
92
+ motion_module_kwargs=motion_module_kwargs,
93
+ name=name_index,
94
+ )
95
+ raise ValueError(f"{down_block_type} does not exist.")
96
+
97
+
98
+ def get_up_block(
99
+ up_block_type,
100
+ num_layers,
101
+ in_channels,
102
+ out_channels,
103
+ prev_output_channel,
104
+ temb_channels,
105
+ add_upsample,
106
+ resnet_eps,
107
+ resnet_act_fn,
108
+ attn_num_head_channels,
109
+ resnet_groups=None,
110
+ cross_attention_dim=None,
111
+ dual_cross_attention=False,
112
+ use_linear_projection=False,
113
+ only_cross_attention=False,
114
+ upcast_attention=False,
115
+ resnet_time_scale_shift="default",
116
+ unet_use_cross_frame_attention=None,
117
+ unet_use_temporal_attention=None,
118
+ use_inflated_groupnorm=None,
119
+ use_motion_module=None,
120
+ motion_module_type=None,
121
+ motion_module_kwargs=None,
122
+ name_index=None,
123
+ ):
124
+ up_block_type = (
125
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
126
+ )
127
+ if up_block_type == "UpBlock3D":
128
+ return UpBlock3D(
129
+ num_layers=num_layers,
130
+ in_channels=in_channels,
131
+ out_channels=out_channels,
132
+ prev_output_channel=prev_output_channel,
133
+ temb_channels=temb_channels,
134
+ add_upsample=add_upsample,
135
+ resnet_eps=resnet_eps,
136
+ resnet_act_fn=resnet_act_fn,
137
+ resnet_groups=resnet_groups,
138
+ resnet_time_scale_shift=resnet_time_scale_shift,
139
+ use_inflated_groupnorm=use_inflated_groupnorm,
140
+ use_motion_module=use_motion_module,
141
+ motion_module_type=motion_module_type,
142
+ motion_module_kwargs=motion_module_kwargs,
143
+ )
144
+ elif up_block_type == "CrossAttnUpBlock3D":
145
+ if cross_attention_dim is None:
146
+ raise ValueError(
147
+ "cross_attention_dim must be specified for CrossAttnUpBlock3D"
148
+ )
149
+ if name_index is not None:
150
+ name_index = f"CrossAttnUpBlock_{name_index}_"
151
+ return CrossAttnUpBlock3D(
152
+ num_layers=num_layers,
153
+ in_channels=in_channels,
154
+ out_channels=out_channels,
155
+ prev_output_channel=prev_output_channel,
156
+ temb_channels=temb_channels,
157
+ add_upsample=add_upsample,
158
+ resnet_eps=resnet_eps,
159
+ resnet_act_fn=resnet_act_fn,
160
+ resnet_groups=resnet_groups,
161
+ cross_attention_dim=cross_attention_dim,
162
+ attn_num_head_channels=attn_num_head_channels,
163
+ dual_cross_attention=dual_cross_attention,
164
+ use_linear_projection=use_linear_projection,
165
+ only_cross_attention=only_cross_attention,
166
+ upcast_attention=upcast_attention,
167
+ resnet_time_scale_shift=resnet_time_scale_shift,
168
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
169
+ unet_use_temporal_attention=unet_use_temporal_attention,
170
+ use_inflated_groupnorm=use_inflated_groupnorm,
171
+ use_motion_module=use_motion_module,
172
+ motion_module_type=motion_module_type,
173
+ motion_module_kwargs=motion_module_kwargs,
174
+ name=name_index,
175
+ )
176
+ raise ValueError(f"{up_block_type} does not exist.")
177
+
178
+
179
+ class UNetMidBlock3DCrossAttn(nn.Module):
180
+ def __init__(
181
+ self,
182
+ in_channels: int,
183
+ temb_channels: int,
184
+ dropout: float = 0.0,
185
+ num_layers: int = 1,
186
+ resnet_eps: float = 1e-6,
187
+ resnet_time_scale_shift: str = "default",
188
+ resnet_act_fn: str = "swish",
189
+ resnet_groups: int = 32,
190
+ resnet_pre_norm: bool = True,
191
+ attn_num_head_channels=1,
192
+ output_scale_factor=1.0,
193
+ cross_attention_dim=1280,
194
+ dual_cross_attention=False,
195
+ use_linear_projection=False,
196
+ upcast_attention=False,
197
+ unet_use_cross_frame_attention=None,
198
+ unet_use_temporal_attention=None,
199
+ use_inflated_groupnorm=None,
200
+ use_motion_module=None,
201
+ motion_module_type=None,
202
+ motion_module_kwargs=None,
203
+ name=None
204
+ ):
205
+ super().__init__()
206
+
207
+ self.has_cross_attention = True
208
+ self.attn_num_head_channels = attn_num_head_channels
209
+ resnet_groups = (
210
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
211
+ )
212
+ self.name = name
213
+ # there is always at least one resnet
214
+ resnets = [
215
+ ResnetBlock3D(
216
+ in_channels=in_channels,
217
+ out_channels=in_channels,
218
+ temb_channels=temb_channels,
219
+ eps=resnet_eps,
220
+ groups=resnet_groups,
221
+ dropout=dropout,
222
+ time_embedding_norm=resnet_time_scale_shift,
223
+ non_linearity=resnet_act_fn,
224
+ output_scale_factor=output_scale_factor,
225
+ pre_norm=resnet_pre_norm,
226
+ use_inflated_groupnorm=use_inflated_groupnorm,
227
+ )
228
+ ]
229
+ attentions = []
230
+ motion_modules = []
231
+ for i in range(num_layers):
232
+ if dual_cross_attention:
233
+ raise NotImplementedError
234
+ if self.name is not None:
235
+ attn_name = f"{self.name}_{i}_TransformerModel"
236
+ else:
237
+ attn_name = None
238
+ attentions.append(
239
+ Transformer3DModel(
240
+ attn_num_head_channels,
241
+ in_channels // attn_num_head_channels,
242
+ in_channels=in_channels,
243
+ num_layers=1,
244
+ cross_attention_dim=cross_attention_dim,
245
+ norm_num_groups=resnet_groups,
246
+ use_linear_projection=use_linear_projection,
247
+ upcast_attention=upcast_attention,
248
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
249
+ unet_use_temporal_attention=unet_use_temporal_attention,
250
+ name=attn_name,
251
+ )
252
+ )
253
+ motion_modules.append(
254
+ get_motion_module(
255
+ in_channels=in_channels,
256
+ motion_module_type=motion_module_type,
257
+ motion_module_kwargs=motion_module_kwargs,
258
+ )
259
+ if use_motion_module
260
+ else None
261
+ )
262
+ resnets.append(
263
+ ResnetBlock3D(
264
+ in_channels=in_channels,
265
+ out_channels=in_channels,
266
+ temb_channels=temb_channels,
267
+ eps=resnet_eps,
268
+ groups=resnet_groups,
269
+ dropout=dropout,
270
+ time_embedding_norm=resnet_time_scale_shift,
271
+ non_linearity=resnet_act_fn,
272
+ output_scale_factor=output_scale_factor,
273
+ pre_norm=resnet_pre_norm,
274
+ use_inflated_groupnorm=use_inflated_groupnorm,
275
+ )
276
+ )
277
+
278
+ self.attentions = nn.ModuleList(attentions)
279
+ self.resnets = nn.ModuleList(resnets)
280
+ self.motion_modules = nn.ModuleList(motion_modules)
281
+
282
+ def forward(
283
+ self,
284
+ hidden_states,
285
+ temb=None,
286
+ encoder_hidden_states=None,
287
+ attention_mask=None,
288
+ self_attention_additional_feats=None,
289
+ mode=None,
290
+ ):
291
+ hidden_states = self.resnets[0](hidden_states, temb)
292
+ for attn, resnet, motion_module in zip(
293
+ self.attentions, self.resnets[1:], self.motion_modules
294
+ ):
295
+ hidden_states = attn(
296
+ hidden_states,
297
+ encoder_hidden_states=encoder_hidden_states,
298
+ self_attention_additional_feats=self_attention_additional_feats,
299
+ mode=mode,
300
+ ).sample
301
+ hidden_states = (
302
+ motion_module(
303
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
304
+ )
305
+ if motion_module is not None
306
+ else hidden_states
307
+ )
308
+ hidden_states = resnet(hidden_states, temb)
309
+
310
+ return hidden_states
311
+
312
+
313
+ class CrossAttnDownBlock3D(nn.Module):
314
+ def __init__(
315
+ self,
316
+ in_channels: int,
317
+ out_channels: int,
318
+ temb_channels: int,
319
+ dropout: float = 0.0,
320
+ num_layers: int = 1,
321
+ resnet_eps: float = 1e-6,
322
+ resnet_time_scale_shift: str = "default",
323
+ resnet_act_fn: str = "swish",
324
+ resnet_groups: int = 32,
325
+ resnet_pre_norm: bool = True,
326
+ attn_num_head_channels=1,
327
+ cross_attention_dim=1280,
328
+ output_scale_factor=1.0,
329
+ downsample_padding=1,
330
+ add_downsample=True,
331
+ dual_cross_attention=False,
332
+ use_linear_projection=False,
333
+ only_cross_attention=False,
334
+ upcast_attention=False,
335
+ unet_use_cross_frame_attention=None,
336
+ unet_use_temporal_attention=None,
337
+ use_inflated_groupnorm=None,
338
+ use_motion_module=None,
339
+ motion_module_type=None,
340
+ motion_module_kwargs=None,
341
+ name=None,
342
+ ):
343
+ super().__init__()
344
+ resnets = []
345
+ attentions = []
346
+ motion_modules = []
347
+
348
+ self.has_cross_attention = True
349
+ self.attn_num_head_channels = attn_num_head_channels
350
+ self.name=name
351
+
352
+ for i in range(num_layers):
353
+ in_channels = in_channels if i == 0 else out_channels
354
+ resnets.append(
355
+ ResnetBlock3D(
356
+ in_channels=in_channels,
357
+ out_channels=out_channels,
358
+ temb_channels=temb_channels,
359
+ eps=resnet_eps,
360
+ groups=resnet_groups,
361
+ dropout=dropout,
362
+ time_embedding_norm=resnet_time_scale_shift,
363
+ non_linearity=resnet_act_fn,
364
+ output_scale_factor=output_scale_factor,
365
+ pre_norm=resnet_pre_norm,
366
+ use_inflated_groupnorm=use_inflated_groupnorm,
367
+ )
368
+ )
369
+ if dual_cross_attention:
370
+ raise NotImplementedError
371
+ if self.name is not None:
372
+ attn_name = f"{self.name}_{i}_TransformerModel"
373
+ else:
374
+ attn_name = None
375
+ attentions.append(
376
+ Transformer3DModel(
377
+ attn_num_head_channels,
378
+ out_channels // attn_num_head_channels,
379
+ in_channels=out_channels,
380
+ num_layers=1,
381
+ cross_attention_dim=cross_attention_dim,
382
+ norm_num_groups=resnet_groups,
383
+ use_linear_projection=use_linear_projection,
384
+ only_cross_attention=only_cross_attention,
385
+ upcast_attention=upcast_attention,
386
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
387
+ unet_use_temporal_attention=unet_use_temporal_attention,
388
+ name=attn_name,
389
+ )
390
+ )
391
+ motion_modules.append(
392
+ get_motion_module(
393
+ in_channels=out_channels,
394
+ motion_module_type=motion_module_type,
395
+ motion_module_kwargs=motion_module_kwargs,
396
+ )
397
+ if use_motion_module
398
+ else None
399
+ )
400
+
401
+ self.attentions = nn.ModuleList(attentions)
402
+ self.resnets = nn.ModuleList(resnets)
403
+ self.motion_modules = nn.ModuleList(motion_modules)
404
+
405
+ if add_downsample:
406
+ self.downsamplers = nn.ModuleList(
407
+ [
408
+ Downsample3D(
409
+ out_channels,
410
+ use_conv=True,
411
+ out_channels=out_channels,
412
+ padding=downsample_padding,
413
+ name="op",
414
+ )
415
+ ]
416
+ )
417
+ else:
418
+ self.downsamplers = None
419
+
420
+ self.gradient_checkpointing = False
421
+
422
+ def forward(
423
+ self,
424
+ hidden_states,
425
+ temb=None,
426
+ encoder_hidden_states=None,
427
+ attention_mask=None,
428
+ self_attention_additional_feats=None,
429
+ mode=None,
430
+ ):
431
+ output_states = ()
432
+
433
+ for i, (resnet, attn, motion_module) in enumerate(
434
+ zip(self.resnets, self.attentions, self.motion_modules)
435
+ ):
436
+ # self.gradient_checkpointing = False
437
+ if self.training and self.gradient_checkpointing:
438
+
439
+ def create_custom_forward(module, return_dict=None):
440
+ def custom_forward(*inputs):
441
+ if return_dict is not None:
442
+ return module(*inputs, return_dict=return_dict)
443
+ else:
444
+ return module(*inputs)
445
+
446
+ return custom_forward
447
+
448
+ hidden_states = torch.utils.checkpoint.checkpoint(
449
+ create_custom_forward(resnet), hidden_states, temb
450
+ )
451
+ hidden_states = torch.utils.checkpoint.checkpoint(
452
+ create_custom_forward(attn, return_dict=False),
453
+ hidden_states,
454
+ encoder_hidden_states,
455
+ self_attention_additional_feats,
456
+ mode,
457
+ )[0]
458
+
459
+ # add motion module
460
+ if motion_module is not None:
461
+ hidden_states = torch.utils.checkpoint.checkpoint(
462
+ create_custom_forward(motion_module),
463
+ hidden_states.requires_grad_(),
464
+ temb,
465
+ encoder_hidden_states,
466
+ )
467
+
468
+ else:
469
+ hidden_states = resnet(hidden_states, temb)
470
+ hidden_states = attn(
471
+ hidden_states,
472
+ encoder_hidden_states=encoder_hidden_states,
473
+ self_attention_additional_feats=self_attention_additional_feats,
474
+ mode=mode,
475
+ ).sample
476
+
477
+ # add motion module
478
+ hidden_states = (
479
+ motion_module(
480
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
481
+ )
482
+ if motion_module is not None
483
+ else hidden_states
484
+ )
485
+
486
+ output_states += (hidden_states,)
487
+
488
+ if self.downsamplers is not None:
489
+ for downsampler in self.downsamplers:
490
+ hidden_states = downsampler(hidden_states)
491
+
492
+ output_states += (hidden_states,)
493
+
494
+ return hidden_states, output_states
495
+
496
+
497
+ class DownBlock3D(nn.Module):
498
+ def __init__(
499
+ self,
500
+ in_channels: int,
501
+ out_channels: int,
502
+ temb_channels: int,
503
+ dropout: float = 0.0,
504
+ num_layers: int = 1,
505
+ resnet_eps: float = 1e-6,
506
+ resnet_time_scale_shift: str = "default",
507
+ resnet_act_fn: str = "swish",
508
+ resnet_groups: int = 32,
509
+ resnet_pre_norm: bool = True,
510
+ output_scale_factor=1.0,
511
+ add_downsample=True,
512
+ downsample_padding=1,
513
+ use_inflated_groupnorm=None,
514
+ use_motion_module=None,
515
+ motion_module_type=None,
516
+ motion_module_kwargs=None,
517
+ ):
518
+ super().__init__()
519
+ resnets = []
520
+ motion_modules = []
521
+
522
+ # use_motion_module = False
523
+ for i in range(num_layers):
524
+ in_channels = in_channels if i == 0 else out_channels
525
+ resnets.append(
526
+ ResnetBlock3D(
527
+ in_channels=in_channels,
528
+ out_channels=out_channels,
529
+ temb_channels=temb_channels,
530
+ eps=resnet_eps,
531
+ groups=resnet_groups,
532
+ dropout=dropout,
533
+ time_embedding_norm=resnet_time_scale_shift,
534
+ non_linearity=resnet_act_fn,
535
+ output_scale_factor=output_scale_factor,
536
+ pre_norm=resnet_pre_norm,
537
+ use_inflated_groupnorm=use_inflated_groupnorm,
538
+ )
539
+ )
540
+ motion_modules.append(
541
+ get_motion_module(
542
+ in_channels=out_channels,
543
+ motion_module_type=motion_module_type,
544
+ motion_module_kwargs=motion_module_kwargs,
545
+ )
546
+ if use_motion_module
547
+ else None
548
+ )
549
+
550
+ self.resnets = nn.ModuleList(resnets)
551
+ self.motion_modules = nn.ModuleList(motion_modules)
552
+
553
+ if add_downsample:
554
+ self.downsamplers = nn.ModuleList(
555
+ [
556
+ Downsample3D(
557
+ out_channels,
558
+ use_conv=True,
559
+ out_channels=out_channels,
560
+ padding=downsample_padding,
561
+ name="op",
562
+ )
563
+ ]
564
+ )
565
+ else:
566
+ self.downsamplers = None
567
+
568
+ self.gradient_checkpointing = False
569
+
570
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
571
+ output_states = ()
572
+
573
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
574
+ # print(f"DownBlock3D {self.gradient_checkpointing = }")
575
+ if self.training and self.gradient_checkpointing:
576
+
577
+ def create_custom_forward(module):
578
+ def custom_forward(*inputs):
579
+ return module(*inputs)
580
+
581
+ return custom_forward
582
+
583
+ hidden_states = torch.utils.checkpoint.checkpoint(
584
+ create_custom_forward(resnet), hidden_states, temb
585
+ )
586
+ if motion_module is not None:
587
+ hidden_states = torch.utils.checkpoint.checkpoint(
588
+ create_custom_forward(motion_module),
589
+ hidden_states.requires_grad_(),
590
+ temb,
591
+ encoder_hidden_states,
592
+ )
593
+ else:
594
+ hidden_states = resnet(hidden_states, temb)
595
+
596
+ # add motion module
597
+ hidden_states = (
598
+ motion_module(
599
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
600
+ )
601
+ if motion_module is not None
602
+ else hidden_states
603
+ )
604
+
605
+ output_states += (hidden_states,)
606
+
607
+ if self.downsamplers is not None:
608
+ for downsampler in self.downsamplers:
609
+ hidden_states = downsampler(hidden_states)
610
+
611
+ output_states += (hidden_states,)
612
+
613
+ return hidden_states, output_states
614
+
615
+
616
+ class CrossAttnUpBlock3D(nn.Module):
617
+ def __init__(
618
+ self,
619
+ in_channels: int,
620
+ out_channels: int,
621
+ prev_output_channel: int,
622
+ temb_channels: int,
623
+ dropout: float = 0.0,
624
+ num_layers: int = 1,
625
+ resnet_eps: float = 1e-6,
626
+ resnet_time_scale_shift: str = "default",
627
+ resnet_act_fn: str = "swish",
628
+ resnet_groups: int = 32,
629
+ resnet_pre_norm: bool = True,
630
+ attn_num_head_channels=1,
631
+ cross_attention_dim=1280,
632
+ output_scale_factor=1.0,
633
+ add_upsample=True,
634
+ dual_cross_attention=False,
635
+ use_linear_projection=False,
636
+ only_cross_attention=False,
637
+ upcast_attention=False,
638
+ unet_use_cross_frame_attention=None,
639
+ unet_use_temporal_attention=None,
640
+ use_motion_module=None,
641
+ use_inflated_groupnorm=None,
642
+ motion_module_type=None,
643
+ motion_module_kwargs=None,
644
+ name=None
645
+ ):
646
+ super().__init__()
647
+ resnets = []
648
+ attentions = []
649
+ motion_modules = []
650
+
651
+ self.has_cross_attention = True
652
+ self.attn_num_head_channels = attn_num_head_channels
653
+ self.name = name
654
+
655
+ for i in range(num_layers):
656
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
657
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
658
+
659
+ resnets.append(
660
+ ResnetBlock3D(
661
+ in_channels=resnet_in_channels + res_skip_channels,
662
+ out_channels=out_channels,
663
+ temb_channels=temb_channels,
664
+ eps=resnet_eps,
665
+ groups=resnet_groups,
666
+ dropout=dropout,
667
+ time_embedding_norm=resnet_time_scale_shift,
668
+ non_linearity=resnet_act_fn,
669
+ output_scale_factor=output_scale_factor,
670
+ pre_norm=resnet_pre_norm,
671
+ use_inflated_groupnorm=use_inflated_groupnorm,
672
+ )
673
+ )
674
+ if dual_cross_attention:
675
+ raise NotImplementedError
676
+ if self.name is not None:
677
+ attn_name = f"{self.name}_{i}_TransformerModel"
678
+ else:
679
+ attn_name = None
680
+ attentions.append(
681
+ Transformer3DModel(
682
+ attn_num_head_channels,
683
+ out_channels // attn_num_head_channels,
684
+ in_channels=out_channels,
685
+ num_layers=1,
686
+ cross_attention_dim=cross_attention_dim,
687
+ norm_num_groups=resnet_groups,
688
+ use_linear_projection=use_linear_projection,
689
+ only_cross_attention=only_cross_attention,
690
+ upcast_attention=upcast_attention,
691
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
692
+ unet_use_temporal_attention=unet_use_temporal_attention,
693
+ name=attn_name,
694
+ )
695
+ )
696
+ motion_modules.append(
697
+ get_motion_module(
698
+ in_channels=out_channels,
699
+ motion_module_type=motion_module_type,
700
+ motion_module_kwargs=motion_module_kwargs,
701
+ )
702
+ if use_motion_module
703
+ else None
704
+ )
705
+
706
+ self.attentions = nn.ModuleList(attentions)
707
+ self.resnets = nn.ModuleList(resnets)
708
+ self.motion_modules = nn.ModuleList(motion_modules)
709
+
710
+ if add_upsample:
711
+ self.upsamplers = nn.ModuleList(
712
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
713
+ )
714
+ else:
715
+ self.upsamplers = None
716
+
717
+ self.gradient_checkpointing = False
718
+
719
+ def forward(
720
+ self,
721
+ hidden_states,
722
+ res_hidden_states_tuple,
723
+ temb=None,
724
+ encoder_hidden_states=None,
725
+ upsample_size=None,
726
+ attention_mask=None,
727
+ self_attention_additional_feats=None,
728
+ mode=None,
729
+ ):
730
+ for i, (resnet, attn, motion_module) in enumerate(
731
+ zip(self.resnets, self.attentions, self.motion_modules)
732
+ ):
733
+ # pop res hidden states
734
+ res_hidden_states = res_hidden_states_tuple[-1]
735
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
736
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
737
+
738
+ if self.training and self.gradient_checkpointing:
739
+
740
+ def create_custom_forward(module, return_dict=None):
741
+ def custom_forward(*inputs):
742
+ if return_dict is not None:
743
+ return module(*inputs, return_dict=return_dict)
744
+ else:
745
+ return module(*inputs)
746
+
747
+ return custom_forward
748
+
749
+ hidden_states = torch.utils.checkpoint.checkpoint(
750
+ create_custom_forward(resnet), hidden_states, temb
751
+ )
752
+ hidden_states = torch.utils.checkpoint.checkpoint(
753
+ create_custom_forward(attn, return_dict=False),
754
+ hidden_states,
755
+ encoder_hidden_states,
756
+ self_attention_additional_feats,
757
+ mode,
758
+ )[0]
759
+ if motion_module is not None:
760
+ hidden_states = torch.utils.checkpoint.checkpoint(
761
+ create_custom_forward(motion_module),
762
+ hidden_states.requires_grad_(),
763
+ temb,
764
+ encoder_hidden_states,
765
+ )
766
+
767
+ else:
768
+ hidden_states = resnet(hidden_states, temb)
769
+ hidden_states = attn(
770
+ hidden_states,
771
+ encoder_hidden_states=encoder_hidden_states,
772
+ self_attention_additional_feats=self_attention_additional_feats,
773
+ mode=mode,
774
+ ).sample
775
+
776
+ # add motion module
777
+ hidden_states = (
778
+ motion_module(
779
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
780
+ )
781
+ if motion_module is not None
782
+ else hidden_states
783
+ )
784
+
785
+ if self.upsamplers is not None:
786
+ for upsampler in self.upsamplers:
787
+ hidden_states = upsampler(hidden_states, upsample_size)
788
+
789
+ return hidden_states
790
+
791
+
792
+ class UpBlock3D(nn.Module):
793
+ def __init__(
794
+ self,
795
+ in_channels: int,
796
+ prev_output_channel: int,
797
+ out_channels: int,
798
+ temb_channels: int,
799
+ dropout: float = 0.0,
800
+ num_layers: int = 1,
801
+ resnet_eps: float = 1e-6,
802
+ resnet_time_scale_shift: str = "default",
803
+ resnet_act_fn: str = "swish",
804
+ resnet_groups: int = 32,
805
+ resnet_pre_norm: bool = True,
806
+ output_scale_factor=1.0,
807
+ add_upsample=True,
808
+ use_inflated_groupnorm=None,
809
+ use_motion_module=None,
810
+ motion_module_type=None,
811
+ motion_module_kwargs=None,
812
+ ):
813
+ super().__init__()
814
+ resnets = []
815
+ motion_modules = []
816
+
817
+ # use_motion_module = False
818
+ for i in range(num_layers):
819
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
820
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
821
+
822
+ resnets.append(
823
+ ResnetBlock3D(
824
+ in_channels=resnet_in_channels + res_skip_channels,
825
+ out_channels=out_channels,
826
+ temb_channels=temb_channels,
827
+ eps=resnet_eps,
828
+ groups=resnet_groups,
829
+ dropout=dropout,
830
+ time_embedding_norm=resnet_time_scale_shift,
831
+ non_linearity=resnet_act_fn,
832
+ output_scale_factor=output_scale_factor,
833
+ pre_norm=resnet_pre_norm,
834
+ use_inflated_groupnorm=use_inflated_groupnorm,
835
+ )
836
+ )
837
+ motion_modules.append(
838
+ get_motion_module(
839
+ in_channels=out_channels,
840
+ motion_module_type=motion_module_type,
841
+ motion_module_kwargs=motion_module_kwargs,
842
+ )
843
+ if use_motion_module
844
+ else None
845
+ )
846
+
847
+ self.resnets = nn.ModuleList(resnets)
848
+ self.motion_modules = nn.ModuleList(motion_modules)
849
+
850
+ if add_upsample:
851
+ self.upsamplers = nn.ModuleList(
852
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
853
+ )
854
+ else:
855
+ self.upsamplers = None
856
+
857
+ self.gradient_checkpointing = False
858
+
859
+ def forward(
860
+ self,
861
+ hidden_states,
862
+ res_hidden_states_tuple,
863
+ temb=None,
864
+ upsample_size=None,
865
+ encoder_hidden_states=None,
866
+ ):
867
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
868
+ # pop res hidden states
869
+ res_hidden_states = res_hidden_states_tuple[-1]
870
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
871
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
872
+
873
+ # print(f"UpBlock3D {self.gradient_checkpointing = }")
874
+ if self.training and self.gradient_checkpointing:
875
+
876
+ def create_custom_forward(module):
877
+ def custom_forward(*inputs):
878
+ return module(*inputs)
879
+
880
+ return custom_forward
881
+
882
+ hidden_states = torch.utils.checkpoint.checkpoint(
883
+ create_custom_forward(resnet), hidden_states, temb
884
+ )
885
+ if motion_module is not None:
886
+ hidden_states = torch.utils.checkpoint.checkpoint(
887
+ create_custom_forward(motion_module),
888
+ hidden_states.requires_grad_(),
889
+ temb,
890
+ encoder_hidden_states,
891
+ )
892
+ else:
893
+ hidden_states = resnet(hidden_states, temb)
894
+ hidden_states = (
895
+ motion_module(
896
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
897
+ )
898
+ if motion_module is not None
899
+ else hidden_states
900
+ )
901
+
902
+ if self.upsamplers is not None:
903
+ for upsampler in self.upsamplers:
904
+ hidden_states = upsampler(hidden_states, upsample_size)
905
+
906
+ return hidden_states
src/pipelines/__init__.py ADDED
File without changes
src/pipelines/context.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: Adapted from cli
2
+ from typing import Callable, List, Optional
3
+
4
+ import numpy as np
5
+
6
+
7
+ def ordered_halving(val):
8
+ bin_str = f"{val:064b}"
9
+ bin_flip = bin_str[::-1]
10
+ as_int = int(bin_flip, 2)
11
+
12
+ return as_int / (1 << 64)
13
+
14
+
15
+ def uniform(
16
+ step: int = ...,
17
+ num_steps: Optional[int] = None,
18
+ num_frames: int = ...,
19
+ context_size: Optional[int] = None,
20
+ context_stride: int = 3,
21
+ context_overlap: int = 4,
22
+ closed_loop: bool = True,
23
+ ):
24
+ if num_frames <= context_size:
25
+ yield list(range(num_frames))
26
+ return
27
+
28
+ context_stride = min(
29
+ context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
30
+ )
31
+
32
+ for context_step in 1 << np.arange(context_stride):
33
+ pad = int(round(num_frames * ordered_halving(step)))
34
+ for j in range(
35
+ int(ordered_halving(step) * context_step) + pad,
36
+ num_frames + pad + (0 if closed_loop else -context_overlap),
37
+ (context_size * context_step - context_overlap),
38
+ ):
39
+ yield [
40
+ e % num_frames
41
+ for e in range(j, j + context_size * context_step, context_step)
42
+ ]
43
+
44
+
45
+ def get_context_scheduler(name: str) -> Callable:
46
+ if name == "uniform":
47
+ return uniform
48
+ else:
49
+ raise ValueError(f"Unknown context_overlap policy {name}")
50
+
51
+
52
+ def get_total_steps(
53
+ scheduler,
54
+ timesteps: List[int],
55
+ num_steps: Optional[int] = None,
56
+ num_frames: int = ...,
57
+ context_size: Optional[int] = None,
58
+ context_stride: int = 3,
59
+ context_overlap: int = 4,
60
+ closed_loop: bool = True,
61
+ ):
62
+ return sum(
63
+ len(
64
+ list(
65
+ scheduler(
66
+ i,
67
+ num_steps,
68
+ num_frames,
69
+ context_size,
70
+ context_stride,
71
+ context_overlap,
72
+ )
73
+ )
74
+ )
75
+ for i in range(len(timesteps))
76
+ )
src/pipelines/pipeline_lmks2vid_long.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Callable, List, Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers import DiffusionPipeline
9
+ from diffusers.image_processor import VaeImageProcessor
10
+ from diffusers.schedulers import (
11
+ DDIMScheduler,
12
+ DPMSolverMultistepScheduler,
13
+ EulerAncestralDiscreteScheduler,
14
+ EulerDiscreteScheduler,
15
+ LMSDiscreteScheduler,
16
+ PNDMScheduler,
17
+ )
18
+ from diffusers.utils import BaseOutput, deprecate, is_accelerate_available, logging
19
+ from diffusers.utils.torch_utils import randn_tensor
20
+ from einops import rearrange
21
+ from tqdm import tqdm
22
+ from transformers import CLIPImageProcessor
23
+
24
+ from src.models.mutual_self_attention import ReferenceAttentionControl
25
+ from src.pipelines.context import get_context_scheduler
26
+ from src.pipelines.utils import get_tensor_interpolation_method
27
+
28
+
29
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
30
+ """
31
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
32
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
33
+ """
34
+ std_text = noise_pred_text.std(
35
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True
36
+ )
37
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
38
+ # rescale the results from guidance (fixes overexposure)
39
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
40
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
41
+ noise_cfg = (
42
+ guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
43
+ )
44
+ print(f"{(std_text / std_cfg) = }")
45
+ return noise_cfg
46
+
47
+
48
+ @dataclass
49
+ class Pose2VideoPipelineOutput(BaseOutput):
50
+ videos: Union[torch.Tensor, np.ndarray]
51
+
52
+
53
+ class Pose2VideoPipeline(DiffusionPipeline):
54
+ _optional_components = []
55
+
56
+ def __init__(
57
+ self,
58
+ vae,
59
+ image_encoder,
60
+ reference_unet,
61
+ denoising_unet,
62
+ pose_guider1,
63
+ pose_guider2,
64
+ scheduler: Union[
65
+ DDIMScheduler,
66
+ PNDMScheduler,
67
+ LMSDiscreteScheduler,
68
+ EulerDiscreteScheduler,
69
+ EulerAncestralDiscreteScheduler,
70
+ DPMSolverMultistepScheduler,
71
+ ],
72
+ audio_guider=None,
73
+ image_proj_model=None,
74
+ tokenizer=None,
75
+ text_encoder=None,
76
+ ):
77
+ super().__init__()
78
+
79
+ self.register_modules(
80
+ vae=vae,
81
+ image_encoder=image_encoder,
82
+ reference_unet=reference_unet,
83
+ denoising_unet=denoising_unet,
84
+ pose_guider1=pose_guider1,
85
+ pose_guider2=pose_guider2,
86
+ scheduler=scheduler,
87
+ audio_guider=audio_guider,
88
+ image_proj_model=image_proj_model,
89
+ tokenizer=tokenizer,
90
+ text_encoder=text_encoder,
91
+ )
92
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
93
+ self.clip_image_processor = CLIPImageProcessor()
94
+ self.ref_image_processor = VaeImageProcessor(
95
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
96
+ )
97
+ self.cond_image_processor = VaeImageProcessor(
98
+ vae_scale_factor=self.vae_scale_factor,
99
+ do_convert_rgb=True,
100
+ do_normalize=False,
101
+ )
102
+
103
+ def enable_vae_slicing(self):
104
+ self.vae.enable_slicing()
105
+
106
+ def disable_vae_slicing(self):
107
+ self.vae.disable_slicing()
108
+
109
+ def enable_sequential_cpu_offload(self, gpu_id=0):
110
+ if is_accelerate_available():
111
+ from accelerate import cpu_offload
112
+ else:
113
+ raise ImportError("Please install accelerate via `pip install accelerate`")
114
+
115
+ device = torch.device(f"cuda:{gpu_id}")
116
+
117
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
118
+ if cpu_offloaded_model is not None:
119
+ cpu_offload(cpu_offloaded_model, device)
120
+
121
+ @property
122
+ def _execution_device(self):
123
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
124
+ return self.device
125
+ for module in self.unet.modules():
126
+ if (
127
+ hasattr(module, "_hf_hook")
128
+ and hasattr(module._hf_hook, "execution_device")
129
+ and module._hf_hook.execution_device is not None
130
+ ):
131
+ return torch.device(module._hf_hook.execution_device)
132
+ return self.device
133
+
134
+ def decode_latents(self, latents):
135
+ video_length = latents.shape[2]
136
+ latents = 1 / 0.18215 * latents
137
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
138
+ # video = self.vae.decode(latents).sample
139
+ video = []
140
+ for frame_idx in tqdm(range(latents.shape[0])):
141
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
142
+ video = torch.cat(video)
143
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
144
+ video = (video / 2 + 0.5).clamp(0, 1)
145
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
146
+ video = video.cpu().float().numpy()
147
+ return video
148
+
149
+ def prepare_extra_step_kwargs(self, generator, eta):
150
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
151
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
152
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
153
+ # and should be between [0, 1]
154
+
155
+ accepts_eta = "eta" in set(
156
+ inspect.signature(self.scheduler.step).parameters.keys()
157
+ )
158
+ extra_step_kwargs = {}
159
+ if accepts_eta:
160
+ extra_step_kwargs["eta"] = eta
161
+
162
+ # check if the scheduler accepts generator
163
+ accepts_generator = "generator" in set(
164
+ inspect.signature(self.scheduler.step).parameters.keys()
165
+ )
166
+ if accepts_generator:
167
+ extra_step_kwargs["generator"] = generator
168
+ return extra_step_kwargs
169
+
170
+ def prepare_latents(
171
+ self,
172
+ batch_size,
173
+ num_channels_latents,
174
+ width,
175
+ height,
176
+ video_length,
177
+ dtype,
178
+ device,
179
+ generator,
180
+ latents=None,
181
+ ):
182
+ shape = (
183
+ batch_size,
184
+ num_channels_latents,
185
+ video_length,
186
+ height // self.vae_scale_factor,
187
+ width // self.vae_scale_factor,
188
+ )
189
+ if isinstance(generator, list) and len(generator) != batch_size:
190
+ raise ValueError(
191
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
192
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
193
+ )
194
+
195
+ if latents is None:
196
+ latents = randn_tensor(
197
+ shape, generator=generator, device=device, dtype=dtype
198
+ )
199
+ else:
200
+ latents = latents.to(device)
201
+
202
+ # scale the initial noise by the standard deviation required by the scheduler
203
+ latents = latents * self.scheduler.init_noise_sigma
204
+ return latents
205
+
206
+ def _encode_prompt(
207
+ self,
208
+ prompt,
209
+ device,
210
+ num_videos_per_prompt,
211
+ do_classifier_free_guidance,
212
+ negative_prompt,
213
+ ):
214
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
215
+
216
+ text_inputs = self.tokenizer(
217
+ prompt,
218
+ padding="max_length",
219
+ max_length=self.tokenizer.model_max_length,
220
+ truncation=True,
221
+ return_tensors="pt",
222
+ )
223
+ text_input_ids = text_inputs.input_ids
224
+ untruncated_ids = self.tokenizer(
225
+ prompt, padding="longest", return_tensors="pt"
226
+ ).input_ids
227
+
228
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
229
+ text_input_ids, untruncated_ids
230
+ ):
231
+ removed_text = self.tokenizer.batch_decode(
232
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
233
+ )
234
+
235
+ if (
236
+ hasattr(self.text_encoder.config, "use_attention_mask")
237
+ and self.text_encoder.config.use_attention_mask
238
+ ):
239
+ attention_mask = text_inputs.attention_mask.to(device)
240
+ else:
241
+ attention_mask = None
242
+
243
+ text_embeddings = self.text_encoder(
244
+ text_input_ids.to(device),
245
+ attention_mask=attention_mask,
246
+ )
247
+ text_embeddings = text_embeddings[0]
248
+
249
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
250
+ bs_embed, seq_len, _ = text_embeddings.shape
251
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
252
+ text_embeddings = text_embeddings.view(
253
+ bs_embed * num_videos_per_prompt, seq_len, -1
254
+ )
255
+
256
+ # get unconditional embeddings for classifier free guidance
257
+ if do_classifier_free_guidance:
258
+ uncond_tokens: List[str]
259
+ if negative_prompt is None:
260
+ uncond_tokens = [""] * batch_size
261
+ elif type(prompt) is not type(negative_prompt):
262
+ raise TypeError(
263
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
264
+ f" {type(prompt)}."
265
+ )
266
+ elif isinstance(negative_prompt, str):
267
+ uncond_tokens = [negative_prompt]
268
+ elif batch_size != len(negative_prompt):
269
+ raise ValueError(
270
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
271
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
272
+ " the batch size of `prompt`."
273
+ )
274
+ else:
275
+ uncond_tokens = negative_prompt
276
+
277
+ max_length = text_input_ids.shape[-1]
278
+ uncond_input = self.tokenizer(
279
+ uncond_tokens,
280
+ padding="max_length",
281
+ max_length=max_length,
282
+ truncation=True,
283
+ return_tensors="pt",
284
+ )
285
+
286
+ if (
287
+ hasattr(self.text_encoder.config, "use_attention_mask")
288
+ and self.text_encoder.config.use_attention_mask
289
+ ):
290
+ attention_mask = uncond_input.attention_mask.to(device)
291
+ else:
292
+ attention_mask = None
293
+
294
+ uncond_embeddings = self.text_encoder(
295
+ uncond_input.input_ids.to(device),
296
+ attention_mask=attention_mask,
297
+ )
298
+ uncond_embeddings = uncond_embeddings[0]
299
+
300
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
301
+ seq_len = uncond_embeddings.shape[1]
302
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
303
+ uncond_embeddings = uncond_embeddings.view(
304
+ batch_size * num_videos_per_prompt, seq_len, -1
305
+ )
306
+
307
+ # For classifier free guidance, we need to do two forward passes.
308
+ # Here we concatenate the unconditional and text embeddings into a single batch
309
+ # to avoid doing two forward passes
310
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
311
+
312
+ return text_embeddings
313
+
314
+ def interpolate_latents(
315
+ self, latents: torch.Tensor, interpolation_factor: int, device
316
+ ):
317
+ if interpolation_factor < 2:
318
+ return latents
319
+
320
+ new_latents = torch.zeros(
321
+ (
322
+ latents.shape[0],
323
+ latents.shape[1],
324
+ ((latents.shape[2] - 1) * interpolation_factor) + 1,
325
+ latents.shape[3],
326
+ latents.shape[4],
327
+ ),
328
+ device=latents.device,
329
+ dtype=latents.dtype,
330
+ )
331
+
332
+ org_video_length = latents.shape[2]
333
+ rate = [i / interpolation_factor for i in range(interpolation_factor)][1:]
334
+
335
+ new_index = 0
336
+
337
+ v0 = None
338
+ v1 = None
339
+
340
+ for i0, i1 in zip(range(org_video_length), range(org_video_length)[1:]):
341
+ v0 = latents[:, :, i0, :, :]
342
+ v1 = latents[:, :, i1, :, :]
343
+
344
+ new_latents[:, :, new_index, :, :] = v0
345
+ new_index += 1
346
+
347
+ for f in rate:
348
+ v = get_tensor_interpolation_method()(
349
+ v0.to(device=device), v1.to(device=device), f
350
+ )
351
+ new_latents[:, :, new_index, :, :] = v.to(latents.device)
352
+ new_index += 1
353
+
354
+ new_latents[:, :, new_index, :, :] = v1
355
+ new_index += 1
356
+
357
+ return new_latents
358
+
359
+ @torch.no_grad()
360
+ def __call__(
361
+ self,
362
+ ref_image,
363
+ pose_up_images,
364
+ pose_down_images,
365
+ width,
366
+ height,
367
+ video_length,
368
+ num_inference_steps,
369
+ guidance_scale,
370
+ audio_features=None,
371
+ num_images_per_prompt=1,
372
+ eta: float = 0.0,
373
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
374
+ output_type: Optional[str] = "tensor",
375
+ return_dict: bool = True,
376
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
377
+ callback_steps: Optional[int] = 1,
378
+ guidance_rescale=0., # 0.7
379
+ **kwargs,
380
+ ):
381
+ # Default height and width to unet
382
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
383
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
384
+
385
+ device = self._execution_device
386
+
387
+ do_classifier_free_guidance = guidance_scale > 1.0
388
+
389
+ # Prepare timesteps
390
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
391
+ timesteps = self.scheduler.timesteps
392
+
393
+ batch_size = 1
394
+
395
+ # Prepare clip image embeds
396
+ # NOTE: 这里是否需要 resize 到 (224, 224) 需要观察
397
+ clip_image = self.clip_image_processor.preprocess(
398
+ ref_image.resize((224, 224)), return_tensors="pt"
399
+ ).pixel_values
400
+ # If image_proj_model is not None, means enable ip-adapter
401
+ if self.image_proj_model is not None:
402
+ clip_image_embeds = self.image_encoder(
403
+ clip_image.to(device, dtype=self.image_encoder.dtype),
404
+ output_hidden_states=True,
405
+ ).hidden_states[-2]
406
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
407
+ uncond_image_prompt_embeds = self.image_proj_model(
408
+ torch.zeros_like(clip_image_embeds)
409
+ )
410
+ text_prompt_embeds = self._encode_prompt(
411
+ "best quality, high quality",
412
+ device,
413
+ num_images_per_prompt,
414
+ do_classifier_free_guidance,
415
+ negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
416
+ )
417
+ # Concat image and text embeddings
418
+ if do_classifier_free_guidance:
419
+ (
420
+ uncond_text_prompt_embeds,
421
+ text_prompt_embeds,
422
+ ) = text_prompt_embeds.chunk(2)
423
+ uncond_encoder_hidden_states = torch.cat(
424
+ [uncond_text_prompt_embeds, uncond_image_prompt_embeds], dim=1
425
+ )
426
+ encoder_hidden_states = torch.cat(
427
+ [text_prompt_embeds, image_prompt_embeds], dim=1
428
+ )
429
+
430
+ else:
431
+ clip_image_embeds = self.image_encoder(
432
+ clip_image.to(device, dtype=self.image_encoder.dtype)
433
+ ).image_embeds
434
+ encoder_hidden_states = clip_image_embeds.unsqueeze(1)
435
+ uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)
436
+
437
+ if do_classifier_free_guidance:
438
+ encoder_hidden_states = torch.cat(
439
+ [uncond_encoder_hidden_states, encoder_hidden_states], dim=0
440
+ )
441
+
442
+ num_channels_latents = self.denoising_unet.in_channels
443
+
444
+ latents = self.prepare_latents(
445
+ batch_size * num_images_per_prompt,
446
+ num_channels_latents,
447
+ width,
448
+ height,
449
+ video_length,
450
+ clip_image_embeds.dtype,
451
+ device,
452
+ generator,
453
+ )
454
+
455
+ # Prepare extra step kwargs.
456
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
457
+
458
+ # Prepare ref image latents
459
+ ref_image_tensor = self.ref_image_processor.preprocess(
460
+ ref_image, height=height, width=width
461
+ ) # (bs, c, width, height)
462
+ ref_image_tensor = ref_image_tensor.to(
463
+ dtype=self.vae.dtype, device=self.vae.device
464
+ )
465
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
466
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
467
+ ref_image_latents = ref_image_latents.unsqueeze(2)
468
+
469
+ # Prepare a list of pose condition images
470
+ pose_up_cond_tensor_list, pose_down_cond_tensor_list = [], []
471
+ for i, pose_up_image in enumerate(pose_up_images):
472
+ pose_up_cond_tensor = self.cond_image_processor.preprocess(
473
+ pose_up_image, height=height, width=width
474
+ )
475
+ pose_up_cond_tensor = pose_up_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w)
476
+ pose_up_cond_tensor_list.append(pose_up_cond_tensor)
477
+ pose_down_cond_tensor = self.cond_image_processor.preprocess(
478
+ pose_down_images[i], height=height, width=width
479
+ )
480
+ pose_down_cond_tensor = pose_down_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w)
481
+ pose_down_cond_tensor_list.append(pose_down_cond_tensor)
482
+ pose_up_cond_tensor = torch.cat(pose_up_cond_tensor_list, dim=2) # (bs, c, t, h, w)
483
+ pose_up_cond_tensor = pose_up_cond_tensor.to(
484
+ device=device, dtype=self.pose_guider1.dtype
485
+ )
486
+ pose_down_cond_tensor = torch.cat(pose_down_cond_tensor_list, dim=2) # (bs, c, t, h, w)
487
+ pose_up_fea = self.pose_guider1(pose_up_cond_tensor)
488
+ pose_down_cond_tensor = pose_down_cond_tensor.to(device=device, dtype=self.pose_guider2.dtype)
489
+ pose_down_fea = self.pose_guider2(pose_down_cond_tensor)
490
+ pose_fea = pose_up_fea + pose_down_fea
491
+
492
+ context_schedule = "uniform"
493
+ context_frames = 24
494
+ context_stride = 1
495
+ context_overlap = 4 # 4
496
+ context_batch_size = 1
497
+ context_scheduler = get_context_scheduler(context_schedule)
498
+
499
+ # denoising loop
500
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
501
+ middle_results = []
502
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
503
+ self_attention_additional_feats = {}
504
+ for i, t in enumerate(timesteps):
505
+ noise_pred = torch.zeros(
506
+ (
507
+ latents.shape[0] * (2 if do_classifier_free_guidance else 1),
508
+ *latents.shape[1:],
509
+ ),
510
+ device=latents.device,
511
+ dtype=latents.dtype,
512
+ )
513
+ counter = torch.zeros(
514
+ (1, 1, latents.shape[2], 1, 1),
515
+ device=latents.device,
516
+ dtype=latents.dtype,
517
+ )
518
+
519
+ # 1. Forward reference image
520
+ if i == 0:
521
+ self.reference_unet(
522
+ # ref_image_latents.repeat(
523
+ # (2 if do_classifier_free_guidance else 1), 1, 1, 1
524
+ # ),
525
+ torch.cat([torch.zeros_like(ref_image_latents), ref_image_latents]) \
526
+ if do_classifier_free_guidance else \
527
+ ref_image_latents,
528
+ torch.zeros_like(t),
529
+ # t,
530
+ encoder_hidden_states=encoder_hidden_states,
531
+ self_attention_additional_feats=self_attention_additional_feats,
532
+ return_dict=False,
533
+ )
534
+ context_queue = list(
535
+ context_scheduler(
536
+ 0,
537
+ num_inference_steps,
538
+ latents.shape[2],
539
+ context_frames,
540
+ context_stride,
541
+ context_overlap,
542
+ )
543
+ )
544
+
545
+ num_context_batches = math.ceil(len(context_queue) / context_batch_size)
546
+ global_context = []
547
+ for i in range(num_context_batches):
548
+ global_context.append(
549
+ context_queue[
550
+ i * context_batch_size : (i + 1) * context_batch_size
551
+ ]
552
+ )
553
+
554
+ for context in global_context:
555
+ # 3.1 expand the latents if we are doing classifier free guidance
556
+ latent_model_input = (
557
+ torch.cat([latents[:, :, c] for c in context])
558
+ .to(device)
559
+ .repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
560
+ )
561
+ latent_model_input = self.scheduler.scale_model_input(
562
+ latent_model_input, t
563
+ )
564
+ b, c, f, h, w = latent_model_input.shape
565
+ latent_pose_input = torch.cat(
566
+ [pose_fea[:, :, c] for c in context]
567
+ ).repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
568
+
569
+ pred = self.denoising_unet(
570
+ latent_model_input,
571
+ t,
572
+ encoder_hidden_states=encoder_hidden_states[:b],
573
+ pose_cond_fea=latent_pose_input,
574
+ self_attention_additional_feats=self_attention_additional_feats,
575
+ return_dict=False,
576
+ )[0]
577
+
578
+ for j, c in enumerate(context):
579
+ noise_pred[:, :, c] = noise_pred[:, :, c] + pred
580
+ counter[:, :, c] = counter[:, :, c] + 1
581
+
582
+ # perform guidance
583
+ if do_classifier_free_guidance:
584
+ noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2)
585
+ noise_pred = noise_pred_uncond + guidance_scale * (
586
+ noise_pred_text - noise_pred_uncond
587
+ )
588
+
589
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
590
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
591
+ noise_pred = rescale_noise_cfg(
592
+ noise_pred,
593
+ noise_pred_text,
594
+ guidance_rescale=guidance_rescale,
595
+ )
596
+
597
+ latents = self.scheduler.step(
598
+ noise_pred, t, latents, **extra_step_kwargs
599
+ ).prev_sample
600
+
601
+ if i == len(timesteps) - 1 or (
602
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
603
+ ):
604
+ progress_bar.update()
605
+ if callback is not None and i % callback_steps == 0:
606
+ step_idx = i // getattr(self.scheduler, "order", 1)
607
+ callback(step_idx, t, latents)
608
+
609
+
610
+ interpolation_factor = 1
611
+ latents = self.interpolate_latents(latents, interpolation_factor, device)
612
+ # Post-processing
613
+ images = self.decode_latents(latents) # (b, c, f, h, w)
614
+
615
+ # Convert to tensor
616
+ if output_type == "tensor":
617
+ images = torch.from_numpy(images)
618
+
619
+ if not return_dict:
620
+ return images
621
+
622
+ return Pose2VideoPipelineOutput(videos=images)
src/pipelines/pipeline_pose2img.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Callable, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from diffusers import DiffusionPipeline
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from diffusers.schedulers import (
10
+ DDIMScheduler,
11
+ DPMSolverMultistepScheduler,
12
+ EulerAncestralDiscreteScheduler,
13
+ EulerDiscreteScheduler,
14
+ LMSDiscreteScheduler,
15
+ PNDMScheduler,
16
+ )
17
+ from diffusers.utils import BaseOutput, is_accelerate_available
18
+ from diffusers.utils.torch_utils import randn_tensor
19
+ from einops import rearrange
20
+ from tqdm import tqdm
21
+ from transformers import CLIPImageProcessor
22
+
23
+ from src.models.mutual_self_attention import ReferenceAttentionControl
24
+
25
+
26
+ @dataclass
27
+ class Pose2ImagePipelineOutput(BaseOutput):
28
+ images: Union[torch.Tensor, np.ndarray]
29
+
30
+
31
+ class Pose2ImagePipeline(DiffusionPipeline):
32
+ _optional_components = []
33
+
34
+ def __init__(
35
+ self,
36
+ vae,
37
+ image_encoder,
38
+ reference_unet,
39
+ denoising_unet,
40
+ pose_guider,
41
+ scheduler: Union[
42
+ DDIMScheduler,
43
+ PNDMScheduler,
44
+ LMSDiscreteScheduler,
45
+ EulerDiscreteScheduler,
46
+ EulerAncestralDiscreteScheduler,
47
+ DPMSolverMultistepScheduler,
48
+ ],
49
+ ):
50
+ super().__init__()
51
+
52
+ self.register_modules(
53
+ vae=vae,
54
+ image_encoder=image_encoder,
55
+ reference_unet=reference_unet,
56
+ denoising_unet=denoising_unet,
57
+ pose_guider=pose_guider,
58
+ scheduler=scheduler,
59
+ )
60
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
61
+ self.clip_image_processor = CLIPImageProcessor()
62
+ self.ref_image_processor = VaeImageProcessor(
63
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
64
+ )
65
+ self.cond_image_processor = VaeImageProcessor(
66
+ vae_scale_factor=self.vae_scale_factor,
67
+ do_convert_rgb=True,
68
+ do_normalize=False,
69
+ )
70
+
71
+ def enable_vae_slicing(self):
72
+ self.vae.enable_slicing()
73
+
74
+ def disable_vae_slicing(self):
75
+ self.vae.disable_slicing()
76
+
77
+ def enable_sequential_cpu_offload(self, gpu_id=0):
78
+ if is_accelerate_available():
79
+ from accelerate import cpu_offload
80
+ else:
81
+ raise ImportError("Please install accelerate via `pip install accelerate`")
82
+
83
+ device = torch.device(f"cuda:{gpu_id}")
84
+
85
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
86
+ if cpu_offloaded_model is not None:
87
+ cpu_offload(cpu_offloaded_model, device)
88
+
89
+ @property
90
+ def _execution_device(self):
91
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
92
+ return self.device
93
+ for module in self.unet.modules():
94
+ if (
95
+ hasattr(module, "_hf_hook")
96
+ and hasattr(module._hf_hook, "execution_device")
97
+ and module._hf_hook.execution_device is not None
98
+ ):
99
+ return torch.device(module._hf_hook.execution_device)
100
+ return self.device
101
+
102
+ def decode_latents(self, latents):
103
+ video_length = latents.shape[2]
104
+ latents = 1 / 0.18215 * latents
105
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
106
+ # video = self.vae.decode(latents).sample
107
+ video = []
108
+ for frame_idx in tqdm(range(latents.shape[0])):
109
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
110
+ video = torch.cat(video)
111
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
112
+ video = (video / 2 + 0.5).clamp(0, 1)
113
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
114
+ video = video.cpu().float().numpy()
115
+ return video
116
+
117
+ def prepare_extra_step_kwargs(self, generator, eta):
118
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
119
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
120
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
121
+ # and should be between [0, 1]
122
+
123
+ accepts_eta = "eta" in set(
124
+ inspect.signature(self.scheduler.step).parameters.keys()
125
+ )
126
+ extra_step_kwargs = {}
127
+ if accepts_eta:
128
+ extra_step_kwargs["eta"] = eta
129
+
130
+ # check if the scheduler accepts generator
131
+ accepts_generator = "generator" in set(
132
+ inspect.signature(self.scheduler.step).parameters.keys()
133
+ )
134
+ if accepts_generator:
135
+ extra_step_kwargs["generator"] = generator
136
+ return extra_step_kwargs
137
+
138
+ def prepare_latents(
139
+ self,
140
+ batch_size,
141
+ num_channels_latents,
142
+ width,
143
+ height,
144
+ dtype,
145
+ device,
146
+ generator,
147
+ latents=None,
148
+ ):
149
+ shape = (
150
+ batch_size,
151
+ num_channels_latents,
152
+ height // self.vae_scale_factor,
153
+ width // self.vae_scale_factor,
154
+ )
155
+ if isinstance(generator, list) and len(generator) != batch_size:
156
+ raise ValueError(
157
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
158
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
159
+ )
160
+
161
+ if latents is None:
162
+ latents = randn_tensor(
163
+ shape, generator=generator, device=device, dtype=dtype
164
+ )
165
+ else:
166
+ latents = latents.to(device)
167
+
168
+ # scale the initial noise by the standard deviation required by the scheduler
169
+ latents = latents * self.scheduler.init_noise_sigma
170
+ return latents
171
+
172
+ def prepare_condition(
173
+ self,
174
+ cond_image,
175
+ width,
176
+ height,
177
+ device,
178
+ dtype,
179
+ do_classififer_free_guidance=False,
180
+ ):
181
+ image = self.cond_image_processor.preprocess(
182
+ cond_image, height=height, width=width
183
+ ).to(dtype=torch.float32)
184
+
185
+ image = image.to(device=device, dtype=dtype)
186
+
187
+ if do_classififer_free_guidance:
188
+ image = torch.cat([image] * 2)
189
+
190
+ return image
191
+
192
+ @torch.no_grad()
193
+ def __call__(
194
+ self,
195
+ ref_image,
196
+ pose_image,
197
+ width,
198
+ height,
199
+ num_inference_steps,
200
+ guidance_scale,
201
+ num_images_per_prompt=1,
202
+ eta: float = 0.0,
203
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
204
+ output_type: Optional[str] = "tensor",
205
+ return_dict: bool = True,
206
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
207
+ callback_steps: Optional[int] = 1,
208
+ **kwargs,
209
+ ):
210
+ # Default height and width to unet
211
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
212
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
213
+
214
+ device = self._execution_device
215
+
216
+ do_classifier_free_guidance = guidance_scale > 1.0
217
+
218
+ # Prepare timesteps
219
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
220
+ timesteps = self.scheduler.timesteps
221
+
222
+ batch_size = 1
223
+
224
+ # Prepare clip image embeds
225
+ clip_image = self.clip_image_processor.preprocess(
226
+ ref_image.resize((224, 224)), return_tensors="pt"
227
+ ).pixel_values
228
+ clip_image_embeds = self.image_encoder(
229
+ clip_image.to(device, dtype=self.image_encoder.dtype)
230
+ ).image_embeds
231
+ image_prompt_embeds = clip_image_embeds.unsqueeze(1)
232
+ uncond_image_prompt_embeds = torch.zeros_like(image_prompt_embeds)
233
+
234
+ if do_classifier_free_guidance:
235
+ image_prompt_embeds = torch.cat(
236
+ [uncond_image_prompt_embeds, image_prompt_embeds], dim=0
237
+ )
238
+
239
+ reference_control_writer = ReferenceAttentionControl(
240
+ self.reference_unet,
241
+ do_classifier_free_guidance=do_classifier_free_guidance,
242
+ mode="write",
243
+ batch_size=batch_size,
244
+ fusion_blocks="full",
245
+ )
246
+ reference_control_reader = ReferenceAttentionControl(
247
+ self.denoising_unet,
248
+ do_classifier_free_guidance=do_classifier_free_guidance,
249
+ mode="read",
250
+ batch_size=batch_size,
251
+ fusion_blocks="full",
252
+ )
253
+
254
+ num_channels_latents = self.denoising_unet.in_channels
255
+ latents = self.prepare_latents(
256
+ batch_size * num_images_per_prompt,
257
+ num_channels_latents,
258
+ width,
259
+ height,
260
+ clip_image_embeds.dtype,
261
+ device,
262
+ generator,
263
+ )
264
+ latents = latents.unsqueeze(2) # (bs, c, 1, h', w')
265
+ latents_dtype = latents.dtype
266
+
267
+ # Prepare extra step kwargs.
268
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
269
+
270
+ # Prepare ref image latents
271
+ ref_image_tensor = self.ref_image_processor.preprocess(
272
+ ref_image, height=height, width=width
273
+ ) # (bs, c, width, height)
274
+ ref_image_tensor = ref_image_tensor.to(
275
+ dtype=self.vae.dtype, device=self.vae.device
276
+ )
277
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
278
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
279
+
280
+ # Prepare pose condition image
281
+ pose_cond_tensor = self.cond_image_processor.preprocess(
282
+ pose_image, height=height, width=width
283
+ )
284
+ pose_cond_tensor = pose_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w)
285
+ pose_cond_tensor = pose_cond_tensor.to(
286
+ device=device, dtype=self.pose_guider.dtype
287
+ )
288
+ pose_fea = self.pose_guider(pose_cond_tensor)
289
+ pose_fea = (
290
+ torch.cat([pose_fea] * 2) if do_classifier_free_guidance else pose_fea
291
+ )
292
+
293
+ # denoising loop
294
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
295
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
296
+ for i, t in enumerate(timesteps):
297
+ # 1. Forward reference image
298
+ if i == 0:
299
+ self.reference_unet(
300
+ ref_image_latents.repeat(
301
+ (2 if do_classifier_free_guidance else 1), 1, 1, 1
302
+ ),
303
+ torch.zeros_like(t),
304
+ encoder_hidden_states=image_prompt_embeds,
305
+ return_dict=False,
306
+ )
307
+
308
+ # 2. Update reference unet feature into denosing net
309
+ reference_control_reader.update(reference_control_writer)
310
+
311
+ # 3.1 expand the latents if we are doing classifier free guidance
312
+ latent_model_input = (
313
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
314
+ )
315
+ latent_model_input = self.scheduler.scale_model_input(
316
+ latent_model_input, t
317
+ )
318
+
319
+ noise_pred = self.denoising_unet(
320
+ latent_model_input,
321
+ t,
322
+ encoder_hidden_states=image_prompt_embeds,
323
+ pose_cond_fea=pose_fea,
324
+ return_dict=False,
325
+ )[0]
326
+
327
+ # perform guidance
328
+ if do_classifier_free_guidance:
329
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
330
+ noise_pred = noise_pred_uncond + guidance_scale * (
331
+ noise_pred_text - noise_pred_uncond
332
+ )
333
+
334
+ # compute the previous noisy sample x_t -> x_t-1
335
+ latents = self.scheduler.step(
336
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
337
+ )[0]
338
+
339
+ # call the callback, if provided
340
+ if i == len(timesteps) - 1 or (
341
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
342
+ ):
343
+ progress_bar.update()
344
+ if callback is not None and i % callback_steps == 0:
345
+ step_idx = i // getattr(self.scheduler, "order", 1)
346
+ callback(step_idx, t, latents)
347
+ reference_control_reader.clear()
348
+ reference_control_writer.clear()
349
+
350
+ # Post-processing
351
+ image = self.decode_latents(latents) # (b, c, 1, h, w)
352
+
353
+ # Convert to tensor
354
+ if output_type == "tensor":
355
+ image = torch.from_numpy(image)
356
+
357
+ if not return_dict:
358
+ return image
359
+
360
+ return Pose2ImagePipelineOutput(images=image)
src/pipelines/pipeline_pose2vid.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Callable, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from diffusers import DiffusionPipeline
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
10
+ EulerAncestralDiscreteScheduler,
11
+ EulerDiscreteScheduler, LMSDiscreteScheduler,
12
+ PNDMScheduler)
13
+ from diffusers.utils import BaseOutput, is_accelerate_available
14
+ from diffusers.utils.torch_utils import randn_tensor
15
+ from einops import rearrange
16
+ from tqdm import tqdm
17
+ from transformers import CLIPImageProcessor
18
+
19
+ from src.models.mutual_self_attention import ReferenceAttentionControl
20
+
21
+
22
+ @dataclass
23
+ class Pose2VideoPipelineOutput(BaseOutput):
24
+ videos: Union[torch.Tensor, np.ndarray]
25
+
26
+
27
+ class Pose2VideoPipeline(DiffusionPipeline):
28
+ _optional_components = []
29
+
30
+ def __init__(
31
+ self,
32
+ vae,
33
+ image_encoder,
34
+ reference_unet,
35
+ denoising_unet,
36
+ pose_guider,
37
+ scheduler: Union[
38
+ DDIMScheduler,
39
+ PNDMScheduler,
40
+ LMSDiscreteScheduler,
41
+ EulerDiscreteScheduler,
42
+ EulerAncestralDiscreteScheduler,
43
+ DPMSolverMultistepScheduler,
44
+ ],
45
+ image_proj_model=None,
46
+ tokenizer=None,
47
+ text_encoder=None,
48
+ ):
49
+ super().__init__()
50
+
51
+ self.register_modules(
52
+ vae=vae,
53
+ image_encoder=image_encoder,
54
+ reference_unet=reference_unet,
55
+ denoising_unet=denoising_unet,
56
+ pose_guider=pose_guider,
57
+ scheduler=scheduler,
58
+ image_proj_model=image_proj_model,
59
+ tokenizer=tokenizer,
60
+ text_encoder=text_encoder,
61
+ )
62
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
63
+ self.clip_image_processor = CLIPImageProcessor()
64
+ self.ref_image_processor = VaeImageProcessor(
65
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
66
+ )
67
+ self.cond_image_processor = VaeImageProcessor(
68
+ vae_scale_factor=self.vae_scale_factor,
69
+ do_convert_rgb=True,
70
+ do_normalize=False,
71
+ )
72
+
73
+ def enable_vae_slicing(self):
74
+ self.vae.enable_slicing()
75
+
76
+ def disable_vae_slicing(self):
77
+ self.vae.disable_slicing()
78
+
79
+ def enable_sequential_cpu_offload(self, gpu_id=0):
80
+ if is_accelerate_available():
81
+ from accelerate import cpu_offload
82
+ else:
83
+ raise ImportError("Please install accelerate via `pip install accelerate`")
84
+
85
+ device = torch.device(f"cuda:{gpu_id}")
86
+
87
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
88
+ if cpu_offloaded_model is not None:
89
+ cpu_offload(cpu_offloaded_model, device)
90
+
91
+ @property
92
+ def _execution_device(self):
93
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
94
+ return self.device
95
+ for module in self.unet.modules():
96
+ if (
97
+ hasattr(module, "_hf_hook")
98
+ and hasattr(module._hf_hook, "execution_device")
99
+ and module._hf_hook.execution_device is not None
100
+ ):
101
+ return torch.device(module._hf_hook.execution_device)
102
+ return self.device
103
+
104
+ def decode_latents(self, latents):
105
+ video_length = latents.shape[2]
106
+ latents = 1 / 0.18215 * latents
107
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
108
+ # video = self.vae.decode(latents).sample
109
+ video = []
110
+ for frame_idx in tqdm(range(latents.shape[0])):
111
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
112
+ video = torch.cat(video)
113
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
114
+ video = (video / 2 + 0.5).clamp(0, 1)
115
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
116
+ video = video.cpu().float().numpy()
117
+ return video
118
+
119
+ def prepare_extra_step_kwargs(self, generator, eta):
120
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
121
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
122
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
123
+ # and should be between [0, 1]
124
+
125
+ accepts_eta = "eta" in set(
126
+ inspect.signature(self.scheduler.step).parameters.keys()
127
+ )
128
+ extra_step_kwargs = {}
129
+ if accepts_eta:
130
+ extra_step_kwargs["eta"] = eta
131
+
132
+ # check if the scheduler accepts generator
133
+ accepts_generator = "generator" in set(
134
+ inspect.signature(self.scheduler.step).parameters.keys()
135
+ )
136
+ if accepts_generator:
137
+ extra_step_kwargs["generator"] = generator
138
+ return extra_step_kwargs
139
+
140
+ def prepare_latents(
141
+ self,
142
+ batch_size,
143
+ num_channels_latents,
144
+ width,
145
+ height,
146
+ video_length,
147
+ dtype,
148
+ device,
149
+ generator,
150
+ latents=None,
151
+ ):
152
+ shape = (
153
+ batch_size,
154
+ num_channels_latents,
155
+ video_length,
156
+ height // self.vae_scale_factor,
157
+ width // self.vae_scale_factor,
158
+ )
159
+ if isinstance(generator, list) and len(generator) != batch_size:
160
+ raise ValueError(
161
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
162
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
163
+ )
164
+
165
+ if latents is None:
166
+ latents = randn_tensor(
167
+ shape, generator=generator, device=device, dtype=dtype
168
+ )
169
+ else:
170
+ latents = latents.to(device)
171
+
172
+ # scale the initial noise by the standard deviation required by the scheduler
173
+ latents = latents * self.scheduler.init_noise_sigma
174
+ return latents
175
+
176
+ def _encode_prompt(
177
+ self,
178
+ prompt,
179
+ device,
180
+ num_videos_per_prompt,
181
+ do_classifier_free_guidance,
182
+ negative_prompt,
183
+ ):
184
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
185
+
186
+ text_inputs = self.tokenizer(
187
+ prompt,
188
+ padding="max_length",
189
+ max_length=self.tokenizer.model_max_length,
190
+ truncation=True,
191
+ return_tensors="pt",
192
+ )
193
+ text_input_ids = text_inputs.input_ids
194
+ untruncated_ids = self.tokenizer(
195
+ prompt, padding="longest", return_tensors="pt"
196
+ ).input_ids
197
+
198
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
199
+ text_input_ids, untruncated_ids
200
+ ):
201
+ removed_text = self.tokenizer.batch_decode(
202
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
203
+ )
204
+
205
+ if (
206
+ hasattr(self.text_encoder.config, "use_attention_mask")
207
+ and self.text_encoder.config.use_attention_mask
208
+ ):
209
+ attention_mask = text_inputs.attention_mask.to(device)
210
+ else:
211
+ attention_mask = None
212
+
213
+ text_embeddings = self.text_encoder(
214
+ text_input_ids.to(device),
215
+ attention_mask=attention_mask,
216
+ )
217
+ text_embeddings = text_embeddings[0]
218
+
219
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
220
+ bs_embed, seq_len, _ = text_embeddings.shape
221
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
222
+ text_embeddings = text_embeddings.view(
223
+ bs_embed * num_videos_per_prompt, seq_len, -1
224
+ )
225
+
226
+ # get unconditional embeddings for classifier free guidance
227
+ if do_classifier_free_guidance:
228
+ uncond_tokens: List[str]
229
+ if negative_prompt is None:
230
+ uncond_tokens = [""] * batch_size
231
+ elif type(prompt) is not type(negative_prompt):
232
+ raise TypeError(
233
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
234
+ f" {type(prompt)}."
235
+ )
236
+ elif isinstance(negative_prompt, str):
237
+ uncond_tokens = [negative_prompt]
238
+ elif batch_size != len(negative_prompt):
239
+ raise ValueError(
240
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
241
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
242
+ " the batch size of `prompt`."
243
+ )
244
+ else:
245
+ uncond_tokens = negative_prompt
246
+
247
+ max_length = text_input_ids.shape[-1]
248
+ uncond_input = self.tokenizer(
249
+ uncond_tokens,
250
+ padding="max_length",
251
+ max_length=max_length,
252
+ truncation=True,
253
+ return_tensors="pt",
254
+ )
255
+
256
+ if (
257
+ hasattr(self.text_encoder.config, "use_attention_mask")
258
+ and self.text_encoder.config.use_attention_mask
259
+ ):
260
+ attention_mask = uncond_input.attention_mask.to(device)
261
+ else:
262
+ attention_mask = None
263
+
264
+ uncond_embeddings = self.text_encoder(
265
+ uncond_input.input_ids.to(device),
266
+ attention_mask=attention_mask,
267
+ )
268
+ uncond_embeddings = uncond_embeddings[0]
269
+
270
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
271
+ seq_len = uncond_embeddings.shape[1]
272
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
273
+ uncond_embeddings = uncond_embeddings.view(
274
+ batch_size * num_videos_per_prompt, seq_len, -1
275
+ )
276
+
277
+ # For classifier free guidance, we need to do two forward passes.
278
+ # Here we concatenate the unconditional and text embeddings into a single batch
279
+ # to avoid doing two forward passes
280
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
281
+
282
+ return text_embeddings
283
+
284
+ @torch.no_grad()
285
+ def __call__(
286
+ self,
287
+ ref_image,
288
+ pose_images,
289
+ width,
290
+ height,
291
+ video_length,
292
+ num_inference_steps,
293
+ guidance_scale,
294
+ num_images_per_prompt=1,
295
+ eta: float = 0.0,
296
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
297
+ output_type: Optional[str] = "tensor",
298
+ return_dict: bool = True,
299
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
300
+ callback_steps: Optional[int] = 1,
301
+ **kwargs,
302
+ ):
303
+ # Default height and width to unet
304
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
305
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
306
+
307
+ device = self._execution_device
308
+
309
+ do_classifier_free_guidance = guidance_scale > 1.0
310
+
311
+ # Prepare timesteps
312
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
313
+ timesteps = self.scheduler.timesteps
314
+
315
+ batch_size = 1
316
+
317
+ # Prepare clip image embeds
318
+ clip_image = self.clip_image_processor.preprocess(
319
+ ref_image, return_tensors="pt"
320
+ ).pixel_values
321
+ clip_image_embeds = self.image_encoder(
322
+ clip_image.to(device, dtype=self.image_encoder.dtype)
323
+ ).image_embeds
324
+ encoder_hidden_states = clip_image_embeds.unsqueeze(1)
325
+ uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)
326
+
327
+ if do_classifier_free_guidance:
328
+ encoder_hidden_states = torch.cat(
329
+ [uncond_encoder_hidden_states, encoder_hidden_states], dim=0
330
+ )
331
+ reference_control_writer = ReferenceAttentionControl(
332
+ self.reference_unet,
333
+ do_classifier_free_guidance=do_classifier_free_guidance,
334
+ mode="write",
335
+ batch_size=batch_size,
336
+ fusion_blocks="full",
337
+ )
338
+ reference_control_reader = ReferenceAttentionControl(
339
+ self.denoising_unet,
340
+ do_classifier_free_guidance=do_classifier_free_guidance,
341
+ mode="read",
342
+ batch_size=batch_size,
343
+ fusion_blocks="full",
344
+ )
345
+
346
+ num_channels_latents = self.denoising_unet.in_channels
347
+ latents = self.prepare_latents(
348
+ batch_size * num_images_per_prompt,
349
+ num_channels_latents,
350
+ width,
351
+ height,
352
+ video_length,
353
+ clip_image_embeds.dtype,
354
+ device,
355
+ generator,
356
+ )
357
+
358
+ # Prepare extra step kwargs.
359
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
360
+
361
+ # Prepare ref image latents
362
+ ref_image_tensor = self.ref_image_processor.preprocess(
363
+ ref_image, height=height, width=width
364
+ ) # (bs, c, width, height)
365
+ ref_image_tensor = ref_image_tensor.to(
366
+ dtype=self.vae.dtype, device=self.vae.device
367
+ )
368
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
369
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
370
+
371
+ # Prepare a list of pose condition images
372
+ pose_cond_tensor_list = []
373
+ for pose_image in pose_images:
374
+ pose_cond_tensor = (
375
+ torch.from_numpy(np.array(pose_image.resize((width, height)))) / 255.0
376
+ )
377
+ pose_cond_tensor = pose_cond_tensor.permute(2, 0, 1).unsqueeze(
378
+ 1
379
+ ) # (c, 1, h, w)
380
+ pose_cond_tensor_list.append(pose_cond_tensor)
381
+ pose_cond_tensor = torch.cat(pose_cond_tensor_list, dim=1) # (c, t, h, w)
382
+ pose_cond_tensor = pose_cond_tensor.unsqueeze(0)
383
+ pose_cond_tensor = pose_cond_tensor.to(
384
+ device=device, dtype=self.pose_guider.dtype
385
+ )
386
+ pose_fea = self.pose_guider(pose_cond_tensor)
387
+ pose_fea = (
388
+ torch.cat([pose_fea] * 2) if do_classifier_free_guidance else pose_fea
389
+ )
390
+
391
+ # denoising loop
392
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
393
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
394
+ for i, t in enumerate(timesteps):
395
+ # 1. Forward reference image
396
+ if i == 0:
397
+ self.reference_unet(
398
+ ref_image_latents.repeat(
399
+ (2 if do_classifier_free_guidance else 1), 1, 1, 1
400
+ ),
401
+ torch.zeros_like(t),
402
+ # t,
403
+ encoder_hidden_states=encoder_hidden_states,
404
+ return_dict=False,
405
+ )
406
+ reference_control_reader.update(reference_control_writer)
407
+
408
+ # 3.1 expand the latents if we are doing classifier free guidance
409
+ latent_model_input = (
410
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
411
+ )
412
+ latent_model_input = self.scheduler.scale_model_input(
413
+ latent_model_input, t
414
+ )
415
+
416
+ noise_pred = self.denoising_unet(
417
+ latent_model_input,
418
+ t,
419
+ encoder_hidden_states=encoder_hidden_states,
420
+ pose_cond_fea=pose_fea,
421
+ return_dict=False,
422
+ )[0]
423
+
424
+ # perform guidance
425
+ if do_classifier_free_guidance:
426
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
427
+ noise_pred = noise_pred_uncond + guidance_scale * (
428
+ noise_pred_text - noise_pred_uncond
429
+ )
430
+
431
+ # compute the previous noisy sample x_t -> x_t-1
432
+ latents = self.scheduler.step(
433
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
434
+ )[0]
435
+
436
+ # call the callback, if provided
437
+ if i == len(timesteps) - 1 or (
438
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
439
+ ):
440
+ progress_bar.update()
441
+ if callback is not None and i % callback_steps == 0:
442
+ step_idx = i // getattr(self.scheduler, "order", 1)
443
+ callback(step_idx, t, latents)
444
+
445
+ reference_control_reader.clear()
446
+ reference_control_writer.clear()
447
+
448
+ # Post-processing
449
+ images = self.decode_latents(latents) # (b, c, f, h, w)
450
+
451
+ # Convert to tensor
452
+ if output_type == "tensor":
453
+ images = torch.from_numpy(images)
454
+
455
+ if not return_dict:
456
+ return images
457
+
458
+ return Pose2VideoPipelineOutput(videos=images)