File size: 495 Bytes
7be134c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
from ._ops import ops
def gemm_int4_forward(
input: torch.Tensor,
weight: torch.Tensor,
zeros: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
) -> torch.Tensor:
original_dtype = input.dtype
if original_dtype != torch.bfloat16:
input = input.to(torch.bfloat16)
output = ops.gemm_int4_forward(input, weight, zeros, absmax, blocksize)
if original_dtype != torch.bfloat16:
output = output.to(original_dtype)
return output
|