File size: 3,596 Bytes
ccf7c47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import math
from typing import Any

import torch
from torch import nn
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
from .model_config import CoDAConfig



def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class AttentionModule(nn.Module):
    def __init__(self, config: CoDAConfig, kernel_config: dict[str, Any] | None = None):
        super().__init__()
        self.config = config
        self.kernel_config = kernel_config
        self.partition_spec = None

    def forward(
        self,
        query_states: torch.Tensor,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
    ):
        """GPU-optimized PyTorch implementation"""

        if self.config.attention_kernel != "splash_attention":
            num_key_value_groups = (
                self.config.num_attention_heads // self.config.num_key_value_heads
            )
            key_states = repeat_kv(key_states, num_key_value_groups)
            value_states = repeat_kv(value_states, num_key_value_groups)

        bsz, num_heads, q_len, head_dim = query_states.size()
        head_dim = value_states.shape[-1]
        kv_seq_len = key_states.shape[-2]

        # Use SDPA with appropriate backend
        match self.config.attention_kernel:
            case "splash_attention":
                raise NotImplementedError(
                    "Splash Attention is not supported in GPU environment"
                )

            case "flash_attention":
                # Try to use flash attention backend, fallback to default if not available
                with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
                    attn_output = scaled_dot_product_attention(
                        query_states,
                        key_states,
                        value_states,
                        dropout_p=(
                            self.config.attention_dropout if self.training else 0.0
                        ),
                        is_causal=False,  # weiran: causal=False for bi-directional attention
                    )
            case _:
                # Default implementation - use math backend for compatibility
                with sdpa_kernel(SDPBackend.MATH):
                    attn_output = scaled_dot_product_attention(
                        query_states,
                        key_states,
                        value_states,
                        dropout_p=(
                            self.config.attention_dropout if self.training else 0.0
                        ),
                        is_causal=False,  # weiran: causal=False for bi-directional attention
                    )

        if attn_output.size() != (bsz, num_heads, q_len, head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, num_heads, q_len, head_dim)}, but is"
                f" {attn_output.size()}"
            )
        return attn_output