File size: 495 Bytes
7be134c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from ._ops import ops

def gemm_int4_forward(
    input: torch.Tensor,
    weight: torch.Tensor,
    zeros: torch.Tensor,
    absmax: torch.Tensor,
    blocksize: int,
) -> torch.Tensor:
    original_dtype = input.dtype
    if original_dtype != torch.bfloat16:
        input = input.to(torch.bfloat16)

    output = ops.gemm_int4_forward(input, weight, zeros, absmax, blocksize)
    if original_dtype != torch.bfloat16:
        output = output.to(original_dtype)

    return output