Add support for repacking AWQ weights for GPTQ-Marlin (#2278)

* Add support for repacking AWQ weights for GPTQ-Marlin

So far we couldn't support AWQ because virtually all AWQ models use
symmetric quantization, which GPTQ-Marlin did not suppors. GPTQ-Marlin
has recently added support AWQ repacking and AWQ asymmetric quantization
(zero_point=True).

This change updates all GPTQ-Marlin kernels from upstream and wires up
AWQ support. For now enabling AWQ using Marlin requires running TGI with
`--quantize gptq`.

* Enable Marlin for supported AWQ configurations by default

This makes the AWQ -> GPTQ repack test redundant, since we are now
testing this with the regular AWQ test.
This commit is contained in:
Daniël de Kok 2024-07-23 13:08:20 +02:00 committed by GitHub
parent 5fca30ee15
commit 9935720c87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1016 additions and 203 deletions

View File

@ -0,0 +1,269 @@
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {}
} // namespace marlin
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
namespace marlin {
template <int const num_threads, int const num_bits>
__global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
int start_k_tile = blockIdx.x * block_k_tiles;
if (start_k_tile >= k_tiles) {
return;
}
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait<repack_stages - 2>();
__syncthreads();
};
extern __shared__ int4 sh[];
constexpr int tile_n_ints = tile_n_size / pack_factor;
constexpr int stage_n_threads = tile_n_ints / 4;
constexpr int stage_k_threads = tile_k_size;
constexpr int stage_size = stage_k_threads * stage_n_threads;
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
cp_async_fence();
return;
}
int first_n = n_tile_id * tile_n_size;
int first_n_packed = first_n / pack_factor;
int4* sh_ptr = sh + stage_size * pipe;
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const*>(
&(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) +
first_n_packed + (n_id * 4)])));
}
cp_async_fence();
};
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
return;
}
int warp_id = threadIdx.x / 32;
int th_id = threadIdx.x % 32;
if (warp_id >= 4) {
return;
}
int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2;
constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col;
int cur_n_packed = cur_n / pack_factor;
int cur_n_pos = cur_n % pack_factor;
constexpr int sh_stride = tile_n_ints;
constexpr uint32_t mask = (1 << num_bits) - 1;
int4* sh_stage_ptr = sh + stage_size * pipe;
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
// Undo interleaving
int cur_n_pos_unpacked;
if constexpr (num_bits == 4) {
constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
cur_n_pos_unpacked = undo_pack[cur_n_pos];
} else {
constexpr int undo_pack[4] = {0, 2, 1, 3};
cur_n_pos_unpacked = undo_pack[cur_n_pos];
}
uint32_t vals[8];
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
sh_stride * cur_elem];
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else {
constexpr int pack_idx[4] = {0, 2, 1, 3};
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
}
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
}
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
}
wait_for_stage();
};
#pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0;
start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1);
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
wait_for_stage();
}
n_tile_id += repack_stages;
}
}
}
} // namespace marlin
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", marlin::tile_k_size);
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits;
// Verify B
TORCH_CHECK(b_q_weight.size(0) == size_k,
"b_q_weight.size(0) = ", b_q_weight.size(0),
" is not size_k = ", size_k);
TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1),
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
", size_n = ", size_n, ", pack_factor = ", pack_factor);
// Verify device and strides
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
torch::Tensor out = torch::empty(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
// Get ptrs
uint32_t const* b_q_weight_ptr =
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
// Get dev info
int dev = b_q_weight.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
int blocks;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
if (false) {
}
CALL_IF(4)
CALL_IF(8)
else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
}
return out;
}
#endif

View File

@ -3,6 +3,8 @@
#include "ext.hh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("awq_marlin_repack", &awq_marlin_repack,
"Repack AWQ parameters for Marlin");
m.def("gptq_marlin_gemm", &gptq_marlin_gemm,
"Marlin gemm with GPTQ compatibility");
m.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin sparse 2:4 gemm");

View File

