Add support for repacking AWQ weights for GPTQ-Marlin (#2278)

* Add support for repacking AWQ weights for GPTQ-Marlin

So far we couldn't support AWQ because virtually all AWQ models use
symmetric quantization, which GPTQ-Marlin did not suppors. GPTQ-Marlin
has recently added support AWQ repacking and AWQ asymmetric quantization
(zero_point=True).

This change updates all GPTQ-Marlin kernels from upstream and wires up
AWQ support. For now enabling AWQ using Marlin requires running TGI with
`--quantize gptq`.

* Enable Marlin for supported AWQ configurations by default

This makes the AWQ -> GPTQ repack test redundant, since we are now
testing this with the regular AWQ test.
This commit is contained in:
Daniël de Kok 2024-07-23 13:08:20 +02:00 committed by GitHub
parent 5fca30ee15
commit 9935720c87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1016 additions and 203 deletions

View File

@ -0,0 +1,269 @@
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {}
} // namespace marlin
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
namespace marlin {
template <int const num_threads, int const num_bits>
__global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
int start_k_tile = blockIdx.x * block_k_tiles;
if (start_k_tile >= k_tiles) {
return;
}
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait<repack_stages - 2>();
__syncthreads();
};
extern __shared__ int4 sh[];
constexpr int tile_n_ints = tile_n_size / pack_factor;
constexpr int stage_n_threads = tile_n_ints / 4;
constexpr int stage_k_threads = tile_k_size;
constexpr int stage_size = stage_k_threads * stage_n_threads;
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
cp_async_fence();
return;
}
int first_n = n_tile_id * tile_n_size;
int first_n_packed = first_n / pack_factor;
int4* sh_ptr = sh + stage_size * pipe;
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const*>(
&(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) +
first_n_packed + (n_id * 4)])));
}
cp_async_fence();
};
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
return;
}
int warp_id = threadIdx.x / 32;
int th_id = threadIdx.x % 32;
if (warp_id >= 4) {
return;
}
int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2;
constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col;
int cur_n_packed = cur_n / pack_factor;
int cur_n_pos = cur_n % pack_factor;
constexpr int sh_stride = tile_n_ints;
constexpr uint32_t mask = (1 << num_bits) - 1;
int4* sh_stage_ptr = sh + stage_size * pipe;
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
// Undo interleaving
int cur_n_pos_unpacked;
if constexpr (num_bits == 4) {
constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
cur_n_pos_unpacked = undo_pack[cur_n_pos];
} else {
constexpr int undo_pack[4] = {0, 2, 1, 3};
cur_n_pos_unpacked = undo_pack[cur_n_pos];
}
uint32_t vals[8];
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
sh_stride * cur_elem];
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else {
constexpr int pack_idx[4] = {0, 2, 1, 3};
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
}
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
}
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
}
wait_for_stage();
};
#pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0;
start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1);
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
wait_for_stage();
}
n_tile_id += repack_stages;
}
}
}
} // namespace marlin
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", marlin::tile_k_size);
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits;
// Verify B
TORCH_CHECK(b_q_weight.size(0) == size_k,
"b_q_weight.size(0) = ", b_q_weight.size(0),
" is not size_k = ", size_k);
TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1),
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
", size_n = ", size_n, ", pack_factor = ", pack_factor);
// Verify device and strides
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
torch::Tensor out = torch::empty(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
// Get ptrs
uint32_t const* b_q_weight_ptr =
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
// Get dev info
int dev = b_q_weight.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
int blocks;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
if (false) {
}
CALL_IF(4)
CALL_IF(8)
else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
}
return out;
}
#endif

View File

@ -3,6 +3,8 @@
#include "ext.hh" #include "ext.hh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("awq_marlin_repack", &awq_marlin_repack,
"Repack AWQ parameters for Marlin");
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, m.def("gptq_marlin_gemm", &gptq_marlin_gemm,
"Marlin gemm with GPTQ compatibility"); "Marlin gemm with GPTQ compatibility");
m.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin sparse 2:4 gemm"); m.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin sparse 2:4 gemm");

View File

@ -6,11 +6,15 @@
// No support for async // No support for async
#else #else
torch::Tensor awq_marlin_repack(torch::Tensor &b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits);
torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &g_idx, torch::Tensor &b_scales, torch::Tensor &b_zeros,
torch::Tensor &perm, torch::Tensor &workspace, torch::Tensor &g_idx, torch::Tensor &perm,
int64_t num_bits, int64_t size_m, int64_t size_n, torch::Tensor &workspace, int64_t num_bits,
int64_t size_k, bool is_k_full); int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp);
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_meta, torch::Tensor &b_meta,
@ -27,8 +31,8 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &workspace, torch::Tensor &b_scales, torch::Tensor &workspace,
int64_t size_m, int64_t size_n, int64_t size_k); int64_t size_m, int64_t size_n, int64_t size_k);
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor fp8_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace, torch::Tensor &b_scales, torch::Tensor &workspace,
int64_t num_bits, int64_t size_m, int64_t size_n, int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k); int64_t size_k);

View File

