| | import torch |
| | import torch.nn.functional as F |
| | from diffusers.models.attention_processor import ( |
| | Attention, |
| | AttnProcessor2_0, |
| | SlicedAttnProcessor, |
| | XFormersAttnProcessor |
| | ) |
| |
|
| | try: |
| | import xformers.ops |
| | except: |
| | xformers = None |
| |
|
| |
|
| | loaded_networks = [] |
| |
|
| |
|
| | def apply_single_hypernetwork( |
| | hypernetwork, hidden_states, encoder_hidden_states |
| | ): |
| | context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states) |
| | return context_k, context_v |
| |
|
| |
|
| | def apply_hypernetworks(context_k, context_v, layer=None): |
| | if len(loaded_networks) == 0: |
| | return context_v, context_v |
| | for hypernetwork in loaded_networks: |
| | context_k, context_v = hypernetwork.forward(context_k, context_v) |
| |
|
| | context_k = context_k.to(dtype=context_k.dtype) |
| | context_v = context_v.to(dtype=context_k.dtype) |
| |
|
| | return context_k, context_v |
| |
|
| |
|
| |
|
| | def xformers_forward( |
| | self: XFormersAttnProcessor, |
| | attn: Attention, |
| | hidden_states: torch.Tensor, |
| | encoder_hidden_states: torch.Tensor = None, |
| | attention_mask: torch.Tensor = None, |
| | ): |
| | batch_size, sequence_length, _ = ( |
| | hidden_states.shape |
| | if encoder_hidden_states is None |
| | else encoder_hidden_states.shape |
| | ) |
| |
|
| | attention_mask = attn.prepare_attention_mask( |
| | attention_mask, sequence_length, batch_size |
| | ) |
| |
|
| | query = attn.to_q(hidden_states) |
| |
|
| | if encoder_hidden_states is None: |
| | encoder_hidden_states = hidden_states |
| | elif attn.norm_cross: |
| | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
| |
|
| | context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) |
| |
|
| | key = attn.to_k(context_k) |
| | value = attn.to_v(context_v) |
| |
|
| | query = attn.head_to_batch_dim(query).contiguous() |
| | key = attn.head_to_batch_dim(key).contiguous() |
| | value = attn.head_to_batch_dim(value).contiguous() |
| |
|
| | hidden_states = xformers.ops.memory_efficient_attention( |
| | query, |
| | key, |
| | value, |
| | attn_bias=attention_mask, |
| | op=self.attention_op, |
| | scale=attn.scale, |
| | ) |
| | hidden_states = hidden_states.to(query.dtype) |
| | hidden_states = attn.batch_to_head_dim(hidden_states) |
| |
|
| | |
| | hidden_states = attn.to_out[0](hidden_states) |
| | |
| | hidden_states = attn.to_out[1](hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | def sliced_attn_forward( |
| | self: SlicedAttnProcessor, |
| | attn: Attention, |
| | hidden_states: torch.Tensor, |
| | encoder_hidden_states: torch.Tensor = None, |
| | attention_mask: torch.Tensor = None, |
| | ): |
| | batch_size, sequence_length, _ = ( |
| | hidden_states.shape |
| | if encoder_hidden_states is None |
| | else encoder_hidden_states.shape |
| | ) |
| | attention_mask = attn.prepare_attention_mask( |
| | attention_mask, sequence_length, batch_size |
| | ) |
| |
|
| | query = attn.to_q(hidden_states) |
| | dim = query.shape[-1] |
| | query = attn.head_to_batch_dim(query) |
| |
|
| | if encoder_hidden_states is None: |
| | encoder_hidden_states = hidden_states |
| | elif attn.norm_cross: |
| | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
| |
|
| | context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) |
| |
|
| | key = attn.to_k(context_k) |
| | value = attn.to_v(context_v) |
| | key = attn.head_to_batch_dim(key) |
| | value = attn.head_to_batch_dim(value) |
| |
|
| | batch_size_attention, query_tokens, _ = query.shape |
| | hidden_states = torch.zeros( |
| | (batch_size_attention, query_tokens, dim // attn.heads), |
| | device=query.device, |
| | dtype=query.dtype, |
| | ) |
| |
|
| | for i in range(batch_size_attention // self.slice_size): |
| | start_idx = i * self.slice_size |
| | end_idx = (i + 1) * self.slice_size |
| |
|
| | query_slice = query[start_idx:end_idx] |
| | key_slice = key[start_idx:end_idx] |
| | attn_mask_slice = ( |
| | attention_mask[start_idx:end_idx] if attention_mask is not None else None |
| | ) |
| |
|
| | attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) |
| |
|
| | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) |
| |
|
| | hidden_states[start_idx:end_idx] = attn_slice |
| |
|
| | hidden_states = attn.batch_to_head_dim(hidden_states) |
| |
|
| | |
| | hidden_states = attn.to_out[0](hidden_states) |
| | |
| | hidden_states = attn.to_out[1](hidden_states) |
| |
|
| | return hidden_states |
| |
|
| |
|
| | def v2_0_forward( |
| | self: AttnProcessor2_0, |
| | attn: Attention, |
| | hidden_states, |
| | encoder_hidden_states=None, |
| | attention_mask=None, |
| | ): |
| | batch_size, sequence_length, _ = ( |
| | hidden_states.shape |
| | if encoder_hidden_states is None |
| | else encoder_hidden_states.shape |
| | ) |
| | inner_dim = hidden_states.shape[-1] |
| |
|
| | if attention_mask is not None: |
| | attention_mask = attn.prepare_attention_mask( |
| | attention_mask, sequence_length, batch_size |
| | ) |
| | |
| | |
| | attention_mask = attention_mask.view( |
| | batch_size, attn.heads, -1, attention_mask.shape[-1] |
| | ) |
| |
|
| | query = attn.to_q(hidden_states) |
| |
|
| | if encoder_hidden_states is None: |
| | encoder_hidden_states = hidden_states |
| | elif attn.norm_cross: |
| | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
| |
|
| | context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) |
| |
|
| | key = attn.to_k(context_k) |
| | value = attn.to_v(context_v) |
| |
|
| | head_dim = inner_dim // attn.heads |
| | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| |
|
| | |
| | |
| | hidden_states = F.scaled_dot_product_attention( |
| | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
| | ) |
| |
|
| | hidden_states = hidden_states.transpose(1, 2).reshape( |
| | batch_size, -1, attn.heads * head_dim |
| | ) |
| | hidden_states = hidden_states.to(query.dtype) |
| |
|
| | |
| | hidden_states = attn.to_out[0](hidden_states) |
| | |
| | hidden_states = attn.to_out[1](hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | def replace_attentions_for_hypernetwork(): |
| | import diffusers.models.attention_processor |
| |
|
| | diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = ( |
| | xformers_forward |
| | ) |
| | diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = ( |
| | sliced_attn_forward |
| | ) |
| | diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward |
| |
|