From 9935720c874055ac77897b22ca4c3c65185bd7e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 23 Jul 2024 13:08:20 +0200 Subject: [PATCH] Add support for repacking AWQ weights for GPTQ-Marlin (#2278) * Add support for repacking AWQ weights for GPTQ-Marlin So far we couldn't support AWQ because virtually all AWQ models use symmetric quantization, which GPTQ-Marlin did not suppors. GPTQ-Marlin has recently added support AWQ repacking and AWQ asymmetric quantization (zero_point=True). This change updates all GPTQ-Marlin kernels from upstream and wires up AWQ support. For now enabling AWQ using Marlin requires running TGI with `--quantize gptq`. * Enable Marlin for supported AWQ configurations by default This makes the AWQ -> GPTQ repack test redundant, since we are now testing this with the regular AWQ test. --- .../marlin_kernels/awq_marlin_repack.cu | 269 +++++++++ server/marlin/marlin_kernels/ext.cpp | 2 + server/marlin/marlin_kernels/ext.hh | 16 +- server/marlin/marlin_kernels/fp8_marlin.cu | 33 +- server/marlin/marlin_kernels/gptq_marlin.cu | 525 ++++++++++++++---- .../marlin_kernels/gptq_marlin_repack.cu | 58 +- .../{gptq_marlin.cuh => marlin.cuh} | 15 +- .../marlin_kernels/marlin_cuda_kernel.cu | 46 +- ...tq_marlin_dtypes.cuh => marlin_dtypes.cuh} | 8 +- server/marlin/setup.py | 1 + .../layers/gptq/__init__.py | 63 ++- .../text_generation_server/layers/marlin.py | 166 +++++- .../utils/quantization.py | 17 +- 13 files changed, 1016 insertions(+), 203 deletions(-) create mode 100644 server/marlin/marlin_kernels/awq_marlin_repack.cu rename server/marlin/marlin_kernels/{gptq_marlin.cuh => marlin.cuh} (88%) rename server/marlin/marlin_kernels/{gptq_marlin_dtypes.cuh => marlin_dtypes.cuh} (93%) diff --git a/server/marlin/marlin_kernels/awq_marlin_repack.cu b/server/marlin/marlin_kernels/awq_marlin_repack.cu new file mode 100644 index 00000000..c58216d8 --- /dev/null +++ b/server/marlin/marlin_kernels/awq_marlin_repack.cu @@ -0,0 +1,269 @@ +#include "marlin.cuh" + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +namespace marlin { + +template +__global__ void awq_marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, + int size_k, int size_n) {} + +} // namespace marlin + +torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, + int64_t size_k, int64_t size_n, + int64_t num_bits) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +namespace marlin { + +template +__global__ void awq_marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, + int size_k, int size_n) { + constexpr int pack_factor = 32 / num_bits; + + int k_tiles = size_k / tile_k_size; + int n_tiles = size_n / tile_n_size; + int block_k_tiles = div_ceil(k_tiles, gridDim.x); + + int start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int tile_n_ints = tile_n_size / pack_factor; + + constexpr int stage_n_threads = tile_n_ints / 4; + constexpr int stage_k_threads = tile_k_size; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * tile_n_size; + int first_n_packed = first_n / pack_factor; + + int4* sh_ptr = sh + stage_size * pipe; + + if (threadIdx.x < stage_size) { + int k_id = threadIdx.x / stage_n_threads; + int n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + + cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + + first_n_packed + (n_id * 4)]))); + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } + + int warp_id = threadIdx.x / 32; + int th_id = threadIdx.x % 32; + + if (warp_id >= 4) { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = warp_id * 16 + tc_col; + int cur_n_packed = cur_n / pack_factor; + int cur_n_pos = cur_n % pack_factor; + + constexpr int sh_stride = tile_n_ints; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4* sh_stage_ptr = sh + stage_size * pipe; + uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + // Undo interleaving + int cur_n_pos_unpacked; + if constexpr (num_bits == 4) { + constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } else { + constexpr int undo_pack[4] = {0, 2, 1, 3}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } + + uint32_t vals[8]; + #pragma unroll + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; + + int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + + sh_stride * cur_elem]; + + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } + + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; + #pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; + #pragma unroll + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) { + #pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; + #pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) { + #pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, + n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } +} + +} // namespace marlin + + #define CALL_IF(NUM_BITS) \ + else if (num_bits == NUM_BITS) { \ + cudaFuncSetAttribute( \ + marlin::awq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + marlin::awq_marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, out_ptr, size_k, size_n); \ + } + +torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, + int64_t size_n, int64_t num_bits) { + // Verify compatibility with marlin tile of 16x64 + TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, + " is not divisible by tile_k_size = ", marlin::tile_k_size); + TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n, + " is not divisible by tile_n_size = ", marlin::tile_n_size); + + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int const pack_factor = 32 / num_bits; + + // Verify B + TORCH_CHECK(b_q_weight.size(0) == size_k, + "b_q_weight.size(0) = ", b_q_weight.size(0), + " is not size_k = ", size_k); + TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1), + "Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1), + ", size_n = ", size_n, ", pack_factor = ", pack_factor); + + // Verify device and strides + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt"); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); + auto options = torch::TensorOptions() + .dtype(b_q_weight.dtype()) + .device(b_q_weight.device()); + torch::Tensor out = torch::empty( + {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, + options); + + // Get ptrs + uint32_t const* b_q_weight_ptr = + reinterpret_cast(b_q_weight.data_ptr()); + uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); + + // Get dev info + int dev = b_q_weight.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + int blocks; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + if (false) { + } + CALL_IF(4) + CALL_IF(8) + else { + TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits); + } + + return out; +} + +#endif diff --git a/server/marlin/marlin_kernels/ext.cpp b/server/marlin/marlin_kernels/ext.cpp index 04e1530f..8dcd2ff9 100644 --- a/server/marlin/marlin_kernels/ext.cpp +++ b/server/marlin/marlin_kernels/ext.cpp @@ -3,6 +3,8 @@ #include "ext.hh" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("awq_marlin_repack", &awq_marlin_repack, + "Repack AWQ parameters for Marlin"); m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Marlin gemm with GPTQ compatibility"); m.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin sparse 2:4 gemm"); diff --git a/server/marlin/marlin_kernels/ext.hh b/server/marlin/marlin_kernels/ext.hh index 102c058e..ec76b5f4 100644 --- a/server/marlin/marlin_kernels/ext.hh +++ b/server/marlin/marlin_kernels/ext.hh @@ -6,11 +6,15 @@ // No support for async #else +torch::Tensor awq_marlin_repack(torch::Tensor &b_q_weight, int64_t size_k, + int64_t size_n, int64_t num_bits); + torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &g_idx, - torch::Tensor &perm, torch::Tensor &workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full); + torch::Tensor &b_scales, torch::Tensor &b_zeros, + torch::Tensor &g_idx, torch::Tensor &perm, + torch::Tensor &workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, bool has_zp); torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor &b_meta, @@ -27,8 +31,8 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor &b_scales, torch::Tensor &workspace, int64_t size_m, int64_t size_n, int64_t size_k); -torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, +torch::Tensor fp8_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_scales, torch::Tensor &workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k); diff --git a/server/marlin/marlin_kernels/fp8_marlin.cu b/server/marlin/marlin_kernels/fp8_marlin.cu index aaef67e5..f6458ca6 100644 --- a/server/marlin/marlin_kernels/fp8_marlin.cu +++ b/server/marlin/marlin_kernels/fp8_marlin.cu @@ -19,10 +19,10 @@ * Adapted from https://github.com/IST-DASLab/marlin */ -#include "./gptq_marlin.cuh" -#include "./gptq_marlin_dtypes.cuh" +#include "marlin.cuh" +#include "marlin_dtypes.cuh" -using namespace gptq_marlin; +using namespace marlin; #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert(std::is_same::value || \ @@ -1224,16 +1224,15 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ", size_k = ", size_k); // Verify B - TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, - " is not divisible by tile_size = ", gptq_marlin::tile_size); - TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), + TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", marlin::tile_size); + TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); - TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, + ", size_k = ", size_k, ", tile_size = ", marlin::tile_size); + TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, "b_q_weight.size(1) = ", b_q_weight.size(1), - " is not divisible by tile_size = ", gptq_marlin::tile_size); - int actual_size_n = - (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; + " is not divisible by tile_size = ", marlin::tile_size); + int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor; TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, ", actual_size_n = ", actual_size_n); @@ -1274,11 +1273,9 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, num_groups = b_scales.size(0); // Verify workspace size - TORCH_CHECK( - size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, - ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); - int min_workspace_size = - (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; + TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n, + ", is not divisible by min_thread_n = ", marlin::min_thread_n); + int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; TORCH_CHECK(workspace.numel() >= min_workspace_size, "workspace.numel = ", workspace.numel(), " is below min_workspace_size = ", min_workspace_size); @@ -1290,14 +1287,14 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, b_scales.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - gptq_marlin::max_par); + marlin::max_par); } else if (a.scalar_type() == at::ScalarType::BFloat16) { fp8_marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - gptq_marlin::max_par); + marlin::max_par); } else { TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16"); } diff --git a/server/marlin/marlin_kernels/gptq_marlin.cu b/server/marlin/marlin_kernels/gptq_marlin.cu index 0beb9de1..122c5c16 100644 --- a/server/marlin/marlin_kernels/gptq_marlin.cu +++ b/server/marlin/marlin_kernels/gptq_marlin.cu @@ -19,8 +19,8 @@ * Adapted from https://github.com/IST-DASLab/marlin */ -#include "gptq_marlin.cuh" -#include "gptq_marlin_dtypes.cuh" +#include "marlin.cuh" +#include "marlin_dtypes.cuh" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert(std::is_same::value || \ @@ -32,7 +32,7 @@ inline std::string str(T x) { return std::to_string(x); } -namespace gptq_marlin { +namespace marlin { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 @@ -72,10 +72,11 @@ __global__ void Marlin( } // namespace gptq_marlin torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& g_idx, - torch::Tensor& perm, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full) { + torch::Tensor& b_scales, torch::Tensor& b_zeros, + torch::Tensor& g_idx, torch::Tensor& perm, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full) { TORCH_CHECK_NOT_IMPLEMENTED(false, "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); @@ -264,6 +265,114 @@ dequant_8bit(int q) { return frag_b; } +// Zero-point dequantizers + +template +__device__ inline typename ScalarType::FragB dequant_4bit_zp(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_4bit_zp( + int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + typename ScalarType::FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant_4bit_zp(int q) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + typename ScalarType::FragB frag_b; + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC300C300; + + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template +__device__ inline typename ScalarType::FragB dequant_8bit_zp(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_8bit_zp( + int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + + typename ScalarType::FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant_8bit_zp(int q) { + typename ScalarType::FragB frag_b; + + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388608.f; + fp32_intermediates[1] -= 8388608.f; + fp32_intermediates[2] -= 8388608.f; + fp32_intermediates[3] -= 8388608.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); + + return frag_b; +} + // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. template @@ -277,6 +386,17 @@ __device__ inline void scale(typename ScalarType::FragB& frag_b, frag_b[1] = __hmul2(frag_b[1], s); } +template +__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, + typename ScalarType::scalar_t2& frag_zp, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 zp = + ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + // Same as above, but for act_order (each K is multiplied individually) template __device__ inline void scale4(typename ScalarType::FragB& frag_b, @@ -404,6 +524,7 @@ template shared // fetch pipeline const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > @@ -413,6 +534,8 @@ __global__ void Marlin( int4* __restrict__ C, // fp16 output buffer of shape mxn const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) const int* __restrict__ g_idx, // int32 group indices of shape k int num_groups, // number of scale groups per output channel int prob_m, // batch dimension m @@ -437,6 +560,7 @@ __global__ void Marlin( using FragB = typename ScalarType::FragB; using FragC = typename ScalarType::FragC; using FragS = typename ScalarType::FragS; + using FragZP = typename ScalarType::FragZP; constexpr int pack_factor = 32 / num_bits; @@ -566,6 +690,13 @@ __global__ void Marlin( int tb_n_warps = thread_n_blocks / 4; int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + // Zero-points sizes/strides + int zp_gl_stride = (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + // Global A read index of current thread. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); @@ -605,6 +736,19 @@ __global__ void Marlin( int s_sh_wr = threadIdx.x; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } + } + int zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + // We use a different scale layout for grouped and column-wise quantization as // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. @@ -616,6 +760,18 @@ __global__ void Marlin( s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + // Precompute which thread should not read memory in which iterations; this is // needed if there are more threads than required for a certain tilesize or // when the batchsize is not a multiple of 16. @@ -664,14 +820,17 @@ __global__ void Marlin( int4* sh_a = sh; int4* sh_b = sh_a + (stages * a_sh_stage); int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_s = sh_g_idx + (stages * g_idx_stage); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 // Zero accumulators. auto zero_accums = [&]() { @@ -777,6 +936,28 @@ __global__ void Marlin( } } } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], + &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } } } // Insert a fence even when we are winding down the pipeline to ensure that @@ -784,6 +965,12 @@ __global__ void Marlin( cp_async_fence(); }; + auto fetch_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + // Wait until the next thread tile has been loaded to shared memory. auto wait_for_stage = [&]() { // We only have `stages - 2` active fetches since we are double buffering @@ -932,8 +1119,73 @@ __global__ void Marlin( } }; + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + if constexpr (!has_zp) { + return; + } + + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + }; + // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { + if constexpr (has_zp) { + FragB frag_zp_0; + FragB frag_zp_1; + if constexpr (num_bits == 4) { + int zp_quant = frag_qzp[k % 2][0]; + int zp_quant_shift = zp_quant >> 8; + frag_zp_0 = dequant_4bit_zp(zp_quant); + frag_zp_1 = dequant_4bit_zp(zp_quant_shift); + + } else { + int zp_quant_0 = frag_qzp[k % 2][0]; + int zp_quant_1 = frag_qzp[k % 2][1]; + frag_zp_0 = dequant_8bit_zp(zp_quant_0); + frag_zp_1 = dequant_8bit_zp(zp_quant_1); + } + + frag_zp[0] = frag_zp_0[0]; + frag_zp[1] = frag_zp_0[1]; + frag_zp[2] = frag_zp_1[0]; + frag_zp[3] = frag_zp_1[1]; + } + // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. #pragma unroll @@ -944,16 +1196,32 @@ __global__ void Marlin( int b_quant = frag_b_quant[k % 2][0][j]; int b_quant_shift = b_quant >> 8; - frag_b0 = dequant_4bit(b_quant); - frag_b1 = dequant_4bit(b_quant_shift); + if constexpr (has_zp) { + frag_b0 = dequant_4bit_zp(b_quant); + frag_b1 = dequant_4bit_zp(b_quant_shift); + + } else { + frag_b0 = dequant_4bit(b_quant); + frag_b1 = dequant_4bit(b_quant_shift); + } } else { int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - frag_b0 = dequant_8bit(b_quant_0); - frag_b1 = dequant_8bit(b_quant_1); + if constexpr (has_zp) { + frag_b0 = dequant_8bit_zp(b_quant_0); + frag_b1 = dequant_8bit_zp(b_quant_1); + } else { + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); + } + } + + // Apply zero-point to frag_b0 + if constexpr (has_zp) { + sub_zp(frag_b0, frag_zp[j], 0); } // Apply scale to frag_b0 @@ -967,6 +1235,11 @@ __global__ void Marlin( } } + // Apply zero-point to frag_b1 + if constexpr (has_zp) { + sub_zp(frag_b1, frag_zp[j], 1); + } + // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], @@ -1189,6 +1462,12 @@ __global__ void Marlin( } fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); } + + if constexpr (has_zp && group_blocks == -1) { + if (i == 0) { + fetch_zp_to_shared(); + } + } fetch_to_shared(i, i, i < slice_iters); } @@ -1197,6 +1476,7 @@ __global__ void Marlin( init_same_group(0); fetch_to_registers(0, 0); fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); a_gl_rd += a_gl_rd_delta_o * (stages - 1); slice_k_start_shared_fetch += tb_k * (stages - 1); }; @@ -1217,6 +1497,7 @@ __global__ void Marlin( for (int k = 0; k < b_sh_wr_iters; k++) { fetch_to_registers(k + 1, pipe % stages); fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); if (k == b_sh_wr_iters - 2) { fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); @@ -1354,6 +1635,7 @@ __global__ void Marlin( } else { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; } start_pipes(); @@ -1363,22 +1645,24 @@ __global__ void Marlin( } #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ + THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ + NUM_THREADS) \ else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ + has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ cudaFuncSetAttribute( \ Marlin, \ + HAS_ZP, GROUP_BLOCKS>, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ Marlin<<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ - prob_k, locks); \ + HAS_ZP, GROUP_BLOCKS> \ + <<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \ + prob_m, prob_n, prob_k, locks); \ } typedef struct { @@ -1548,39 +1832,61 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, return exec_config_t{0, {-1, -1, -1}}; } - #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + #define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) + + #define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) template -void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, +void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m, int prob_n, int prob_k, void* workspace, int num_bits, - bool has_act_order, bool is_k_full, int num_groups, - int group_size, int dev, cudaStream_t stream, int thread_k, - int thread_n, int sms, int max_par) { + bool has_act_order, bool is_k_full, bool has_zp, + int num_groups, int group_size, int dev, + cudaStream_t stream, int thread_k, int thread_n, int sms, + int max_par) { TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, @@ -1665,6 +1971,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; const int4* s_ptr = (const int4*)s; + const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; int4* a_tmp_ptr = (int4*)a_tmp; @@ -1701,28 +2008,33 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, thread_m_blocks = exec_cfg.max_m_blocks; } - // Define kernel configurations if (false) { } - CALL_IF(4, 32, 2, 256) - CALL_IF(4, 16, 4, 256) - CALL_IF(4, 8, 8, 256) - CALL_IF(4, 8, 4, 128) - CALL_IF(4, 4, 8, 128) - CALL_IF(8, 32, 2, 256) - CALL_IF(8, 16, 4, 256) - CALL_IF(8, 8, 8, 256) - CALL_IF(8, 8, 4, 128) - CALL_IF(8, 4, 8, 128) + GPTQ_CALL_IF(4, 16, 4, 256) + GPTQ_CALL_IF(4, 8, 8, 256) + GPTQ_CALL_IF(4, 8, 4, 128) + GPTQ_CALL_IF(4, 4, 8, 128) + GPTQ_CALL_IF(8, 16, 4, 256) + GPTQ_CALL_IF(8, 8, 8, 256) + GPTQ_CALL_IF(8, 8, 4, 128) + GPTQ_CALL_IF(8, 4, 8, 128) + + AWQ_CALL_IF(4, 16, 4, 256) + AWQ_CALL_IF(4, 8, 8, 256) + AWQ_CALL_IF(4, 8, 4, 128) + AWQ_CALL_IF(4, 4, 8, 128) + AWQ_CALL_IF(8, 16, 4, 256) + AWQ_CALL_IF(8, 8, 8, 256) + AWQ_CALL_IF(8, 8, 4, 128) + AWQ_CALL_IF(8, 4, 8, 128) else { - TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + - str(prob_n) + ", " + str(prob_k) + "]" + - ", has_act_order = " + str(has_act_order) + - ", num_groups = " + str(num_groups) + - ", group_size = " + str(group_size) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); + TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, + ", ", prob_k, "]", ", has_act_order = ", has_act_order, + ", num_groups = ", num_groups, ", group_size = ", group_size, + ", thread_m_blocks = ", thread_m_blocks, + ", thread_n_blocks = ", thread_n_blocks, + ", thread_k_blocks = ", thread_k_blocks, + ", num_bits = ", num_bits); } A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; @@ -1733,10 +2045,11 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, } // namespace gptq_marlin torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& g_idx, - torch::Tensor& perm, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full) { + torch::Tensor& b_scales, torch::Tensor& b_zeros, + torch::Tensor& g_idx, torch::Tensor& perm, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, bool has_zp) { // Verify num_bits TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); @@ -1749,16 +2062,15 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ", size_k = ", size_k); // Verify B - TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, - " is not divisible by tile_size = ", gptq_marlin::tile_size); - TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), + TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", marlin::tile_size); + TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); - TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, + ", size_k = ", size_k, ", tile_size = ", marlin::tile_size); + TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, "b_q_weight.size(1) = ", b_q_weight.size(1), - " is not divisible by tile_size = ", gptq_marlin::tile_size); - int actual_size_n = - (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; + " is not divisible by tile_size = ", marlin::tile_size); + int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor; TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, ", actual_size_n = ", actual_size_n); @@ -1772,6 +2084,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU"); + TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous"); + TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); @@ -1805,8 +2120,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, int group_size = -1; bool has_act_order = g_idx.size(0) != 0; - int b_rank = b_scales.sizes().size(); - TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); + int rank = b_scales.sizes().size(); + TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2"); TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), " is not size_n = ", size_n); num_groups = b_scales.size(0); @@ -1832,34 +2147,44 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, } } + // Verify b_zeros + if (has_zp) { + int rank = b_zeros.sizes().size(); + TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2"); + TORCH_CHECK(b_zeros.size(0) == num_groups, + "b_zeros dim 0 = ", b_zeros.size(0), + " is not num_groups = ", num_groups); + TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor, + "b_zeros dim 1 = ", b_scales.size(1), + " is not size_n / pack_factor = ", size_n / pack_factor); + } + // Verify workspace size - TORCH_CHECK( - size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, - ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); - int min_workspace_size = - (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; + TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n, + ", is not divisible by min_thread_n = ", marlin::min_thread_n); + int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; TORCH_CHECK(workspace.numel() >= min_workspace_size, "workspace.numel = ", workspace.numel(), " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { - gptq_marlin::marlin_mm_f16i4( + marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), - a_tmp.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups, - group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_n, sms, gptq_marlin::max_par); + b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, + num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, marlin::max_par); } else if (a.scalar_type() == at::ScalarType::BFloat16) { - gptq_marlin::marlin_mm_f16i4( + marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), - size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order, - is_k_full, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - gptq_marlin::max_par); + b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), + a_tmp.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, + num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, marlin::max_par); } else { TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); } diff --git a/server/marlin/marlin_kernels/gptq_marlin_repack.cu b/server/marlin/marlin_kernels/gptq_marlin_repack.cu index 4adc158e..c71b1bf5 100644 --- a/server/marlin/marlin_kernels/gptq_marlin_repack.cu +++ b/server/marlin/marlin_kernels/gptq_marlin_repack.cu @@ -1,23 +1,16 @@ -#include "gptq_marlin.cuh" - -namespace gptq_marlin { - -static constexpr int repack_stages = 8; - -static constexpr int repack_threads = 256; - -static constexpr int tile_k_size = tile_size; -static constexpr int tile_n_size = tile_k_size * 4; +#include "marlin.cuh" #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +namespace marlin { + template -__global__ void marlin_repack_kernel( +__global__ void gptq_marlin_repack_kernel( uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) {} -} // namespace gptq_marlin +} // namespace marlin torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, @@ -29,8 +22,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, #else +namespace marlin { + template -__global__ void marlin_repack_kernel( +__global__ void gptq_marlin_repack_kernel( uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) { @@ -259,28 +254,28 @@ __global__ void marlin_repack_kernel( } } -} // namespace gptq_marlin +} // namespace marlin - #define CALL_IF(NUM_BITS, HAS_PERM) \ - else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ - cudaFuncSetAttribute( \ - gptq_marlin::marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - gptq_marlin::marlin_repack_kernel \ - <<>>( \ - b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ + #define CALL_IF(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + cudaFuncSetAttribute( \ + marlin::gptq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + marlin::gptq_marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ } torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) { // Verify compatibility with marlin tile of 16x64 - TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k, - " is not divisible by tile_k_size = ", gptq_marlin::tile_k_size); - TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n, - " is not divisible by tile_n_size = ", gptq_marlin::tile_n_size); + TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, + " is not divisible by tile_k_size = ", marlin::tile_k_size); + TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n, + " is not divisible by tile_n_size = ", marlin::tile_n_size); TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); @@ -308,10 +303,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, auto options = torch::TensorOptions() .dtype(b_q_weight.dtype()) .device(b_q_weight.device()); - torch::Tensor out = - torch::empty({size_k / gptq_marlin::tile_size, - size_n * gptq_marlin::tile_size / pack_factor}, - options); + torch::Tensor out = torch::empty( + {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, + options); // Detect if there is act_order bool has_perm = perm.size(0) != 0; diff --git a/server/marlin/marlin_kernels/gptq_marlin.cuh b/server/marlin/marlin_kernels/marlin.cuh similarity index 88% rename from server/marlin/marlin_kernels/gptq_marlin.cuh rename to server/marlin/marlin_kernels/marlin.cuh index 42af4495..74ccbac5 100644 --- a/server/marlin/marlin_kernels/gptq_marlin.cuh +++ b/server/marlin/marlin_kernels/marlin.cuh @@ -9,7 +9,9 @@ #include #include -namespace gptq_marlin { +namespace marlin { + +// Marlin params // 8 warps are a good choice since every SM has 4 schedulers and having more // than 1 warp per schedule allows some more latency hiding. At the same time, @@ -25,6 +27,15 @@ static constexpr int min_thread_k = 64; static constexpr int tile_size = 16; static constexpr int max_par = 16; +// Repack params +static constexpr int repack_stages = 8; + +static constexpr int repack_threads = 256; + +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; + +// Helpers template struct Vec { T elems[n]; @@ -73,4 +84,4 @@ __device__ inline void cp_async_wait() { #endif -} // namespace gptq_marlin +} // namespace marlin diff --git a/server/marlin/marlin_kernels/marlin_cuda_kernel.cu b/server/marlin/marlin_kernels/marlin_cuda_kernel.cu index d124c014..37339b84 100644 --- a/server/marlin/marlin_kernels/marlin_cuda_kernel.cu +++ b/server/marlin/marlin_kernels/marlin_cuda_kernel.cu @@ -30,7 +30,7 @@ inline std::string str(T x) { return std::to_string(x); } -namespace marlin { +namespace marlin_dense { constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } @@ -1040,7 +1040,7 @@ void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m, } } -} // namespace marlin +} // namespace marlin_dense torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& workspace, @@ -1054,24 +1054,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, TORCH_CHECK(size_k == a.size(1), "Shape mismatch: a.size(1) = " + str(a.size(1)) + ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % marlin::tile_size == 0, - "size_k = " + str(size_k) + - " is not divisible by tile_size = " + str(marlin::tile_size)); - TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), + TORCH_CHECK(size_k % marlin_dense::tile_size == 0, + "size_k = " + str(size_k) + " is not divisible by tile_size = " + + str(marlin_dense::tile_size)); + TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = " + str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + - ", tile_size = " + str(marlin::tile_size)); + ", tile_size = " + str(marlin_dense::tile_size)); // Verify N TORCH_CHECK(b_scales.size(1) == size_n, "b_scales.size(1) = " + str(b_scales.size(1)) + ", size_n = " + str(size_n)); - TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(marlin::tile_size)); + TORCH_CHECK( + b_q_weight.size(1) % marlin_dense::tile_size == 0, + "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + + " is not divisible by tile_size = " + str(marlin_dense::tile_size)); - int actual_size_n = - (b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit; + int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) * + marlin_dense::pack_factor_4bit; TORCH_CHECK( size_n == actual_size_n, "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); @@ -1116,21 +1117,22 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, "Unexpected groupsize = " + str(groupsize)); // Verify workspace size - TORCH_CHECK( - size_n % marlin::min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + str(marlin::min_thread_n)); - int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; + TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0, + "size_n = " + str(size_n) + + ", is not divisible by min_thread_n = " + + str(marlin_dense::min_thread_n)); + int min_workspace_size = + (size_n / marlin_dense::min_thread_n) * marlin_dense::max_par; TORCH_CHECK(workspace.numel() >= min_workspace_size, "workspace.numel = " + str(workspace.numel()) + " is below min_workspace_size = " + str(min_workspace_size)); int dev = a.get_device(); - marlin::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, - sms, marlin::max_par); + marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), + b_scales.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), groupsize, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, + thread_n, sms, marlin_dense::max_par); return c; } diff --git a/server/marlin/marlin_kernels/gptq_marlin_dtypes.cuh b/server/marlin/marlin_kernels/marlin_dtypes.cuh similarity index 93% rename from server/marlin/marlin_kernels/gptq_marlin_dtypes.cuh rename to server/marlin/marlin_kernels/marlin_dtypes.cuh index ca1b7099..be06c09b 100644 --- a/server/marlin/marlin_kernels/gptq_marlin_dtypes.cuh +++ b/server/marlin/marlin_kernels/marlin_dtypes.cuh @@ -1,11 +1,11 @@ #ifndef _data_types_cuh #define _data_types_cuh -#include "gptq_marlin.cuh" +#include "marlin.cuh" #include #include -namespace gptq_marlin { +namespace marlin { template class ScalarType {}; @@ -23,6 +23,7 @@ class ScalarType { using FragB = Vec; using FragC = Vec; using FragS = Vec; + using FragZP = Vec; static __device__ float inline num2float(const half x) { return __half2float(x); @@ -51,6 +52,7 @@ class ScalarType { using FragB = Vec; using FragC = Vec; using FragS = Vec; + using FragZP = Vec; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static __device__ float inline num2float(const nv_bfloat16 x) { @@ -72,6 +74,6 @@ class ScalarType { #endif }; -} // namespace gptq_marlin +} // namespace marlin #endif diff --git a/server/marlin/setup.py b/server/marlin/setup.py index cc38bccf..f92eef04 100644 --- a/server/marlin/setup.py +++ b/server/marlin/setup.py @@ -9,6 +9,7 @@ setup( CUDAExtension( name="marlin_kernels", sources=[ + "marlin_kernels/awq_marlin_repack.cu", "marlin_kernels/fp8_marlin.cu", "marlin_kernels/gptq_marlin.cu", "marlin_kernels/gptq_marlin_repack.cu", diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 6feca275..6d7495b7 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -156,16 +156,26 @@ class GPTQWeightsLoader(WeightsLoader): f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" ) - g_idx = weights.get_tensor(f"{prefix}.g_idx") + if not self.sym: + qzeros = weights.get_tensor(f"{prefix}.qzeros") + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + g_idx = weights.get_tensor(f"{prefix}.g_idx") scales = weights.get_tensor(f"{prefix}.scales") return repack_gptq_for_marlin( qweight=qweight, scales=scales, + qzeros=qzeros, g_idx=g_idx, bits=self.bits, desc_act=self.desc_act, groupsize=self.groupsize, + quant_method=self.quant_method, sym=self.sym, sharded_infeatures=False, ) @@ -275,14 +285,26 @@ class GPTQWeightsLoader(WeightsLoader): quantize=self.quantize, sym=self.sym, ): - g_idx = weights.get_tensor(f"{prefix}.g_idx") + if not self.sym: + qzeros = weights.get_packed_sharded( + f"{prefix}.qzeros", dim=1, block_sizes=block_sizes + ) + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + g_idx = weights.get_tensor(f"{prefix}.g_idx") return repack_gptq_for_marlin( qweight=qweight, scales=scales, + qzeros=qzeros, g_idx=g_idx, bits=self.bits, desc_act=self.desc_act, groupsize=self.groupsize, + quant_method=self.quant_method, sym=self.sym, sharded_infeatures=False, ) @@ -349,18 +371,31 @@ class GPTQWeightsLoader(WeightsLoader): quantize=self.quantize, sym=self.sym, ): - w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] + + if not self.sym: + qzeros = torch.cat( + [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] return repack_gptq_for_marlin( qweight=qweight, scales=scales, + qzeros=qzeros, g_idx=g_idx, bits=self.bits, desc_act=self.desc_act, groupsize=self.groupsize, + quant_method=self.quant_method, sym=self.sym, sharded_infeatures=False, ) @@ -438,7 +473,19 @@ class GPTQWeightsLoader(WeightsLoader): f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" ) - g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + if not self.sym: + if self.desc_act or self.groupsize == -1: + qzeros = weights.get_tensor(f"{prefix}.qzeros") + else: + qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + if self.desc_act or self.groupsize == -1: scales = weights.get_tensor(f"{prefix}.scales") else: @@ -449,10 +496,12 @@ class GPTQWeightsLoader(WeightsLoader): return repack_gptq_for_marlin( qweight=qweight, scales=scales, + qzeros=qzeros, g_idx=g_idx, bits=self.bits, desc_act=self.desc_act, groupsize=self.groupsize, + quant_method=self.quant_method, sym=self.sym, sharded_infeatures=sharded_in_features, ) diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index a28012da..cf622cc2 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import List, Optional, Tuple, Union +import numpy import torch import torch.nn as nn from loguru import logger @@ -174,11 +175,12 @@ def can_use_gptq_marlin( SYSTEM == "cuda" and marlin_kernels is not None and has_sm_8_0 - and quantize == "gptq" - and quant_method == "gptq" + and quantize in {"awq", "gptq"} + and quant_method in {"awq", "gptq"} and bits in GPTQ_MARLIN_BITS and groupsize in GPTQ_MARLIN_GROUP_SIZES - and sym + # We only suppord asymmetric quantization for AWQ. + and (sym or quant_method == "awq") ) @@ -234,6 +236,7 @@ class GPTQMarlinWeight(Weight): """ qweight: torch.Tensor + qzeros: torch.Tensor scales: torch.Tensor g_idx: torch.Tensor perm: torch.Tensor @@ -256,11 +259,13 @@ class GPTQMarlinWeight(Weight): def repack_gptq_for_marlin( *, qweight: torch.Tensor, + qzeros: Optional[torch.Tensor], scales: torch.Tensor, - g_idx: torch.Tensor, + g_idx: Optional[torch.Tensor], bits: int, desc_act: bool, groupsize: int, + quant_method: str, sym: bool, sharded_infeatures: bool, ) -> GPTQMarlinWeight: @@ -279,30 +284,54 @@ def repack_gptq_for_marlin( raise RuntimeError( f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}" ) - if not sym: + if not (sym or quant_method == "awq"): raise RuntimeError( "Repacking GPTQ weights with asymmetric quantization as Marlin is not supported." ) + log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.") + weights_per_int = 32 // bits - in_features = qweight.shape[0] * weights_per_int + in_features = qweight.shape[0] out_features = qweight.shape[1] + # AWQ uses column packing, GPTQ uses row packing + if quant_method == "awq": + out_features *= weights_per_int + else: + in_features *= weights_per_int + if in_features % groupsize != 0: raise ValueError( f"Number of input features ({in_features}) not divisible by group size ({groupsize})" ) - if desc_act and groupsize != -1: + if g_idx is not None and desc_act and groupsize != -1: perm = torch.argsort(g_idx).to(torch.int) g_idx = g_idx[perm] else: perm = torch.empty(0, dtype=torch.int, device=qweight.device) g_idx = torch.empty(0, dtype=torch.int, device=qweight.device) - repacked = marlin_kernels.gptq_marlin_repack( - qweight, perm, in_features, out_features, bits - ) + if quant_method == "awq": + repacked = marlin_kernels.awq_marlin_repack( + qweight, in_features, out_features, bits + ) + if qzeros is not None: + qzeros = awq_to_marlin_zero_points( + qzeros, + in_features // groupsize, + out_features, + bits, + ) + + else: + repacked = marlin_kernels.gptq_marlin_repack( + qweight, perm, in_features, out_features, bits + ) + + if qzeros is None: + qzeros = torch.empty(0, dtype=torch.int, device=qweight.device) scales = permute_scales(scales) @@ -310,6 +339,7 @@ def repack_gptq_for_marlin( return GPTQMarlinWeight( qweight=repacked, + qzeros=qzeros, scales=scales, g_idx=g_idx, perm=perm, @@ -343,6 +373,7 @@ class GPTQMarlinLinear(nn.Module): self.is_full_k = weight.is_full_k self.qweight = weight.qweight + self.qzeros = weight.qzeros self.scales = weight.scales self.g_idx = weight.g_idx self.perm = weight.perm @@ -363,6 +394,7 @@ class GPTQMarlinLinear(nn.Module): A_flat, self.qweight, self.scales, + self.qzeros, self.g_idx, self.perm, self.workspace, @@ -371,6 +403,7 @@ class GPTQMarlinLinear(nn.Module): self.scales.shape[1], A_flat.shape[1], self.is_full_k, + self.qzeros.numel() > 0, ) C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) @@ -688,3 +721,116 @@ class MarlinLinear(nn.Module): C += self.bias return C + + +# Functions below are from vLLM + + +def get_pack_factor(bits: int) -> int: + if 32 % bits != 0: + raise ValueError(f"Cannot {bits} bit values into uint32") + return 32 // bits + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + zp = zp.reshape((-1, len(_scale_perm)))[:, _scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index c3c038fe..98dae079 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -34,7 +34,7 @@ def _get_quantizer_config(model_id, revision): groupsize = -1 quant_method = "gptq" checkpoint_format = None - sym = True + sym = False desc_act = False filename = "config.json" @@ -52,12 +52,17 @@ def _get_quantizer_config(model_id, revision): activation_scale_ub=data["quantization_config"]["activation_scale_ub"] ) + if "zero_point" in data["quantization_config"]: + sym = not data["quantization_config"]["zero_point"] + quant_method = "awq" + elif "sym" in data["quantization_config"]: + sym = data["quantization_config"]["sym"] + bits = data["quantization_config"]["bits"] groupsize = data["quantization_config"]["group_size"] # Order is important here, desc_act is missing on some real models quant_method = data["quantization_config"]["quant_method"] checkpoint_format = data["quantization_config"].get("checkpoint_format") - sym = data["quantization_config"]["sym"] desc_act = data["quantization_config"]["desc_act"] except Exception: filename = "quantize_config.json" @@ -72,7 +77,13 @@ def _get_quantizer_config(model_id, revision): data = json.load(f) bits = data["bits"] groupsize = data["group_size"] - sym = data["sym"] + + if "zero_point" in data: + sym = not data["zero_point"] + quant_method = "awq" + elif "sym" in data: + sym = data["sym"] + desc_act = data["desc_act"] if "version" in data and data["version"] == "GEMM": quant_method = "awq"