@ -19,10 +19,10 @@
* Adapted from https://github.com/IST-DASLab/marlin * Adapted from https://github.com/IST-DASLab/marlin
*/ */
#include "./gptq_marlin.cuh" #include "marlin.cuh"
#include "./gptq_marlin_dtypes.cuh" #include "marlin_dtypes.cuh"
using namespace gptq_marlin; using namespace marlin;
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \ static_assert(std::is_same<scalar_t, half>::value || \
@ -1224,16 +1224,15 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
", size_k = ", size_k); ", size_k = ", size_k);
// Verify B // Verify B
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
" is not divisible by tile_size = ", gptq_marlin::tile_size); " is not divisible by tile_size = ", marlin::tile_size);
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); ", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
"b_q_weight.size(1) = ", b_q_weight.size(1), "b_q_weight.size(1) = ", b_q_weight.size(1),
" is not divisible by tile_size = ", gptq_marlin::tile_size); " is not divisible by tile_size = ", marlin::tile_size);
int actual_size_n = int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
", actual_size_n = ", actual_size_n); ", actual_size_n = ", actual_size_n);
@ -1274,11 +1273,9 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
num_groups = b_scales.size(0); num_groups = b_scales.size(0);
// Verify workspace size // Verify workspace size
TORCH_CHECK( TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, ", is not divisible by min_thread_n = ", marlin::min_thread_n);
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
int min_workspace_size =
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
TORCH_CHECK(workspace.numel() >= min_workspace_size, TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = ", workspace.numel(), "workspace.numel = ", workspace.numel(),
" is below min_workspace_size = ", min_workspace_size); " is below min_workspace_size = ", min_workspace_size);
@ -1290,14 +1287,14 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
b_scales.data_ptr<at::Half>(), size_m, size_n, size_k, b_scales.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), num_bits, num_groups, group_size, dev, workspace.data_ptr(), num_bits, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
gptq_marlin::max_par); marlin::max_par);
} else if (a.scalar_type() == at::ScalarType::BFloat16) { } else if (a.scalar_type() == at::ScalarType::BFloat16) {
fp8_marlin::marlin_mm_f16i4<nv_bfloat16>( fp8_marlin::marlin_mm_f16i4<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(), a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m, c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m,
size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size,
dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
gptq_marlin::max_par); marlin::max_par);
} else { } else {
TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16"); TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16");
} }

View File