@ -6,11 +6,15 @@
// No support for async
#else
torch::Tensor awq_marlin_repack(torch::Tensor &b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits);
torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &g_idx,
torch::Tensor &perm, torch::Tensor &workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full);
torch::Tensor &b_scales, torch::Tensor &b_zeros,
torch::Tensor &g_idx, torch::Tensor &perm,
torch::Tensor &workspace, int64_t num_bits,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp);
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_meta,
@ -27,8 +31,8 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &workspace,
int64_t size_m, int64_t size_n, int64_t size_k);
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
torch::Tensor fp8_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k);

View File

@ -19,10 +19,10 @@
* Adapted from https://github.com/IST-DASLab/marlin
*/
#include "./gptq_marlin.cuh"
#include "./gptq_marlin_dtypes.cuh"
#include "marlin.cuh"
#include "marlin_dtypes.cuh"
using namespace gptq_marlin;
using namespace marlin;
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
@ -1224,16 +1224,15 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
", size_k = ", size_k);
// Verify B
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k,
" is not divisible by tile_size = ", gptq_marlin::tile_size);
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
" is not divisible by tile_size = ", marlin::tile_size);
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size);
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,
", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
"b_q_weight.size(1) = ", b_q_weight.size(1),
" is not divisible by tile_size = ", gptq_marlin::tile_size);
int actual_size_n =
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
" is not divisible by tile_size = ", marlin::tile_size);
int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
", actual_size_n = ", actual_size_n);
@ -1274,11 +1273,9 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
num_groups = b_scales.size(0);
// Verify workspace size
TORCH_CHECK(
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n,
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n);
int min_workspace_size =
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
", is not divisible by min_thread_n = ", marlin::min_thread_n);
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = ", workspace.numel(),
" is below min_workspace_size = ", min_workspace_size);
@ -1290,14 +1287,14 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
b_scales.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), num_bits, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
gptq_marlin::max_par);
marlin::max_par);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
fp8_marlin::marlin_mm_f16i4<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m,
size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size,
dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
gptq_marlin::max_par);
marlin::max_par);
} else {
TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16");
}

View File

