danieldk's picture
danieldk HF Staff
Build uploaded using `kernels`.
7be134c verified
raw
history blame
495 Bytes
import torch
from ._ops import ops
def gemm_int4_forward(
input: torch.Tensor,
weight: torch.Tensor,
zeros: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
) -> torch.Tensor:
original_dtype = input.dtype
if original_dtype != torch.bfloat16:
input = input.to(torch.bfloat16)
output = ops.gemm_int4_forward(input, weight, zeros, absmax, blocksize)
if original_dtype != torch.bfloat16:
output = output.to(original_dtype)
return output