2024-06-14 01:45:42 -06:00
|
|
|
import torch
|
|
|
|
|
|
|
|
def gptq_marlin_gemm(
|
|
|
|
a: torch.Tensor,
|
|
|
|
b_q_weight: torch.Tensor,
|
|
|
|
b_scales: torch.Tensor,
|
|
|
|
g_idx: torch.Tensor,
|
|
|
|
perm: torch.Tensor,
|
|
|
|
workspace: torch.Tensor,
|
|
|
|
num_bits: int,
|
|
|
|
size_m: int,
|
|
|
|
size_n: int,
|
|
|
|
size_k: int,
|
|
|
|
is_k_full: bool,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Matrix multiplication using Marlin kernels. This is an extension of
|
|
|
|
`marlin_gemm` that supports converted GPTQ kernels.
|
|
|
|
"""
|
|
|
|
...
|
|
|
|
|
2024-06-25 13:09:42 -06:00
|
|
|
def gptq_marlin_24_gemm(
|
|
|
|
a: torch.Tensor,
|
|
|
|
b_q_weight: torch.Tensor,
|
|
|
|
b_meta: torch.Tensor,
|
|
|
|
b_scales: torch.Tensor,
|
|
|
|
workspace: torch.Tensor,
|
|
|
|
num_bits: int,
|
|
|
|
size_m: int,
|
|
|
|
size_n: int,
|
|
|
|
size_k: int,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Matrix multiplication using Marlin kernels. This is an extension of
|
|
|
|
`marlin_gemm` that supports 2:4 sparsity.
|
|
|
|
"""
|
|
|
|
...
|
|
|
|
|
2024-06-14 01:45:42 -06:00
|
|
|
def gptq_marlin_repack(
|
|
|
|
b_q_weight: torch.Tensor,
|
|
|
|
perm: torch.Tensor,
|
|
|
|
size_k: int,
|
|
|
|
size_n: int,
|
|
|
|
num_bits: int,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
"""Repack GPTQ parameters for Marlin kernels."""
|
|
|
|
...
|
|
|
|
|
|
|
|
def marlin_gemm(
|
|
|
|
a: torch.Tensor,
|
|
|
|
b_q_weight: torch.Tensor,
|
|
|
|
b_scales: torch.Tensor,
|
|
|
|
workspace: torch.Tensor,
|
|
|
|
size_m: int,
|
|
|
|
size_n: int,
|
|
|
|
size_k: int,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Matrix multiplication using Marlin kernels.
|
|
|
|
"""
|
|
|
|
...
|
2024-07-11 08:03:26 -06:00
|
|
|
|
|
|
|
# fp8 marlin
|
|
|
|
def fp8_marlin_gemm(
|
|
|
|
a: torch.Tensor,
|
|
|
|
b_q_weight: torch.Tensor,
|
|
|
|
b_scales: torch.Tensor,
|
|
|
|
workspace: torch.Tensor,
|
|
|
|
num_bits: int,
|
|
|
|
size_m: int,
|
|
|
|
size_n: int,
|
|
|
|
size_k: int,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
return torch.ops._C.fp8_marlin_gemm(
|
|
|
|
a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
|
|
|
|
)
|