@ -19,8 +19,8 @@
* Adapted from https://github.com/IST-DASLab/marlin * Adapted from https://github.com/IST-DASLab/marlin
*/ */
#include "gptq_marlin.cuh" #include "marlin.cuh"
#include "gptq_marlin_dtypes.cuh" #include "marlin_dtypes.cuh"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \ static_assert(std::is_same<scalar_t, half>::value || \
@ -32,7 +32,7 @@ inline std::string str(T x) {
return std::to_string(x); return std::to_string(x);
} }
namespace gptq_marlin { namespace marlin {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
@ -72,10 +72,11 @@ __global__ void Marlin(
} // namespace gptq_marlin } // namespace gptq_marlin
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& g_idx, torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& perm, torch::Tensor& workspace, torch::Tensor& g_idx, torch::Tensor& perm,
int64_t num_bits, int64_t size_m, int64_t size_n, torch::Tensor& workspace, int64_t num_bits,
int64_t size_k, bool is_k_full) { int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full) {
TORCH_CHECK_NOT_IMPLEMENTED(false, TORCH_CHECK_NOT_IMPLEMENTED(false,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"); "marlin_gemm(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1}); return torch::empty({1, 1});
@ -264,6 +265,114 @@ dequant_8bit<nv_bfloat16>(int q) {
return frag_b; return frag_b;
} }
// Zero-point dequantizers
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit_zp(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
template <>
__device__ inline typename ScalarType<half>::FragB dequant_4bit_zp<half>(
int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
typename ScalarType<half>::FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_4bit_zp<nv_bfloat16>(int q) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
typename ScalarType<nv_bfloat16>::FragB frag_b;
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC300C300;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
return frag_b;
}
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit_zp(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
template <>
__device__ inline typename ScalarType<half>::FragB dequant_8bit_zp<half>(
int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
typename ScalarType<half>::FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_8bit_zp<nv_bfloat16>(int q) {
typename ScalarType<nv_bfloat16>::FragB frag_b;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388608.f;
fp32_intermediates[1] -= 8388608.f;
fp32_intermediates[2] -= 8388608.f;
fp32_intermediates[3] -= 8388608.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used // Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization. // only for grouped quantization.
template <typename scalar_t> template <typename scalar_t>
@ -277,6 +386,17 @@ __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
frag_b[1] = __hmul2(frag_b[1], s); frag_b[1] = __hmul2(frag_b[1], s);
} }
template <typename scalar_t>
__device__ inline void sub_zp(typename ScalarType<scalar_t>::FragB& frag_b,
typename ScalarType<scalar_t>::scalar_t2& frag_zp,
int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
scalar_t2 zp =
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]);
frag_b[0] = __hsub2(frag_b[0], zp);
frag_b[1] = __hsub2(frag_b[1], zp);
}
// Same as above, but for act_order (each K is multiplied individually) // Same as above, but for act_order (each K is multiplied individually)
template <typename scalar_t> template <typename scalar_t>
__device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b, __device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
@ -404,6 +524,7 @@ template <typename scalar_t, // compute dtype, half or nv_float16
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const bool has_act_order, // whether act_order is enabled const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale // with a separate quantization scale
> >
@ -413,6 +534,8 @@ __global__ void Marlin(
int4* __restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn // (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ g_idx, // int32 group indices of shape k
int num_groups, // number of scale groups per output channel int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m int prob_m, // batch dimension m
@ -437,6 +560,7 @@ __global__ void Marlin(
using FragB = typename ScalarType<scalar_t>::FragB; using FragB = typename ScalarType<scalar_t>::FragB;
using FragC = typename ScalarType<scalar_t>::FragC; using FragC = typename ScalarType<scalar_t>::FragC;
using FragS = typename ScalarType<scalar_t>::FragS; using FragS = typename ScalarType<scalar_t>::FragS;
using FragZP = typename ScalarType<scalar_t>::FragZP;
constexpr int pack_factor = 32 / num_bits; constexpr int pack_factor = 32 / num_bits;
@ -566,6 +690,13 @@ __global__ void Marlin(
int tb_n_warps = thread_n_blocks / 4; int tb_n_warps = thread_n_blocks / 4;
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
// Zero-points sizes/strides
int zp_gl_stride = (prob_n / pack_factor) / 4;
constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4;
constexpr int zp_tb_groups = s_tb_groups;
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
int zp_gl_rd_delta = zp_gl_stride;
// Global A read index of current thread. // Global A read index of current thread.
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o); (threadIdx.x % a_gl_rd_delta_o);
@ -605,6 +736,19 @@ __global__ void Marlin(
int s_sh_wr = threadIdx.x; int s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride; bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
// Zero-points
int zp_gl_rd;
if constexpr (has_zp) {
if constexpr (group_blocks == -1) {
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} else {
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
zp_sh_stride * slice_col + threadIdx.x;
}
}
int zp_sh_wr = threadIdx.x;
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
// We use a different scale layout for grouped and column-wise quantization as // We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in // we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case. // row-major in the latter case.
@ -616,6 +760,18 @@ __global__ void Marlin(
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4; (threadIdx.x % 32) % 4;
// Zero-points have the same read layout as the scales
// (without column-wise case)
constexpr int num_col_threads = 8;
constexpr int num_row_threads = 4;
constexpr int num_ints_per_thread = 8 / pack_factor;
int zp_sh_rd;
if constexpr (has_zp) {
zp_sh_rd = num_ints_per_thread * num_col_threads *
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
}
// Precompute which thread should not read memory in which iterations; this is // Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or // needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16. // when the batchsize is not a multiple of 16.
@ -664,14 +820,17 @@ __global__ void Marlin(
int4* sh_a = sh; int4* sh_a = sh;
int4* sh_b = sh_a + (stages * a_sh_stage); int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_g_idx = sh_b + (stages * b_sh_stage); int4* sh_g_idx = sh_b + (stages * b_sh_stage);
int4* sh_s = sh_g_idx + (stages * g_idx_stage); int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage);
// Register storage for double buffer of shared memory reads. // Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks]; FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2][b_thread_vecs]; I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2]; FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4]; // No act-order FragS frag_s[2][4]; // No act-order
FragS act_frag_s[2][4][4]; // For act-order FragS act_frag_s[2][4][4]; // For act-order
int frag_qzp[2][num_ints_per_thread]; // Zero-points
FragZP frag_zp; // Zero-points in fp16
// Zero accumulators. // Zero accumulators.
auto zero_accums = [&]() { auto zero_accums = [&]() {
@ -777,6 +936,28 @@ __global__ void Marlin(
} }
} }
} }
if constexpr (has_zp && group_blocks != -1) {
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch zero-points if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
} else {
for (int i = 0; i < zp_tb_groups; i++) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
&zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
}
}
} }
} }
// Insert a fence even when we are winding down the pipeline to ensure that // Insert a fence even when we are winding down the pipeline to ensure that
@ -784,6 +965,12 @@ __global__ void Marlin(
cp_async_fence(); cp_async_fence();
}; };
auto fetch_zp_to_shared = [&]() {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
};
// Wait until the next thread tile has been loaded to shared memory. // Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() { auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering // We only have `stages - 2` active fetches since we are double buffering
@ -932,8 +1119,73 @@ __global__ void Marlin(
} }
}; };
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
if constexpr (!has_zp) {
return;
}
int pipe = full_pipe % stages;
if constexpr (group_blocks == -1) {
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
}
} else if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
} else {
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
int cur_group_id = k_blocks / group_blocks;
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
sh_zp_stage += cur_group_id * zp_sh_stride;
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
}
};
// Execute the actual tensor core matmul of a sub-tile. // Execute the actual tensor core matmul of a sub-tile.
auto matmul = [&](int k) { auto matmul = [&](int k) {
if constexpr (has_zp) {
FragB frag_zp_0;
FragB frag_zp_1;
if constexpr (num_bits == 4) {
int zp_quant = frag_qzp[k % 2][0];
int zp_quant_shift = zp_quant >> 8;
frag_zp_0 = dequant_4bit_zp<scalar_t>(zp_quant);
frag_zp_1 = dequant_4bit_zp<scalar_t>(zp_quant_shift);
} else {
int zp_quant_0 = frag_qzp[k % 2][0];
int zp_quant_1 = frag_qzp[k % 2][1];
frag_zp_0 = dequant_8bit_zp<scalar_t>(zp_quant_0);
frag_zp_1 = dequant_8bit_zp<scalar_t>(zp_quant_1);
}
frag_zp[0] = frag_zp_0[0];
frag_zp[1] = frag_zp_0[1];
frag_zp[2] = frag_zp_1[0];
frag_zp[3] = frag_zp_1[1];
}
// We have the m dimension as the inner loop in order to encourage overlapping // We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations. // dequantization and matmul operations.
#pragma unroll #pragma unroll
@ -944,16 +1196,32 @@ __global__ void Marlin(
int b_quant = frag_b_quant[k % 2][0][j]; int b_quant = frag_b_quant[k % 2][0][j];
int b_quant_shift = b_quant >> 8; int b_quant_shift = b_quant >> 8;
frag_b0 = dequant_4bit<scalar_t>(b_quant); if constexpr (has_zp) {
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift); frag_b0 = dequant_4bit_zp<scalar_t>(b_quant);
frag_b1 = dequant_4bit_zp<scalar_t>(b_quant_shift);
} else {
frag_b0 = dequant_4bit<scalar_t>(b_quant);
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
}
} else { } else {
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]); int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
frag_b0 = dequant_8bit<scalar_t>(b_quant_0); if constexpr (has_zp) {
frag_b1 = dequant_8bit<scalar_t>(b_quant_1); frag_b0 = dequant_8bit_zp<scalar_t>(b_quant_0);
frag_b1 = dequant_8bit_zp<scalar_t>(b_quant_1);
} else {
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
}
}
// Apply zero-point to frag_b0
if constexpr (has_zp) {
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
} }
// Apply scale to frag_b0 // Apply scale to frag_b0
@ -967,6 +1235,11 @@ __global__ void Marlin(
} }
} }
// Apply zero-point to frag_b1
if constexpr (has_zp) {
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
}
// Apply scale to frag_b1 // Apply scale to frag_b1
if constexpr (has_act_order) { if constexpr (has_act_order) {
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j], scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
@ -1189,6 +1462,12 @@ __global__ void Marlin(
} }
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
} }
if constexpr (has_zp && group_blocks == -1) {
if (i == 0) {
fetch_zp_to_shared();
}
}
fetch_to_shared(i, i, i < slice_iters); fetch_to_shared(i, i, i < slice_iters);
} }
@ -1197,6 +1476,7 @@ __global__ void Marlin(
init_same_group(0); init_same_group(0);
fetch_to_registers(0, 0); fetch_to_registers(0, 0);
fetch_scales_to_registers(0, 0); fetch_scales_to_registers(0, 0);
fetch_zp_to_registers(0, 0);
a_gl_rd += a_gl_rd_delta_o * (stages - 1); a_gl_rd += a_gl_rd_delta_o * (stages - 1);
slice_k_start_shared_fetch += tb_k * (stages - 1); slice_k_start_shared_fetch += tb_k * (stages - 1);
}; };
@ -1217,6 +1497,7 @@ __global__ void Marlin(
for (int k = 0; k < b_sh_wr_iters; k++) { for (int k = 0; k < b_sh_wr_iters; k++) {
fetch_to_registers(k + 1, pipe % stages); fetch_to_registers(k + 1, pipe % stages);
fetch_scales_to_registers(k + 1, pipe); fetch_scales_to_registers(k + 1, pipe);
fetch_zp_to_registers(k + 1, pipe);
if (k == b_sh_wr_iters - 2) { if (k == b_sh_wr_iters - 2) {
fetch_to_shared((pipe + stages - 1) % stages, pipe, fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages); slice_iters >= stages);
@ -1354,6 +1635,7 @@ __global__ void Marlin(
} else { } else {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x; s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} }
start_pipes(); start_pipes();
@ -1363,22 +1645,24 @@ __global__ void Marlin(
} }
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
NUM_THREADS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
num_threads == NUM_THREADS) { \ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \ Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \ THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
GROUP_BLOCKS>, \ HAS_ZP, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \ Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \ THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \ HAS_ZP, GROUP_BLOCKS> \
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
prob_k, locks); \ A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \
prob_m, prob_n, prob_k, locks); \
} }
typedef struct { typedef struct {
@ -1548,39 +1832,61 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
return exec_config_t{0, {-1, -1, -1}}; return exec_config_t{0, {-1, -1, -1}};
} }
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ #define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
\ \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\ \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\ \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\ \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
#define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
template <typename scalar_t> template <typename scalar_t>
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
void* g_idx, void* perm, void* a_tmp, int prob_m, void* g_idx, void* perm, void* a_tmp, int prob_m,
int prob_n, int prob_k, void* workspace, int num_bits, int prob_n, int prob_k, void* workspace, int num_bits,
bool has_act_order, bool is_k_full, int num_groups, bool has_act_order, bool is_k_full, bool has_zp,
int group_size, int dev, cudaStream_t stream, int thread_k, int num_groups, int group_size, int dev,
int thread_n, int sms, int max_par) { cudaStream_t stream, int thread_k, int thread_n, int sms,
int max_par) {
TORCH_CHECK(num_bits == 4 || num_bits == 8, TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits); "num_bits must be 4 or 8. Got = ", num_bits);
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
@ -1665,6 +1971,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
const int4* B_ptr = (const int4*)B; const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C; int4* C_ptr = (int4*)C;
const int4* s_ptr = (const int4*)s; const int4* s_ptr = (const int4*)s;
const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx; const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm; const int* perm_ptr = (const int*)perm;
int4* a_tmp_ptr = (int4*)a_tmp; int4* a_tmp_ptr = (int4*)a_tmp;
@ -1701,28 +2008,33 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
thread_m_blocks = exec_cfg.max_m_blocks; thread_m_blocks = exec_cfg.max_m_blocks;
} }
// Define kernel configurations
if (false) { if (false) {
} }
CALL_IF(4, 32, 2, 256) GPTQ_CALL_IF(4, 16, 4, 256)
CALL_IF(4, 16, 4, 256) GPTQ_CALL_IF(4, 8, 8, 256)
CALL_IF(4, 8, 8, 256) GPTQ_CALL_IF(4, 8, 4, 128)
CALL_IF(4, 8, 4, 128) GPTQ_CALL_IF(4, 4, 8, 128)
CALL_IF(4, 4, 8, 128) GPTQ_CALL_IF(8, 16, 4, 256)
CALL_IF(8, 32, 2, 256) GPTQ_CALL_IF(8, 8, 8, 256)
CALL_IF(8, 16, 4, 256) GPTQ_CALL_IF(8, 8, 4, 128)
CALL_IF(8, 8, 8, 256) GPTQ_CALL_IF(8, 4, 8, 128)
CALL_IF(8, 8, 4, 128)
CALL_IF(8, 4, 8, 128) AWQ_CALL_IF(4, 16, 4, 256)
AWQ_CALL_IF(4, 8, 8, 256)
AWQ_CALL_IF(4, 8, 4, 128)
AWQ_CALL_IF(4, 4, 8, 128)
AWQ_CALL_IF(8, 16, 4, 256)
AWQ_CALL_IF(8, 8, 8, 256)
AWQ_CALL_IF(8, 8, 4, 128)
AWQ_CALL_IF(8, 4, 8, 128)
else { else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
str(prob_n) + ", " + str(prob_k) + "]" + ", ", prob_k, "]", ", has_act_order = ", has_act_order,
", has_act_order = " + str(has_act_order) + ", num_groups = ", num_groups, ", group_size = ", group_size,
", num_groups = " + str(num_groups) + ", thread_m_blocks = ", thread_m_blocks,
", group_size = " + str(group_size) + ", thread_n_blocks = ", thread_n_blocks,
", thread_m_blocks = " + str(thread_m_blocks) + ", thread_k_blocks = ", thread_k_blocks,
", thread_n_blocks = " + str(thread_n_blocks) + ", num_bits = ", num_bits);
", thread_k_blocks = " + str(thread_k_blocks));
} }
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
@ -1733,10 +2045,11 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
} // namespace gptq_marlin } // namespace gptq_marlin
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& g_idx, torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& perm, torch::Tensor& workspace, torch::Tensor& g_idx, torch::Tensor& perm,
int64_t num_bits, int64_t size_m, int64_t size_n, torch::Tensor& workspace, int64_t num_bits,
int64_t size_k, bool is_k_full) { int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp) {
// Verify num_bits // Verify num_bits
TORCH_CHECK(num_bits == 4 || num_bits == 8, TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits); "num_bits must be 4 or 8. Got = ", num_bits);
@ -1749,16 +2062,15 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
", size_k = ", size_k); ", size_k = ", size_k);
// Verify B // Verify B
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
" is not divisible by tile_size = ", gptq_marlin::tile_size); " is not divisible by tile_size = ", marlin::tile_size);
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); ", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
"b_q_weight.size(1) = ", b_q_weight.size(1), "b_q_weight.size(1) = ", b_q_weight.size(1),
" is not divisible by tile_size = ", gptq_marlin::tile_size); " is not divisible by tile_size = ", marlin::tile_size);
int actual_size_n = int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
", actual_size_n = ", actual_size_n); ", actual_size_n = ", actual_size_n);
@ -1772,6 +2084,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
@ -1805,8 +2120,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int group_size = -1; int group_size = -1;
bool has_act_order = g_idx.size(0) != 0; bool has_act_order = g_idx.size(0) != 0;
int b_rank = b_scales.sizes().size(); int rank = b_scales.sizes().size();
TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2");
TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
" is not size_n = ", size_n); " is not size_n = ", size_n);
num_groups = b_scales.size(0); num_groups = b_scales.size(0);
@ -1832,34 +2147,44 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
} }
} }
// Verify b_zeros
if (has_zp) {
int rank = b_zeros.sizes().size();
TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2");
TORCH_CHECK(b_zeros.size(0) == num_groups,
"b_zeros dim 0 = ", b_zeros.size(0),
" is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
"b_zeros dim 1 = ", b_scales.size(1),
" is not size_n / pack_factor = ", size_n / pack_factor);
}
// Verify workspace size // Verify workspace size
TORCH_CHECK( TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, ", is not divisible by min_thread_n = ", marlin::min_thread_n);
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
int min_workspace_size =
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
TORCH_CHECK(workspace.numel() >= min_workspace_size, TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = ", workspace.numel(), "workspace.numel = ", workspace.numel(),
" is below min_workspace_size = ", min_workspace_size); " is below min_workspace_size = ", min_workspace_size);
int dev = a.get_device(); int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) { if (a.scalar_type() == at::ScalarType::Half) {
gptq_marlin::marlin_mm_f16i4<half>( marlin::marlin_mm_f16i4<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
b_scales.data_ptr<at::Half>(), g_idx.data_ptr(), perm.data_ptr(), b_scales.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k, perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups, workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_n, sms, gptq_marlin::max_par); thread_k, thread_n, sms, marlin::max_par);
} else if (a.scalar_type() == at::ScalarType::BFloat16) { } else if (a.scalar_type() == at::ScalarType::BFloat16) {
gptq_marlin::marlin_mm_f16i4<nv_bfloat16>( marlin::marlin_mm_f16i4<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(), a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order, a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
is_k_full, num_groups, group_size, dev, workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
gptq_marlin::max_par); thread_k, thread_n, sms, marlin::max_par);
} else { } else {
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
} }