@ -19,8 +19,8 @@
* Adapted from https://github.com/IST-DASLab/marlin
*/
#include "gptq_marlin.cuh"
#include "gptq_marlin_dtypes.cuh"
#include "marlin.cuh"
#include "marlin_dtypes.cuh"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
@ -32,7 +32,7 @@ inline std::string str(T x) {
return std::to_string(x);
}
namespace gptq_marlin {
namespace marlin {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
@ -72,10 +72,11 @@ __global__ void Marlin(
} // namespace gptq_marlin
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& g_idx,
torch::Tensor& perm, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full) {
torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace, int64_t num_bits,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
@ -264,6 +265,114 @@ dequant_8bit<nv_bfloat16>(int q) {
return frag_b;
}
// Zero-point dequantizers
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit_zp(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
template <>
__device__ inline typename ScalarType<half>::FragB dequant_4bit_zp<half>(
int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
typename ScalarType<half>::FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_4bit_zp<nv_bfloat16>(int q) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
typename ScalarType<nv_bfloat16>::FragB frag_b;
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC300C300;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
return frag_b;
}
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit_zp(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
template <>
__device__ inline typename ScalarType<half>::FragB dequant_8bit_zp<half>(
int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
typename ScalarType<half>::FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_8bit_zp<nv_bfloat16>(int q) {
typename ScalarType<nv_bfloat16>::FragB frag_b;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388608.f;
fp32_intermediates[1] -= 8388608.f;
fp32_intermediates[2] -= 8388608.f;
fp32_intermediates[3] -= 8388608.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template <typename scalar_t>
@ -277,6 +386,17 @@ __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
frag_b[1] = __hmul2(frag_b[1], s);
}
template <typename scalar_t>
__device__ inline void sub_zp(typename ScalarType<scalar_t>::FragB& frag_b,
typename ScalarType<scalar_t>::scalar_t2& frag_zp,
int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
scalar_t2 zp =
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]);
frag_b[0] = __hsub2(frag_b[0], zp);
frag_b[1] = __hsub2(frag_b[1], zp);
}
// Same as above, but for act_order (each K is multiplied individually)
template <typename scalar_t>
__device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
@ -404,6 +524,7 @@ template <typename scalar_t, // compute dtype, half or nv_float16
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
>
@ -413,6 +534,8 @@ __global__ void Marlin(
int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
@ -437,6 +560,7 @@ __global__ void Marlin(
using FragB = typename ScalarType<scalar_t>::FragB;
using FragC = typename ScalarType<scalar_t>::FragC;
using FragS = typename ScalarType<scalar_t>::FragS;
using FragZP = typename ScalarType<scalar_t>::FragZP;
constexpr int pack_factor = 32 / num_bits;
@ -566,6 +690,13 @@ __global__ void Marlin(
int tb_n_warps = thread_n_blocks / 4;
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
// Zero-points sizes/strides
int zp_gl_stride = (prob_n / pack_factor) / 4;
constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4;
constexpr int zp_tb_groups = s_tb_groups;
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
int zp_gl_rd_delta = zp_gl_stride;
// Global A read index of current thread.
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
@ -605,6 +736,19 @@ __global__ void Marlin(
int s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
// Zero-points
int zp_gl_rd;
if constexpr (has_zp) {
if constexpr (group_blocks == -1) {
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} else {
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
zp_sh_stride * slice_col + threadIdx.x;
}
}
int zp_sh_wr = threadIdx.x;
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
@ -616,6 +760,18 @@ __global__ void Marlin(
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4;
// Zero-points have the same read layout as the scales
// (without column-wise case)
constexpr int num_col_threads = 8;
constexpr int num_row_threads = 4;
constexpr int num_ints_per_thread = 8 / pack_factor;
int zp_sh_rd;
if constexpr (has_zp) {
zp_sh_rd = num_ints_per_thread * num_col_threads *
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
}
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
@ -664,14 +820,17 @@ __global__ void Marlin(
int4* sh_a = sh;
int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
int4* sh_s = sh_g_idx + (stages * g_idx_stage);
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage);
// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4]; // No act-order
FragS act_frag_s[2][4][4]; // For act-order
FragS frag_s[2][4]; // No act-order
FragS act_frag_s[2][4][4]; // For act-order
int frag_qzp[2][num_ints_per_thread]; // Zero-points
FragZP frag_zp; // Zero-points in fp16
// Zero accumulators.
auto zero_accums = [&]() {
@ -777,6 +936,28 @@ __global__ void Marlin(
}
}
}
if constexpr (has_zp && group_blocks != -1) {
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch zero-points if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
} else {
for (int i = 0; i < zp_tb_groups; i++) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
&zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
}
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
@ -784,6 +965,12 @@ __global__ void Marlin(
cp_async_fence();
};
auto fetch_zp_to_shared = [&]() {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
};
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
@ -932,8 +1119,73 @@ __global__ void Marlin(
}
};
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
if constexpr (!has_zp) {
return;
}
int pipe = full_pipe % stages;
if constexpr (group_blocks == -1) {
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
}
} else if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
} else {
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
int cur_group_id = k_blocks / group_blocks;
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
sh_zp_stage += cur_group_id * zp_sh_stride;
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
}
};
// Execute the actual tensor core matmul of a sub-tile.
auto matmul = [&](int k) {
if constexpr (has_zp) {
FragB frag_zp_0;
FragB frag_zp_1;
if constexpr (num_bits == 4) {
int zp_quant = frag_qzp[k % 2][0];
int zp_quant_shift = zp_quant >> 8;
frag_zp_0 = dequant_4bit_zp<scalar_t>(zp_quant);
frag_zp_1 = dequant_4bit_zp<scalar_t>(zp_quant_shift);
} else {
int zp_quant_0 = frag_qzp[k % 2][0];
int zp_quant_1 = frag_qzp[k % 2][1];
frag_zp_0 = dequant_8bit_zp<scalar_t>(zp_quant_0);
frag_zp_1 = dequant_8bit_zp<scalar_t>(zp_quant_1);
}
frag_zp[0] = frag_zp_0[0];
frag_zp[1] = frag_zp_0[1];
frag_zp[2] = frag_zp_1[0];
frag_zp[3] = frag_zp_1[1];
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
@ -944,16 +1196,32 @@ __global__ void Marlin(
int b_quant = frag_b_quant[k % 2][0][j];
int b_quant_shift = b_quant >> 8;
frag_b0 = dequant_4bit<scalar_t>(b_quant);
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
if constexpr (has_zp) {
frag_b0 = dequant_4bit_zp<scalar_t>(b_quant);
frag_b1 = dequant_4bit_zp<scalar_t>(b_quant_shift);
} else {
frag_b0 = dequant_4bit<scalar_t>(b_quant);
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
}
} else {
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
if constexpr (has_zp) {
frag_b0 = dequant_8bit_zp<scalar_t>(b_quant_0);
frag_b1 = dequant_8bit_zp<scalar_t>(b_quant_1);
} else {
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
}
}
// Apply zero-point to frag_b0
if constexpr (has_zp) {
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
}
// Apply scale to frag_b0
@ -967,6 +1235,11 @@ __global__ void Marlin(
}
}
// Apply zero-point to frag_b1
if constexpr (has_zp) {
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
}
// Apply scale to frag_b1
if constexpr (has_act_order) {
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
@ -1189,6 +1462,12 @@ __global__ void Marlin(
}
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
}
if constexpr (has_zp && group_blocks == -1) {
if (i == 0) {
fetch_zp_to_shared();
}
}
fetch_to_shared(i, i, i < slice_iters);
}
@ -1197,6 +1476,7 @@ __global__ void Marlin(
init_same_group(0);
fetch_to_registers(0, 0);
fetch_scales_to_registers(0, 0);
fetch_zp_to_registers(0, 0);
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
slice_k_start_shared_fetch += tb_k * (stages - 1);
};
@ -1217,6 +1497,7 @@ __global__ void Marlin(
for (int k = 0; k < b_sh_wr_iters; k++) {
fetch_to_registers(k + 1, pipe % stages);
fetch_scales_to_registers(k + 1, pipe);
fetch_zp_to_registers(k + 1, pipe);
if (k == b_sh_wr_iters - 2) {
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages);
@ -1354,6 +1635,7 @@ __global__ void Marlin(
} else {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
}
start_pipes();
@ -1363,22 +1645,24 @@ __global__ void Marlin(
}
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
NUM_THREADS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
num_threads == NUM_THREADS) { \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
GROUP_BLOCKS>, \
HAS_ZP, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
prob_k, locks); \
HAS_ZP, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \
prob_m, prob_n, prob_k, locks); \
}
typedef struct {
@ -1548,39 +1832,61 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
return exec_config_t{0, {-1, -1, -1}};
}
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
#define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
#define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
template <typename scalar_t>
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
void* g_idx, void* perm, void* a_tmp, int prob_m,
int prob_n, int prob_k, void* workspace, int num_bits,
bool has_act_order, bool is_k_full, int num_groups,
int group_size, int dev, cudaStream_t stream, int thread_k,
int thread_n, int sms, int max_par) {
bool has_act_order, bool is_k_full, bool has_zp,
int num_groups, int group_size, int dev,
cudaStream_t stream, int thread_k, int thread_n, int sms,
int max_par) {
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
@ -1665,6 +1971,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C;
const int4* s_ptr = (const int4*)s;
const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm;
int4* a_tmp_ptr = (int4*)a_tmp;
@ -1701,28 +2008,33 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
thread_m_blocks = exec_cfg.max_m_blocks;
}
// Define kernel configurations
if (false) {
}
CALL_IF(4, 32, 2, 256)
CALL_IF(4, 16, 4, 256)
CALL_IF(4, 8, 8, 256)
CALL_IF(4, 8, 4, 128)
CALL_IF(4, 4, 8, 128)
CALL_IF(8, 32, 2, 256)
CALL_IF(8, 16, 4, 256)
CALL_IF(8, 8, 8, 256)
CALL_IF(8, 8, 4, 128)
CALL_IF(8, 4, 8, 128)
GPTQ_CALL_IF(4, 16, 4, 256)
GPTQ_CALL_IF(4, 8, 8, 256)
GPTQ_CALL_IF(4, 8, 4, 128)
GPTQ_CALL_IF(4, 4, 8, 128)
GPTQ_CALL_IF(8, 16, 4, 256)
GPTQ_CALL_IF(8, 8, 8, 256)
GPTQ_CALL_IF(8, 8, 4, 128)
GPTQ_CALL_IF(8, 4, 8, 128)
AWQ_CALL_IF(4, 16, 4, 256)
AWQ_CALL_IF(4, 8, 8, 256)
AWQ_CALL_IF(4, 8, 4, 128)
AWQ_CALL_IF(4, 4, 8, 128)
AWQ_CALL_IF(8, 16, 4, 256)
AWQ_CALL_IF(8, 8, 8, 256)
AWQ_CALL_IF(8, 8, 4, 128)
AWQ_CALL_IF(8, 4, 8, 128)
else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
str(prob_n) + ", " + str(prob_k) + "]" +
", has_act_order = " + str(has_act_order) +
", num_groups = " + str(num_groups) +
", group_size = " + str(group_size) +
", thread_m_blocks = " + str(thread_m_blocks) +
", thread_n_blocks = " + str(thread_n_blocks) +
", thread_k_blocks = " + str(thread_k_blocks));
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
", ", prob_k, "]", ", has_act_order = ", has_act_order,
", num_groups = ", num_groups, ", group_size = ", group_size,
", thread_m_blocks = ", thread_m_blocks,
", thread_n_blocks = ", thread_n_blocks,
", thread_k_blocks = ", thread_k_blocks,
", num_bits = ", num_bits);
}
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
@ -1733,10 +2045,11 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
} // namespace gptq_marlin
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& g_idx,
torch::Tensor& perm, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full) {
torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace, int64_t num_bits,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp) {
// Verify num_bits
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
@ -1749,16 +2062,15 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
", size_k = ", size_k);
// Verify B
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k,
" is not divisible by tile_size = ", gptq_marlin::tile_size);
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
" is not divisible by tile_size = ", marlin::tile_size);
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size);
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,
", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
"b_q_weight.size(1) = ", b_q_weight.size(1),
" is not divisible by tile_size = ", gptq_marlin::tile_size);
int actual_size_n =
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
" is not divisible by tile_size = ", marlin::tile_size);
int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
", actual_size_n = ", actual_size_n);
@ -1772,6 +2084,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
@ -1805,8 +2120,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int group_size = -1;
bool has_act_order = g_idx.size(0) != 0;
int b_rank = b_scales.sizes().size();
TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2");
int rank = b_scales.sizes().size();
TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2");
TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
" is not size_n = ", size_n);
num_groups = b_scales.size(0);
@ -1832,34 +2147,44 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
}
// Verify b_zeros
if (has_zp) {
int rank = b_zeros.sizes().size();
TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2");
TORCH_CHECK(b_zeros.size(0) == num_groups,
"b_zeros dim 0 = ", b_zeros.size(0),
" is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
"b_zeros dim 1 = ", b_scales.size(1),
" is not size_n / pack_factor = ", size_n / pack_factor);
}
// Verify workspace size
TORCH_CHECK(
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n,
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n);
int min_workspace_size =
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
", is not divisible by min_thread_n = ", marlin::min_thread_n);
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = ", workspace.numel(),
" is below min_workspace_size = ", min_workspace_size);
int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
gptq_marlin::marlin_mm_f16i4<half>(
marlin::marlin_mm_f16i4<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
b_scales.data_ptr<at::Half>(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups,
group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
thread_n, sms, gptq_marlin::max_par);
b_scales.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
gptq_marlin::marlin_mm_f16i4<nv_bfloat16>(
marlin::marlin_mm_f16i4<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order,
is_k_full, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
gptq_marlin::max_par);
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par);
} else {
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
}

