File size: 7,079 Bytes
fd8c8b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""

Quantization utilities for ternary weight representation.



This module implements the core quantization functions for converting

dense weights to ternary ({-1, 0, +1}) representation with appropriate

scaling factors.

"""

import torch
from typing import Tuple, Optional


def absmax_scale(tensor: torch.Tensor, dim: Optional[int] = None) -> torch.Tensor:
    """

    Compute absmax scaling factor for quantization.

    

    The absmax scale is:

        scale = max(abs(tensor)) / Q_max

    

    where Q_max is the maximum quantized value (e.g., 1 for ternary).

    

    Args:

        tensor: Input tensor to compute scale for

        dim: Dimension to compute scale along (None = global, int = per-dim)

    

    Returns:

        Scaling factor(s)

    

    Examples:

        >>> W = torch.randn(512, 512)

        >>> scale = absmax_scale(W, dim=0)  # Per output channel

        >>> scale.shape

        torch.Size([512])

    """
    if dim is None:
        # Global absmax
        scale = torch.max(torch.abs(tensor))
    else:
        # Per-dimension absmax
        scale = torch.max(torch.abs(tensor), dim=dim, keepdim=True)[0]
        # Remove keepdim for cleaner output
        scale = scale.squeeze(dim)
    
    # Add small epsilon to avoid division by zero
    scale = torch.clamp(scale, min=1e-5)
    
    return scale


def ternary_quantize(

    tensor: torch.Tensor,

    scale: Optional[torch.Tensor] = None,

) -> torch.Tensor:
    """

    Quantize tensor to ternary values {-1, 0, +1}.

    

    Uses a threshold-based approach:

        - Values > threshold → +1

        - Values < -threshold → -1

        - Values in [-threshold, threshold] → 0

    

    The threshold is typically computed as a fraction of the scale.

    

    Args:

        tensor: Input tensor to quantize

        scale: Optional pre-computed scale (if None, compute from tensor)

    

    Returns:

        Ternary tensor with values in {-1, 0, +1}

    

    Notes:

        - The threshold determines sparsity (more zeros)

        - Common thresholds: 0.33 * scale or 0.5 * scale

        - Inspired by BitNet's weight quantization scheme

    """
    # Compute scale if not provided
    if scale is None:
        scale = absmax_scale(tensor, dim=None)
    
    # Compute threshold (using 0.5 as a reasonable default)
    # This can be tuned: smaller threshold = more zeros (more sparse)
    threshold = 0.5 * scale
    
    # Ensure scale and threshold have proper shape for broadcasting
    if scale.dim() > 0:
        # Add dimensions to match tensor shape for broadcasting
        while threshold.dim() < tensor.dim():
            threshold = threshold.unsqueeze(-1)
    
    # Initialize ternary tensor with zeros
    ternary = torch.zeros_like(tensor)
    
    # Assign +1 and -1 based on threshold
    ternary[tensor > threshold] = 1
    ternary[tensor < -threshold] = -1
    
    return ternary


def weight_to_ternary(

    W: torch.Tensor,

    per_channel: bool = True,

) -> Tuple[torch.Tensor, torch.Tensor]:
    """

    Convert dense weights to ternary representation with scaling.

    

    This is the main quantization function that combines:

        1. Scale computation (absmax per channel or global)

        2. Ternary quantization

        3. Return both quantized weights and scales

    

    Args:

        W: Dense weight matrix of shape [out_features, in_features]

        per_channel: If True, use per-output-channel scaling (recommended)

    

    Returns:

        W_ternary: Ternary weight matrix (values in {-1, 0, +1})

        gamma: Scaling factors (shape [out_features] or scalar)

    

    Examples:

        >>> W = torch.randn(512, 768)

        >>> W_t, gamma = weight_to_ternary(W, per_channel=True)

        >>> W_reconstructed = W_t * gamma.unsqueeze(1)

        >>> error = torch.norm(W - W_reconstructed)

    

    Notes:

        - Per-channel scaling preserves output scale better

        - The scaling factor gamma compensates for quantization

        - This function is used during layer initialization/conversion

    """
    if per_channel:
        # Compute scale per output channel (along dimension 1)
        # W is [out_features, in_features], so dim=1 gives scale per output
        gamma = absmax_scale(W, dim=1)
    else:
        # Global scale for entire weight matrix
        gamma = absmax_scale(W, dim=None)
    
    # Quantize to ternary using the computed scale
    W_ternary = ternary_quantize(W, gamma)
    
    return W_ternary, gamma


def quantize_activations_absmax(

    x: torch.Tensor,

    bits: int = 8,

    per_token: bool = True,

) -> torch.Tensor:
    """

    Quantize activations using absmax scaling.

    

    BitNet quantizes both weights (ternary) and activations (8-bit).

    This function implements activation quantization with per-token scaling.

    

    Args:

        x: Input activations of shape [batch, seq_len, features]

        bits: Number of bits for quantization (default: 8)

        per_token: If True, scale per token; if False, global scaling

    

    Returns:

        Quantized activations (as float, simulating INT8)

    

    Notes:

        - Per-token scaling is important for handling outliers

        - Returns float for autograd compatibility

        - Simulates quantization without actual int8 storage

    """
    # Calculate quantization range based on bits
    Q_max = 2 ** (bits - 1) - 1  # For 8-bit: 127
    Q_min = -Q_max  # -127
    
    if per_token:
        # Compute scale per token (across feature dimension)
        # x shape: [batch, seq_len, features]
        # Scale along last dimension, keeping dims for broadcasting
        scale = torch.max(torch.abs(x), dim=-1, keepdim=True)[0]
        scale = torch.clamp(scale, min=1e-5)  # Avoid division by zero
    else:
        # Global scale for entire tensor
        scale = torch.max(torch.abs(x))
        scale = torch.clamp(scale, min=1e-5)
    
    # Quantize: scale to [-Q_max, Q_max], round, and scale back
    x_scaled = x / scale * Q_max
    x_quant = torch.clamp(x_scaled, Q_min, Q_max)
    x_quant = torch.round(x_quant)
    
    # Dequantize back to float (simulating int8 → float32 for autograd)
    x_dequant = x_quant * scale / Q_max
    
    return x_dequant


def dequantize_scale(

    x_quant: torch.Tensor,

    scale: torch.Tensor,

) -> torch.Tensor:
    """

    Dequantize tensor back to float using scale.

    

    Simple helper for:

        x_float = x_quant * scale

    

    Args:

        x_quant: Quantized tensor (ternary or int8)

        scale: Scaling factors

    

    Returns:

        Dequantized float tensor

    """
    # Ensure scale has proper shape for broadcasting
    if scale.dim() > 0 and scale.dim() < x_quant.dim():
        # Add dimensions to the right to match x_quant shape
        while scale.dim() < x_quant.dim():
            scale = scale.unsqueeze(-1)
    
    return x_quant * scale