View File

@ -1,23 +1,16 @@
#include "gptq_marlin.cuh" #include "marlin.cuh"
namespace gptq_marlin {
static constexpr int repack_stages = 8;
static constexpr int repack_threads = 256;
static constexpr int tile_k_size = tile_size;
static constexpr int tile_n_size = tile_k_size * 4;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm> template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void marlin_repack_kernel( __global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {} int size_k, int size_n) {}
} // namespace gptq_marlin } // namespace marlin
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n, int64_t size_k, int64_t size_n,
@ -29,8 +22,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
#else #else
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm> template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void marlin_repack_kernel( __global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) { int size_k, int size_n) {
@ -259,28 +254,28 @@ __global__ void marlin_repack_kernel(
} }
} }
} // namespace gptq_marlin } // namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \ #define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \ marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
NUM_BITS, HAS_PERM>, \ HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \ marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \ HAS_PERM> \
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \ <<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
} }
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n, int64_t size_k, int64_t size_n,
int64_t num_bits) { int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64 // Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k, TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size); " is not divisible by tile_k_size = ", marlin::tile_k_size);
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n, TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size); " is not divisible by tile_n_size = ", marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8, TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits); "num_bits must be 4 or 8. Got = ", num_bits);
@ -308,10 +303,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
auto options = torch::TensorOptions() auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype()) .dtype(b_q_weight.dtype())
.device(b_q_weight.device()); .device(b_q_weight.device());
torch::Tensor out = torch::Tensor out = torch::empty(
torch::empty({size_k / gptq_marlin::tile_size, {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
size_n * gptq_marlin::tile_size / pack_factor}, options);
options);
// Detect if there is act_order // Detect if there is act_order
bool has_perm = perm.size(0) != 0; bool has_perm = perm.size(0) != 0;

View File

@ -9,7 +9,9 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <iostream> #include <iostream>
namespace gptq_marlin { namespace marlin {
// Marlin params
// 8 warps are a good choice since every SM has 4 schedulers and having more // 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time, // than 1 warp per schedule allows some more latency hiding. At the same time,
@ -25,6 +27,15 @@ static constexpr int min_thread_k = 64;
static constexpr int tile_size = 16; static constexpr int tile_size = 16;
static constexpr int max_par = 16; static constexpr int max_par = 16;
// Repack params
static constexpr int repack_stages = 8;
static constexpr int repack_threads = 256;
static constexpr int tile_k_size = tile_size;
static constexpr int tile_n_size = tile_k_size * 4;
// Helpers
template <typename T, int n> template <typename T, int n>
struct Vec { struct Vec {
T elems[n]; T elems[n];
@ -73,4 +84,4 @@ __device__ inline void cp_async_wait() {
#endif #endif
} // namespace gptq_marlin } // namespace marlin

View File

@ -30,7 +30,7 @@ inline std::string str(T x) {
return std::to_string(x); return std::to_string(x);
} }
namespace marlin { namespace marlin_dense {
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
@ -1040,7 +1040,7 @@ void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m,
} }
} }
} // namespace marlin } // namespace marlin_dense
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace, torch::Tensor& b_scales, torch::Tensor& workspace,
@ -1054,24 +1054,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_CHECK(size_k == a.size(1), TORCH_CHECK(size_k == a.size(1),
"Shape mismatch: a.size(1) = " + str(a.size(1)) + "Shape mismatch: a.size(1) = " + str(a.size(1)) +
", size_k = " + str(size_k)); ", size_k = " + str(size_k));
TORCH_CHECK(size_k % marlin::tile_size == 0, TORCH_CHECK(size_k % marlin_dense::tile_size == 0,
"size_k = " + str(size_k) + "size_k = " + str(size_k) + " is not divisible by tile_size = " +
" is not divisible by tile_size = " + str(marlin::tile_size)); str(marlin_dense::tile_size));
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = " + "Shape mismatch: b_q_weight.size(0) = " +
str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
", tile_size = " + str(marlin::tile_size)); ", tile_size = " + str(marlin_dense::tile_size));
// Verify N // Verify N
TORCH_CHECK(b_scales.size(1) == size_n, TORCH_CHECK(b_scales.size(1) == size_n,
"b_scales.size(1) = " + str(b_scales.size(1)) + "b_scales.size(1) = " + str(b_scales.size(1)) +
", size_n = " + str(size_n)); ", size_n = " + str(size_n));
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, TORCH_CHECK(
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) + b_q_weight.size(1) % marlin_dense::tile_size == 0,
" is not divisible by tile_size = " + str(marlin::tile_size)); "b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
" is not divisible by tile_size = " + str(marlin_dense::tile_size));
int actual_size_n = int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) *
(b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit; marlin_dense::pack_factor_4bit;
TORCH_CHECK( TORCH_CHECK(
size_n == actual_size_n, size_n == actual_size_n,
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
@ -1116,21 +1117,22 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
"Unexpected groupsize = " + str(groupsize)); "Unexpected groupsize = " + str(groupsize));
// Verify workspace size // Verify workspace size
TORCH_CHECK( TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0,
size_n % marlin::min_thread_n == 0, "size_n = " + str(size_n) +
"size_n = " + str(size_n) + ", is not divisible by min_thread_n = " +
", is not divisible by min_thread_n = " + str(marlin::min_thread_n)); str(marlin_dense::min_thread_n));
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; int min_workspace_size =
(size_n / marlin_dense::min_thread_n) * marlin_dense::max_par;
TORCH_CHECK(workspace.numel() >= min_workspace_size, TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = " + str(workspace.numel()) + "workspace.numel = " + str(workspace.numel()) +
" is below min_workspace_size = " + str(min_workspace_size)); " is below min_workspace_size = " + str(min_workspace_size));
int dev = a.get_device(); int dev = a.get_device();
marlin::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
b_scales.data_ptr(), size_m, size_n, size_k, b_scales.data_ptr(), size_m, size_n, size_k,
workspace.data_ptr(), groupsize, dev, workspace.data_ptr(), groupsize, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, at::cuda::getCurrentCUDAStream(dev), thread_k,
sms, marlin::max_par); thread_n, sms, marlin_dense::max_par);
return c; return c;
} }