View File

@ -1,23 +1,16 @@
#include "gptq_marlin.cuh"
namespace gptq_marlin {
static constexpr int repack_stages = 8;
static constexpr int repack_threads = 256;
static constexpr int tile_k_size = tile_size;
static constexpr int tile_n_size = tile_k_size * 4;
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void marlin_repack_kernel(
__global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {}
} // namespace gptq_marlin
} // namespace marlin
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
@ -29,8 +22,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
#else
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void marlin_repack_kernel(
__global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {
@ -259,28 +254,28 @@ __global__ void marlin_repack_kernel(
}
}
} // namespace gptq_marlin
} // namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
NUM_BITS, HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", marlin::tile_k_size);
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
@ -308,10 +303,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
torch::Tensor out =
torch::empty({size_k / gptq_marlin::tile_size,
size_n * gptq_marlin::tile_size / pack_factor},
options);
torch::Tensor out = torch::empty(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
// Detect if there is act_order
bool has_perm = perm.size(0) != 0;

View File

@ -9,7 +9,9 @@
#include <cuda_runtime.h>
#include <iostream>
namespace gptq_marlin {
namespace marlin {
// Marlin params
// 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time,
@ -25,6 +27,15 @@ static constexpr int min_thread_k = 64;
static constexpr int tile_size = 16;
static constexpr int max_par = 16;
// Repack params
static constexpr int repack_stages = 8;
static constexpr int repack_threads = 256;
static constexpr int tile_k_size = tile_size;
static constexpr int tile_n_size = tile_k_size * 4;
// Helpers
template <typename T, int n>
struct Vec {
T elems[n];
@ -73,4 +84,4 @@ __device__ inline void cp_async_wait() {
#endif
} // namespace gptq_marlin
} // namespace marlin

View File

@ -30,7 +30,7 @@ inline std::string str(T x) {
return std::to_string(x);
}
namespace marlin {
namespace marlin_dense {
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
@ -1040,7 +1040,7 @@ void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m,
}
}
} // namespace marlin
} // namespace marlin_dense
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
@ -1054,24 +1054,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_CHECK(size_k == a.size(1),
"Shape mismatch: a.size(1) = " + str(a.size(1)) +
", size_k = " + str(size_k));
TORCH_CHECK(size_k % marlin::tile_size == 0,
"size_k = " + str(size_k) +
" is not divisible by tile_size = " + str(marlin::tile_size));
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
TORCH_CHECK(size_k % marlin_dense::tile_size == 0,
"size_k = " + str(size_k) + " is not divisible by tile_size = " +
str(marlin_dense::tile_size));
TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = " +
str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
", tile_size = " + str(marlin::tile_size));
", tile_size = " + str(marlin_dense::tile_size));
// Verify N
TORCH_CHECK(b_scales.size(1) == size_n,
"b_scales.size(1) = " + str(b_scales.size(1)) +
", size_n = " + str(size_n));
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
" is not divisible by tile_size = " + str(marlin::tile_size));
TORCH_CHECK(
b_q_weight.size(1) % marlin_dense::tile_size == 0,
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
" is not divisible by tile_size = " + str(marlin_dense::tile_size));
int actual_size_n =
(b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit;
int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) *
marlin_dense::pack_factor_4bit;
TORCH_CHECK(
size_n == actual_size_n,
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
@ -1116,21 +1117,22 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
"Unexpected groupsize = " + str(groupsize));
// Verify workspace size
TORCH_CHECK(
size_n % marlin::min_thread_n == 0,
"size_n = " + str(size_n) +
", is not divisible by min_thread_n = " + str(marlin::min_thread_n));
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0,
"size_n = " + str(size_n) +
", is not divisible by min_thread_n = " +
str(marlin_dense::min_thread_n));
int min_workspace_size =
(size_n / marlin_dense::min_thread_n) * marlin_dense::max_par;
TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = " + str(workspace.numel()) +
" is below min_workspace_size = " + str(min_workspace_size));
int dev = a.get_device();
marlin::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
b_scales.data_ptr(), size_m, size_n, size_k,
workspace.data_ptr(), groupsize, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n,
sms, marlin::max_par);
marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
b_scales.data_ptr(), size_m, size_n, size_k,
workspace.data_ptr(), groupsize, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k,
thread_n, sms, marlin_dense::max_par);
return c;
}

