File size: 3,596 Bytes
ccf7c47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import math
from typing import Any
import torch
from torch import nn
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
from .model_config import CoDAConfig
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class AttentionModule(nn.Module):
def __init__(self, config: CoDAConfig, kernel_config: dict[str, Any] | None = None):
super().__init__()
self.config = config
self.kernel_config = kernel_config
self.partition_spec = None
def forward(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
):
"""GPU-optimized PyTorch implementation"""
if self.config.attention_kernel != "splash_attention":
num_key_value_groups = (
self.config.num_attention_heads // self.config.num_key_value_heads
)
key_states = repeat_kv(key_states, num_key_value_groups)
value_states = repeat_kv(value_states, num_key_value_groups)
bsz, num_heads, q_len, head_dim = query_states.size()
head_dim = value_states.shape[-1]
kv_seq_len = key_states.shape[-2]
# Use SDPA with appropriate backend
match self.config.attention_kernel:
case "splash_attention":
raise NotImplementedError(
"Splash Attention is not supported in GPU environment"
)
case "flash_attention":
# Try to use flash attention backend, fallback to default if not available
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
attn_output = scaled_dot_product_attention(
query_states,
key_states,
value_states,
dropout_p=(
self.config.attention_dropout if self.training else 0.0
),
is_causal=False, # weiran: causal=False for bi-directional attention
)
case _:
# Default implementation - use math backend for compatibility
with sdpa_kernel(SDPBackend.MATH):
attn_output = scaled_dot_product_attention(
query_states,
key_states,
value_states,
dropout_p=(
self.config.attention_dropout if self.training else 0.0
),
is_causal=False, # weiran: causal=False for bi-directional attention
)
if attn_output.size() != (bsz, num_heads, q_len, head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, num_heads, q_len, head_dim)}, but is"
f" {attn_output.size()}"
)
return attn_output
|