Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from models.helpers import DropPath, drop_path | |
| # this file only provides the 3 blocks used in VAR transformer | |
| __all__ = ['FFN', 'AdaLNSelfAttn', 'AdaLNBeforeHead'] | |
| # automatically import fused operators | |
| dropout_add_layer_norm = fused_mlp_func = memory_efficient_attention = flash_attn_func = None | |
| try: | |
| from flash_attn.ops.layer_norm import dropout_add_layer_norm | |
| from flash_attn.ops.fused_dense import fused_mlp_func | |
| except ImportError: pass | |
| # automatically import faster attention implementations | |
| try: from xformers.ops import memory_efficient_attention | |
| except ImportError: pass | |
| try: from flash_attn import flash_attn_func # qkv: BLHc, ret: BLHcq | |
| except ImportError: pass | |
| try: from torch.nn.functional import scaled_dot_product_attention as slow_attn # q, k, v: BHLc | |
| except ImportError: | |
| def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p=0.0): | |
| attn = query.mul(scale) @ key.transpose(-2, -1) # BHLc @ BHcL => BHLL | |
| if attn_mask is not None: attn.add_(attn_mask) | |
| return (F.dropout(attn.softmax(dim=-1), p=dropout_p, inplace=True) if dropout_p > 0 else attn.softmax(dim=-1)) @ value | |
| class FFN(nn.Module): | |
| def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_if_available=True): | |
| super().__init__() | |
| self.fused_mlp_func = fused_mlp_func if fused_if_available else None | |
| out_features = out_features or in_features | |
| hidden_features = hidden_features or in_features | |
| self.fc1 = nn.Linear(in_features, hidden_features) | |
| self.act = nn.GELU(approximate='tanh') | |
| self.fc2 = nn.Linear(hidden_features, out_features) | |
| self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity() | |
| def forward(self, x): | |
| if self.fused_mlp_func is not None: | |
| return self.drop(self.fused_mlp_func( | |
| x=x, weight1=self.fc1.weight, weight2=self.fc2.weight, bias1=self.fc1.bias, bias2=self.fc2.bias, | |
| activation='gelu_approx', save_pre_act=self.training, return_residual=False, checkpoint_lvl=0, | |
| heuristic=0, process_group=None, | |
| )) | |
| else: | |
| return self.drop(self.fc2( self.act(self.fc1(x)) )) | |
| def extra_repr(self) -> str: | |
| return f'fused_mlp_func={self.fused_mlp_func is not None}' | |
| class SelfAttention(nn.Module): | |
| def __init__( | |
| self, block_idx, embed_dim=768, num_heads=12, | |
| attn_drop=0., proj_drop=0., attn_l2_norm=False, flash_if_available=True, | |
| ): | |
| super().__init__() | |
| assert embed_dim % num_heads == 0 | |
| self.block_idx, self.num_heads, self.head_dim = block_idx, num_heads, embed_dim // num_heads # =64 | |
| self.attn_l2_norm = attn_l2_norm | |
| if self.attn_l2_norm: | |
| self.scale = 1 | |
| self.scale_mul_1H11 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True) | |
| self.max_scale_mul = torch.log(torch.tensor(100)).item() | |
| else: | |
| self.scale = 0.25 / math.sqrt(self.head_dim) | |
| self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False) | |
| self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim)) | |
| self.register_buffer('zero_k_bias', torch.zeros(embed_dim)) | |
| self.proj = nn.Linear(embed_dim, embed_dim) | |
| self.proj_drop = nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity() | |
| self.attn_drop: float = attn_drop | |
| self.using_flash = flash_if_available and flash_attn_func is not None | |
| self.using_xform = flash_if_available and memory_efficient_attention is not None | |
| # only used during inference | |
| self.caching, self.cached_k, self.cached_v = False, None, None | |
| def kv_caching(self, enable: bool): self.caching, self.cached_k, self.cached_v = enable, None, None | |
| # NOTE: attn_bias is None during inference because kv cache is enabled | |
| def forward(self, x, attn_bias): | |
| B, L, C = x.shape | |
| qkv = F.linear(input=x, weight=self.mat_qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim) | |
| main_type = qkv.dtype | |
| # qkv: BL3Hc | |
| using_flash = self.using_flash and attn_bias is None and qkv.dtype != torch.float32 | |
| if using_flash or self.using_xform: q, k, v = qkv.unbind(dim=2); dim_cat = 1 # q or k or v: BLHc | |
| else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); dim_cat = 2 # q or k or v: BHLc | |
| if self.attn_l2_norm: | |
| scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp() | |
| if using_flash or self.using_xform: scale_mul = scale_mul.transpose(1, 2) # 1H11 to 11H1 | |
| q = F.normalize(q, dim=-1).mul(scale_mul) | |
| k = F.normalize(k, dim=-1) | |
| if self.caching: | |
| if self.cached_k is None: self.cached_k = k; self.cached_v = v | |
| else: k = self.cached_k = torch.cat((self.cached_k, k), dim=dim_cat); v = self.cached_v = torch.cat((self.cached_v, v), dim=dim_cat) | |
| dropout_p = self.attn_drop if self.training else 0.0 | |
| if using_flash: | |
| oup = flash_attn_func(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), dropout_p=dropout_p, softmax_scale=self.scale).view(B, L, C) | |
| elif self.using_xform: | |
| oup = memory_efficient_attention(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), attn_bias=None if attn_bias is None else attn_bias.to(dtype=main_type).expand(B, self.num_heads, -1, -1), p=dropout_p, scale=self.scale).view(B, L, C) | |
| else: | |
| oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias, dropout_p=dropout_p).transpose(1, 2).reshape(B, L, C) | |
| return self.proj_drop(self.proj(oup)) | |
| # attn = (q @ k.transpose(-2, -1)).add_(attn_bias + self.local_rpb()) # BHLc @ BHcL => BHLL | |
| # attn = self.attn_drop(attn.softmax(dim=-1)) | |
| # oup = (attn @ v).transpose_(1, 2).reshape(B, L, -1) # BHLL @ BHLc = BHLc => BLHc => BLC | |
| def extra_repr(self) -> str: | |
| return f'using_flash={self.using_flash}, using_xform={self.using_xform}, attn_l2_norm={self.attn_l2_norm}' | |
| class AdaLNSelfAttn(nn.Module): | |
| def __init__( | |
| self, block_idx, last_drop_p, embed_dim, cond_dim, shared_aln: bool, norm_layer, | |
| num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., attn_l2_norm=False, | |
| flash_if_available=False, fused_if_available=True, | |
| ): | |
| super(AdaLNSelfAttn, self).__init__() | |
| self.block_idx, self.last_drop_p, self.C = block_idx, last_drop_p, embed_dim | |
| self.C, self.D = embed_dim, cond_dim | |
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
| self.attn = SelfAttention(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop, attn_l2_norm=attn_l2_norm, flash_if_available=flash_if_available) | |
| self.ffn = FFN(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio), drop=drop, fused_if_available=fused_if_available) | |
| self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False) | |
| self.shared_aln = shared_aln | |
| if self.shared_aln: | |
| self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5) | |
| else: | |
| lin = nn.Linear(cond_dim, 6*embed_dim) | |
| self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) | |
| self.fused_add_norm_fn = None | |
| # NOTE: attn_bias is None during inference because kv cache is enabled | |
| def forward(self, x, cond_BD, attn_bias): # C: embed_dim, D: cond_dim | |
| if self.shared_aln: | |
| gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C | |
| else: | |
| gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2) | |
| x = x + self.drop_path(self.attn( self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1), attn_bias=attn_bias ).mul_(gamma1)) | |
| x = x + self.drop_path(self.ffn( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed when FusedMLP is used | |
| return x | |
| def extra_repr(self) -> str: | |
| return f'shared_aln={self.shared_aln}' | |
| class AdaLNBeforeHead(nn.Module): | |
| def __init__(self, C, D, norm_layer): # C: embed_dim, D: cond_dim | |
| super().__init__() | |
| self.C, self.D = C, D | |
| self.ln_wo_grad = norm_layer(C, elementwise_affine=False) | |
| self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(D, 2*C)) | |
| def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor): | |
| scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2) | |
| return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift) | |