View File

@ -1,11 +1,11 @@
#ifndef _data_types_cuh
#define _data_types_cuh
#include "gptq_marlin.cuh"
#include "marlin.cuh"
#include <cuda_fp16.h>
#include <cuda_bf16.h>
namespace gptq_marlin {
namespace marlin {
template <typename scalar_t>
class ScalarType {};
@ -23,6 +23,7 @@ class ScalarType<half> {
using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>;
using FragZP = Vec<half2, 4>;
static __device__ float inline num2float(const half x) {
return __half2float(x);
@ -51,6 +52,7 @@ class ScalarType<nv_bfloat16> {
using FragB = Vec<nv_bfloat162, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<nv_bfloat162, 1>;
using FragZP = Vec<nv_bfloat162, 4>;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) {
@ -72,6 +74,6 @@ class ScalarType<nv_bfloat16> {
#endif
};
} // namespace gptq_marlin
} // namespace marlin
#endif

View File

@ -9,6 +9,7 @@ setup(
CUDAExtension(
name="marlin_kernels",
sources=[
"marlin_kernels/awq_marlin_repack.cu",
"marlin_kernels/fp8_marlin.cu",
"marlin_kernels/gptq_marlin.cu",
"marlin_kernels/gptq_marlin_repack.cu",

View File

@ -156,16 +156,26 @@ class GPTQWeightsLoader(WeightsLoader):
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = weights.get_tensor(f"{prefix}.g_idx")
if not self.sym:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
scales = weights.get_tensor(f"{prefix}.scales")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
@ -275,14 +285,26 @@ class GPTQWeightsLoader(WeightsLoader):
quantize=self.quantize,
sym=self.sym,
):
g_idx = weights.get_tensor(f"{prefix}.g_idx")
if not self.sym:
qzeros = weights.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
@ -349,18 +371,31 @@ class GPTQWeightsLoader(WeightsLoader):
quantize=self.quantize,
sym=self.sym,
):
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
if not self.sym:
qzeros = torch.cat(
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
@ -438,7 +473,19 @@ class GPTQWeightsLoader(WeightsLoader):
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
if not self.sym:
if self.desc_act or self.groupsize == -1:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
else:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
if self.desc_act or self.groupsize == -1:
scales = weights.get_tensor(f"{prefix}.scales")
else:
@ -449,10 +496,12 @@ class GPTQWeightsLoader(WeightsLoader):
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=sharded_in_features,
)

View File

@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy
import torch
import torch.nn as nn
from loguru import logger
@ -174,11 +175,12 @@ def can_use_gptq_marlin(
SYSTEM == "cuda"
and marlin_kernels is not None
and has_sm_8_0
and quantize == "gptq"
and quant_method == "gptq"
and quantize in {"awq", "gptq"}
and quant_method in {"awq", "gptq"}
and bits in GPTQ_MARLIN_BITS
and groupsize in GPTQ_MARLIN_GROUP_SIZES
and sym
# We only suppord asymmetric quantization for AWQ.
and (sym or quant_method == "awq")
)
@ -234,6 +236,7 @@ class GPTQMarlinWeight(Weight):
"""
qweight: torch.Tensor
qzeros: torch.Tensor
scales: torch.Tensor
g_idx: torch.Tensor
perm: torch.Tensor
@ -256,11 +259,13 @@ class GPTQMarlinWeight(Weight):
def repack_gptq_for_marlin(
*,
qweight: torch.Tensor,
qzeros: Optional[torch.Tensor],
scales: torch.Tensor,
g_idx: torch.Tensor,
g_idx: Optional[torch.Tensor],
bits: int,
desc_act: bool,
groupsize: int,
quant_method: str,
sym: bool,
sharded_infeatures: bool,
) -> GPTQMarlinWeight:
@ -279,30 +284,54 @@ def repack_gptq_for_marlin(
raise RuntimeError(
f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
)
if not sym:
if not (sym or quant_method == "awq"):
raise RuntimeError(
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
)
log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.")
weights_per_int = 32 // bits
in_features = qweight.shape[0] * weights_per_int
in_features = qweight.shape[0]
out_features = qweight.shape[1]
# AWQ uses column packing, GPTQ uses row packing
if quant_method == "awq":
out_features *= weights_per_int
else:
in_features *= weights_per_int
if in_features % groupsize != 0:
raise ValueError(
f"Number of input features ({in_features}) not divisible by group size ({groupsize})"
)
if desc_act and groupsize != -1:
if g_idx is not None and desc_act and groupsize != -1:
perm = torch.argsort(g_idx).to(torch.int)
g_idx = g_idx[perm]
else:
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
repacked = marlin_kernels.gptq_marlin_repack(
qweight, perm, in_features, out_features, bits
)
if quant_method == "awq":
repacked = marlin_kernels.awq_marlin_repack(
qweight, in_features, out_features, bits
)
if qzeros is not None:
qzeros = awq_to_marlin_zero_points(
qzeros,
in_features // groupsize,
out_features,
bits,
)
else:
repacked = marlin_kernels.gptq_marlin_repack(
qweight, perm, in_features, out_features, bits
)
if qzeros is None:
qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)
scales = permute_scales(scales)
@ -310,6 +339,7 @@ def repack_gptq_for_marlin(
return GPTQMarlinWeight(
qweight=repacked,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
perm=perm,
@ -343,6 +373,7 @@ class GPTQMarlinLinear(nn.Module):
self.is_full_k = weight.is_full_k
self.qweight = weight.qweight
self.qzeros = weight.qzeros
self.scales = weight.scales
self.g_idx = weight.g_idx
self.perm = weight.perm
@ -363,6 +394,7 @@ class GPTQMarlinLinear(nn.Module):
A_flat,
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.perm,
self.workspace,
@ -371,6 +403,7 @@ class GPTQMarlinLinear(nn.Module):
self.scales.shape[1],
A_flat.shape[1],
self.is_full_k,
self.qzeros.numel() > 0,
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
@ -688,3 +721,116 @@ class MarlinLinear(nn.Module):
C += self.bias
return C
# Functions below are from vLLM
def get_pack_factor(bits: int) -> int:
if 32 % bits != 0:
raise ValueError(f"Cannot {bits} bit values into uint32")
return 32 // bits
def pack_cols(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[:, i::pack_factor] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def unpack_cols(
packed_q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
assert packed_q_w.shape == (
size_k,
size_n // pack_factor,
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
packed_q_w.shape, size_k, size_n, pack_factor
)
orig_device = packed_q_w.device
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
mask = (1 << num_bits) - 1
for i in range(pack_factor):
vals = packed_q_w_cpu & mask
packed_q_w_cpu >>= num_bits
q_res[:, i::pack_factor] = vals
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def marlin_zero_points(
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
zp = zp.reshape((-1, len(_scale_perm)))[:, _scale_perm]
# Interleave column dim (for the dequantize code) and pack it to int32
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
zp = zp.reshape((-1, size_n)).contiguous()
zp = pack_cols(zp, num_bits, size_k, size_n)
return zp
def awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
# Here we undo both of these, and then apply marlin permutation
# and pack it back.
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
# Undo interleaving (use argsort(..) to get inverse perm)
if num_bits == 4:
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
elif num_bits == 8:
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
q_zp = q_zp.reshape((-1, size_n)).contiguous()
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
return marlin_zp

View File

@ -34,7 +34,7 @@ def _get_quantizer_config(model_id, revision):
groupsize = -1
quant_method = "gptq"
checkpoint_format = None
sym = True
sym = False
desc_act = False
filename = "config.json"
@ -52,12 +52,17 @@ def _get_quantizer_config(model_id, revision):
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
)
if "zero_point" in data["quantization_config"]:
sym = not data["quantization_config"]["zero_point"]
quant_method = "awq"
elif "sym" in data["quantization_config"]:
sym = data["quantization_config"]["sym"]
bits = data["quantization_config"]["bits"]
groupsize = data["quantization_config"]["group_size"]
# Order is important here, desc_act is missing on some real models
quant_method = data["quantization_config"]["quant_method"]
checkpoint_format = data["quantization_config"].get("checkpoint_format")
sym = data["quantization_config"]["sym"]
desc_act = data["quantization_config"]["desc_act"]
except Exception:
filename = "quantize_config.json"
@ -72,7 +77,13 @@ def _get_quantizer_config(model_id, revision):
data = json.load(f)
bits = data["bits"]
groupsize = data["group_size"]
sym = data["sym"]
if "zero_point" in data:
sym = not data["zero_point"]
quant_method = "awq"
elif "sym" in data:
sym = data["sym"]
desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
quant_method = "awq"