View File

@ -1,11 +1,11 @@
#ifndef _data_types_cuh #ifndef _data_types_cuh
#define _data_types_cuh #define _data_types_cuh
#include "gptq_marlin.cuh" #include "marlin.cuh"
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
namespace gptq_marlin { namespace marlin {
template <typename scalar_t> template <typename scalar_t>
class ScalarType {}; class ScalarType {};
@ -23,6 +23,7 @@ class ScalarType<half> {
using FragB = Vec<half2, 2>; using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; using FragS = Vec<half2, 1>;
using FragZP = Vec<half2, 4>;
static __device__ float inline num2float(const half x) { static __device__ float inline num2float(const half x) {
return __half2float(x); return __half2float(x);
@ -51,6 +52,7 @@ class ScalarType<nv_bfloat16> {
using FragB = Vec<nv_bfloat162, 2>; using FragB = Vec<nv_bfloat162, 2>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<nv_bfloat162, 1>; using FragS = Vec<nv_bfloat162, 1>;
using FragZP = Vec<nv_bfloat162, 4>;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) { static __device__ float inline num2float(const nv_bfloat16 x) {
@ -72,6 +74,6 @@ class ScalarType<nv_bfloat16> {
#endif #endif
}; };
} // namespace gptq_marlin } // namespace marlin
#endif #endif

View File

@ -9,6 +9,7 @@ setup(
CUDAExtension( CUDAExtension(
name="marlin_kernels", name="marlin_kernels",
sources=[ sources=[
"marlin_kernels/awq_marlin_repack.cu",
"marlin_kernels/fp8_marlin.cu", "marlin_kernels/fp8_marlin.cu",
"marlin_kernels/gptq_marlin.cu", "marlin_kernels/gptq_marlin.cu",
"marlin_kernels/gptq_marlin_repack.cu", "marlin_kernels/gptq_marlin_repack.cu",

View File

@ -156,16 +156,26 @@ class GPTQWeightsLoader(WeightsLoader):
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
) )
g_idx = weights.get_tensor(f"{prefix}.g_idx") if not self.sym:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
scales = weights.get_tensor(f"{prefix}.scales") scales = weights.get_tensor(f"{prefix}.scales")
return repack_gptq_for_marlin( return repack_gptq_for_marlin(
qweight=qweight, qweight=qweight,
scales=scales, scales=scales,
qzeros=qzeros,
g_idx=g_idx, g_idx=g_idx,
bits=self.bits, bits=self.bits,
desc_act=self.desc_act, desc_act=self.desc_act,
groupsize=self.groupsize, groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym, sym=self.sym,
sharded_infeatures=False, sharded_infeatures=False,
) )
@ -275,14 +285,26 @@ class GPTQWeightsLoader(WeightsLoader):
quantize=self.quantize, quantize=self.quantize,
sym=self.sym, sym=self.sym,
): ):
g_idx = weights.get_tensor(f"{prefix}.g_idx") if not self.sym:
qzeros = weights.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin( return repack_gptq_for_marlin(
qweight=qweight, qweight=qweight,
scales=scales, scales=scales,
qzeros=qzeros,
g_idx=g_idx, g_idx=g_idx,
bits=self.bits, bits=self.bits,
desc_act=self.desc_act, desc_act=self.desc_act,
groupsize=self.groupsize, groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym, sym=self.sym,
sharded_infeatures=False, sharded_infeatures=False,
) )
@ -349,18 +371,31 @@ class GPTQWeightsLoader(WeightsLoader):
quantize=self.quantize, quantize=self.quantize,
sym=self.sym, sym=self.sym,
): ):
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]: if not self.sym:
torch.testing.assert_close(w2, w[0]) qzeros = torch.cat(
g_idx = w[0] [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin( return repack_gptq_for_marlin(
qweight=qweight, qweight=qweight,
scales=scales, scales=scales,
qzeros=qzeros,
g_idx=g_idx, g_idx=g_idx,
bits=self.bits, bits=self.bits,
desc_act=self.desc_act, desc_act=self.desc_act,
groupsize=self.groupsize, groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym, sym=self.sym,
sharded_infeatures=False, sharded_infeatures=False,
) )
@ -438,7 +473,19 @@ class GPTQWeightsLoader(WeightsLoader):
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
) )
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) if not self.sym:
if self.desc_act or self.groupsize == -1:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
else:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
if self.desc_act or self.groupsize == -1: if self.desc_act or self.groupsize == -1:
scales = weights.get_tensor(f"{prefix}.scales") scales = weights.get_tensor(f"{prefix}.scales")
else: else:
@ -449,10 +496,12 @@ class GPTQWeightsLoader(WeightsLoader):
return repack_gptq_for_marlin( return repack_gptq_for_marlin(
qweight=qweight, qweight=qweight,
scales=scales, scales=scales,
qzeros=qzeros,
g_idx=g_idx, g_idx=g_idx,
bits=self.bits, bits=self.bits,
desc_act=self.desc_act, desc_act=self.desc_act,
groupsize=self.groupsize, groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym, sym=self.sym,
sharded_infeatures=sharded_in_features, sharded_infeatures=sharded_in_features,
) )

View File

@ -1,6 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy
import torch import torch
import torch.nn as nn import torch.nn as nn
from loguru import logger from loguru import logger
@ -174,11 +175,12 @@ def can_use_gptq_marlin(
SYSTEM == "cuda" SYSTEM == "cuda"
and marlin_kernels is not None and marlin_kernels is not None
and has_sm_8_0 and has_sm_8_0
and quantize == "gptq" and quantize in {"awq", "gptq"}
and quant_method == "gptq" and quant_method in {"awq", "gptq"}
and bits in GPTQ_MARLIN_BITS and bits in GPTQ_MARLIN_BITS
and groupsize in GPTQ_MARLIN_GROUP_SIZES and groupsize in GPTQ_MARLIN_GROUP_SIZES
and sym # We only suppord asymmetric quantization for AWQ.
and (sym or quant_method == "awq")
) )
@ -234,6 +236,7 @@ class GPTQMarlinWeight(Weight):
""" """
qweight: torch.Tensor qweight: torch.Tensor
qzeros: torch.Tensor
scales: torch.Tensor scales: torch.Tensor
g_idx: torch.Tensor g_idx: torch.Tensor
perm: torch.Tensor perm: torch.Tensor
@ -256,11 +259,13 @@ class GPTQMarlinWeight(Weight):
def repack_gptq_for_marlin( def repack_gptq_for_marlin(
*, *,
qweight: torch.Tensor, qweight: torch.Tensor,
qzeros: Optional[torch.Tensor],
scales: torch.Tensor, scales: torch.Tensor,
g_idx: torch.Tensor, g_idx: Optional[torch.Tensor],
bits: int, bits: int,
desc_act: bool, desc_act: bool,
groupsize: int, groupsize: int,
quant_method: str,
sym: bool, sym: bool,
sharded_infeatures: bool, sharded_infeatures: bool,
) -> GPTQMarlinWeight: ) -> GPTQMarlinWeight:
@ -279,30 +284,54 @@ def repack_gptq_for_marlin(
raise RuntimeError( raise RuntimeError(
f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}" f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
) )
if not sym: if not (sym or quant_method == "awq"):
raise RuntimeError( raise RuntimeError(
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported." "Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
) )
log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.")
weights_per_int = 32 // bits weights_per_int = 32 // bits
in_features = qweight.shape[0] * weights_per_int in_features = qweight.shape[0]
out_features = qweight.shape[1] out_features = qweight.shape[1]
# AWQ uses column packing, GPTQ uses row packing
if quant_method == "awq":
out_features *= weights_per_int
else:
in_features *= weights_per_int
if in_features % groupsize != 0: if in_features % groupsize != 0:
raise ValueError( raise ValueError(
f"Number of input features ({in_features}) not divisible by group size ({groupsize})" f"Number of input features ({in_features}) not divisible by group size ({groupsize})"
) )
if desc_act and groupsize != -1: if g_idx is not None and desc_act and groupsize != -1:
perm = torch.argsort(g_idx).to(torch.int) perm = torch.argsort(g_idx).to(torch.int)
g_idx = g_idx[perm] g_idx = g_idx[perm]
else: else:
perm = torch.empty(0, dtype=torch.int, device=qweight.device) perm = torch.empty(0, dtype=torch.int, device=qweight.device)
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device) g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
repacked = marlin_kernels.gptq_marlin_repack( if quant_method == "awq":
qweight, perm, in_features, out_features, bits repacked = marlin_kernels.awq_marlin_repack(
) qweight, in_features, out_features, bits
)
if qzeros is not None:
qzeros = awq_to_marlin_zero_points(
qzeros,
in_features // groupsize,
out_features,
bits,
)
else:
repacked = marlin_kernels.gptq_marlin_repack(
qweight, perm, in_features, out_features, bits
)
if qzeros is None:
qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)
scales = permute_scales(scales) scales = permute_scales(scales)
@ -310,6 +339,7 @@ def repack_gptq_for_marlin(
return GPTQMarlinWeight( return GPTQMarlinWeight(
qweight=repacked, qweight=repacked,
qzeros=qzeros,
scales=scales, scales=scales,
g_idx=g_idx, g_idx=g_idx,
perm=perm, perm=perm,
@ -343,6 +373,7 @@ class GPTQMarlinLinear(nn.Module):
self.is_full_k = weight.is_full_k self.is_full_k = weight.is_full_k
self.qweight = weight.qweight self.qweight = weight.qweight
self.qzeros = weight.qzeros
self.scales = weight.scales self.scales = weight.scales
self.g_idx = weight.g_idx self.g_idx = weight.g_idx
self.perm = weight.perm self.perm = weight.perm
@ -363,6 +394,7 @@ class GPTQMarlinLinear(nn.Module):
A_flat, A_flat,
self.qweight, self.qweight,
self.scales, self.scales,
self.qzeros,
self.g_idx, self.g_idx,
self.perm, self.perm,
self.workspace, self.workspace,
@ -371,6 +403,7 @@ class GPTQMarlinLinear(nn.Module):
self.scales.shape[1], self.scales.shape[1],
A_flat.shape[1], A_flat.shape[1],
self.is_full_k, self.is_full_k,
self.qzeros.numel() > 0,
) )
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
@ -688,3 +721,116 @@ class MarlinLinear(nn.Module):
C += self.bias C += self.bias
return C return C
# Functions below are from vLLM
def get_pack_factor(bits: int) -> int:
if 32 % bits != 0:
raise ValueError(f"Cannot {bits} bit values into uint32")
return 32 // bits
def pack_cols(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[:, i::pack_factor] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def unpack_cols(
packed_q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
assert packed_q_w.shape == (
size_k,
size_n // pack_factor,
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
packed_q_w.shape, size_k, size_n, pack_factor
)
orig_device = packed_q_w.device
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
mask = (1 << num_bits) - 1
for i in range(pack_factor):
vals = packed_q_w_cpu & mask
packed_q_w_cpu >>= num_bits
q_res[:, i::pack_factor] = vals
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def marlin_zero_points(
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
zp = zp.reshape((-1, len(_scale_perm)))[:, _scale_perm]
# Interleave column dim (for the dequantize code) and pack it to int32
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
zp = zp.reshape((-1, size_n)).contiguous()
zp = pack_cols(zp, num_bits, size_k, size_n)
return zp
def awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
# Here we undo both of these, and then apply marlin permutation
# and pack it back.
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
# Undo interleaving (use argsort(..) to get inverse perm)
if num_bits == 4:
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
elif num_bits == 8:
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
q_zp = q_zp.reshape((-1, size_n)).contiguous()
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
return marlin_zp

View File

@ -34,7 +34,7 @@ def _get_quantizer_config(model_id, revision):
groupsize = -1 groupsize = -1
quant_method = "gptq" quant_method = "gptq"
checkpoint_format = None checkpoint_format = None
sym = True sym = False
desc_act = False desc_act = False
filename = "config.json" filename = "config.json"
@ -52,12 +52,17 @@ def _get_quantizer_config(model_id, revision):
activation_scale_ub=data["quantization_config"]["activation_scale_ub"] activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
) )
if "zero_point" in data["quantization_config"]:
sym = not data["quantization_config"]["zero_point"]
quant_method = "awq"
elif "sym" in data["quantization_config"]:
sym = data["quantization_config"]["sym"]
bits = data["quantization_config"]["bits"] bits = data["quantization_config"]["bits"]
groupsize = data["quantization_config"]["group_size"] groupsize = data["quantization_config"]["group_size"]
# Order is important here, desc_act is missing on some real models # Order is important here, desc_act is missing on some real models
quant_method = data["quantization_config"]["quant_method"] quant_method = data["quantization_config"]["quant_method"]
checkpoint_format = data["quantization_config"].get("checkpoint_format") checkpoint_format = data["quantization_config"].get("checkpoint_format")
sym = data["quantization_config"]["sym"]
desc_act = data["quantization_config"]["desc_act"] desc_act = data["quantization_config"]["desc_act"]
except Exception: except Exception:
filename = "quantize_config.json" filename = "quantize_config.json"
@ -72,7 +77,13 @@ def _get_quantizer_config(model_id, revision):
data = json.load(f) data = json.load(f)
bits = data["bits"] bits = data["bits"]
groupsize = data["group_size"] groupsize = data["group_size"]
sym = data["sym"]
if "zero_point" in data:
sym = not data["zero_point"]
quant_method = "awq"
elif "sym" in data:
sym = data["sym"]
desc_act = data["desc_act"] desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM": if "version" in data and data["version"] == "GEMM":
quant_method = "awq" quant_method = "awq"