Add support for repacking AWQ weights for GPTQ-Marlin (#2278)
* Add support for repacking AWQ weights for GPTQ-Marlin So far we couldn't support AWQ because virtually all AWQ models use symmetric quantization, which GPTQ-Marlin did not suppors. GPTQ-Marlin has recently added support AWQ repacking and AWQ asymmetric quantization (zero_point=True). This change updates all GPTQ-Marlin kernels from upstream and wires up AWQ support. For now enabling AWQ using Marlin requires running TGI with `--quantize gptq`. * Enable Marlin for supported AWQ configurations by default This makes the AWQ -> GPTQ repack test redundant, since we are now testing this with the regular AWQ test.
This commit is contained in:
parent
5fca30ee15
commit
9935720c87
|
@ -0,0 +1,269 @@
|
||||||
|
#include "marlin.cuh"
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
|
||||||
|
namespace marlin {
|
||||||
|
|
||||||
|
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||||
|
__global__ void awq_marlin_repack_kernel(
|
||||||
|
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
|
||||||
|
int size_k, int size_n) {}
|
||||||
|
|
||||||
|
} // namespace marlin
|
||||||
|
|
||||||
|
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||||
|
int64_t size_k, int64_t size_n,
|
||||||
|
int64_t num_bits) {
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
|
||||||
|
return torch::empty({1, 1});
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
namespace marlin {
|
||||||
|
|
||||||
|
template <int const num_threads, int const num_bits>
|
||||||
|
__global__ void awq_marlin_repack_kernel(
|
||||||
|
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
|
||||||
|
int size_k, int size_n) {
|
||||||
|
constexpr int pack_factor = 32 / num_bits;
|
||||||
|
|
||||||
|
int k_tiles = size_k / tile_k_size;
|
||||||
|
int n_tiles = size_n / tile_n_size;
|
||||||
|
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||||
|
|
||||||
|
int start_k_tile = blockIdx.x * block_k_tiles;
|
||||||
|
if (start_k_tile >= k_tiles) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
||||||
|
|
||||||
|
// Wait until the next thread tile has been loaded to shared memory.
|
||||||
|
auto wait_for_stage = [&]() {
|
||||||
|
// We only have `stages - 2` active fetches since we are double buffering
|
||||||
|
// and can only issue the next fetch when it is guaranteed that the previous
|
||||||
|
// shared memory load is fully complete (as it may otherwise be
|
||||||
|
// overwritten).
|
||||||
|
cp_async_wait<repack_stages - 2>();
|
||||||
|
__syncthreads();
|
||||||
|
};
|
||||||
|
|
||||||
|
extern __shared__ int4 sh[];
|
||||||
|
|
||||||
|
constexpr int tile_n_ints = tile_n_size / pack_factor;
|
||||||
|
|
||||||
|
constexpr int stage_n_threads = tile_n_ints / 4;
|
||||||
|
constexpr int stage_k_threads = tile_k_size;
|
||||||
|
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
||||||
|
|
||||||
|
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||||
|
if (n_tile_id >= n_tiles) {
|
||||||
|
cp_async_fence();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int first_n = n_tile_id * tile_n_size;
|
||||||
|
int first_n_packed = first_n / pack_factor;
|
||||||
|
|
||||||
|
int4* sh_ptr = sh + stage_size * pipe;
|
||||||
|
|
||||||
|
if (threadIdx.x < stage_size) {
|
||||||
|
int k_id = threadIdx.x / stage_n_threads;
|
||||||
|
int n_id = threadIdx.x % stage_n_threads;
|
||||||
|
|
||||||
|
int first_k = k_tile_id * tile_k_size;
|
||||||
|
|
||||||
|
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
|
||||||
|
reinterpret_cast<int4 const*>(
|
||||||
|
&(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) +
|
||||||
|
first_n_packed + (n_id * 4)])));
|
||||||
|
}
|
||||||
|
|
||||||
|
cp_async_fence();
|
||||||
|
};
|
||||||
|
|
||||||
|
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||||
|
if (n_tile_id >= n_tiles) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int warp_id = threadIdx.x / 32;
|
||||||
|
int th_id = threadIdx.x % 32;
|
||||||
|
|
||||||
|
if (warp_id >= 4) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int tc_col = th_id / 4;
|
||||||
|
int tc_row = (th_id % 4) * 2;
|
||||||
|
|
||||||
|
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
||||||
|
|
||||||
|
int cur_n = warp_id * 16 + tc_col;
|
||||||
|
int cur_n_packed = cur_n / pack_factor;
|
||||||
|
int cur_n_pos = cur_n % pack_factor;
|
||||||
|
|
||||||
|
constexpr int sh_stride = tile_n_ints;
|
||||||
|
constexpr uint32_t mask = (1 << num_bits) - 1;
|
||||||
|
|
||||||
|
int4* sh_stage_ptr = sh + stage_size * pipe;
|
||||||
|
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
||||||
|
|
||||||
|
// Undo interleaving
|
||||||
|
int cur_n_pos_unpacked;
|
||||||
|
if constexpr (num_bits == 4) {
|
||||||
|
constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
|
||||||
|
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
||||||
|
} else {
|
||||||
|
constexpr int undo_pack[4] = {0, 2, 1, 3};
|
||||||
|
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t vals[8];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
int cur_elem = tc_row + tc_offsets[i];
|
||||||
|
|
||||||
|
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
|
||||||
|
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
|
||||||
|
sh_stride * cur_elem];
|
||||||
|
|
||||||
|
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||||
|
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
||||||
|
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
||||||
|
|
||||||
|
// Result of:
|
||||||
|
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||||
|
if constexpr (num_bits == 4) {
|
||||||
|
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||||
|
|
||||||
|
uint32_t res = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
res |= vals[pack_idx[i]] << (i * 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
||||||
|
|
||||||
|
uint32_t res1 = 0;
|
||||||
|
uint32_t res2 = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
res1 |= vals[pack_idx[i]] << (i * 8);
|
||||||
|
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
||||||
|
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||||
|
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||||
|
}
|
||||||
|
|
||||||
|
wait_for_stage();
|
||||||
|
};
|
||||||
|
#pragma unroll
|
||||||
|
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||||
|
int n_tile_id = 0;
|
||||||
|
|
||||||
|
start_pipes(k_tile_id, n_tile_id);
|
||||||
|
|
||||||
|
while (n_tile_id < n_tiles) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||||
|
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
||||||
|
n_tile_id + pipe + repack_stages - 1);
|
||||||
|
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
||||||
|
wait_for_stage();
|
||||||
|
}
|
||||||
|
n_tile_id += repack_stages;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace marlin
|
||||||
|
|
||||||
|
#define CALL_IF(NUM_BITS) \
|
||||||
|
else if (num_bits == NUM_BITS) { \
|
||||||
|
cudaFuncSetAttribute( \
|
||||||
|
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
|
||||||
|
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||||
|
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
|
||||||
|
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||||
|
b_q_weight_ptr, out_ptr, size_k, size_n); \
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||||
|
int64_t size_n, int64_t num_bits) {
|
||||||
|
// Verify compatibility with marlin tile of 16x64
|
||||||
|
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
|
||||||
|
" is not divisible by tile_k_size = ", marlin::tile_k_size);
|
||||||
|
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
|
||||||
|
" is not divisible by tile_n_size = ", marlin::tile_n_size);
|
||||||
|
|
||||||
|
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||||
|
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||||
|
int const pack_factor = 32 / num_bits;
|
||||||
|
|
||||||
|
// Verify B
|
||||||
|
TORCH_CHECK(b_q_weight.size(0) == size_k,
|
||||||
|
"b_q_weight.size(0) = ", b_q_weight.size(0),
|
||||||
|
" is not size_k = ", size_k);
|
||||||
|
TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1),
|
||||||
|
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||||
|
", size_n = ", size_n, ", pack_factor = ", pack_factor);
|
||||||
|
|
||||||
|
// Verify device and strides
|
||||||
|
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||||
|
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||||
|
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
||||||
|
|
||||||
|
// Alloc buffers
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||||
|
auto options = torch::TensorOptions()
|
||||||
|
.dtype(b_q_weight.dtype())
|
||||||
|
.device(b_q_weight.device());
|
||||||
|
torch::Tensor out = torch::empty(
|
||||||
|
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||||
|
options);
|
||||||
|
|
||||||
|
// Get ptrs
|
||||||
|
uint32_t const* b_q_weight_ptr =
|
||||||
|
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
||||||
|
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
||||||
|
|
||||||
|
// Get dev info
|
||||||
|
int dev = b_q_weight.get_device();
|
||||||
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||||
|
int blocks;
|
||||||
|
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||||
|
|
||||||
|
int max_shared_mem = 0;
|
||||||
|
cudaDeviceGetAttribute(&max_shared_mem,
|
||||||
|
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||||
|
TORCH_CHECK(max_shared_mem > 0);
|
||||||
|
|
||||||
|
if (false) {
|
||||||
|
}
|
||||||
|
CALL_IF(4)
|
||||||
|
CALL_IF(8)
|
||||||
|
else {
|
||||||
|
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -3,6 +3,8 @@
|
||||||
#include "ext.hh"
|
#include "ext.hh"
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("awq_marlin_repack", &awq_marlin_repack,
|
||||||
|
"Repack AWQ parameters for Marlin");
|
||||||
m.def("gptq_marlin_gemm", &gptq_marlin_gemm,
|
m.def("gptq_marlin_gemm", &gptq_marlin_gemm,
|
||||||
"Marlin gemm with GPTQ compatibility");
|
"Marlin gemm with GPTQ compatibility");
|
||||||
m.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin sparse 2:4 gemm");
|
m.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin sparse 2:4 gemm");
|
||||||
|
|
|
@ -6,11 +6,15 @@
|
||||||
// No support for async
|
// No support for async
|
||||||
#else
|
#else
|
||||||
|
|
||||||
|
torch::Tensor awq_marlin_repack(torch::Tensor &b_q_weight, int64_t size_k,
|
||||||
|
int64_t size_n, int64_t num_bits);
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||||
torch::Tensor &b_scales, torch::Tensor &g_idx,
|
torch::Tensor &b_scales, torch::Tensor &b_zeros,
|
||||||
torch::Tensor &perm, torch::Tensor &workspace,
|
torch::Tensor &g_idx, torch::Tensor &perm,
|
||||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
torch::Tensor &workspace, int64_t num_bits,
|
||||||
int64_t size_k, bool is_k_full);
|
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||||
|
bool is_k_full, bool has_zp);
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||||
torch::Tensor &b_meta,
|
torch::Tensor &b_meta,
|
||||||
|
@ -27,8 +31,8 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||||
torch::Tensor &b_scales, torch::Tensor &workspace,
|
torch::Tensor &b_scales, torch::Tensor &workspace,
|
||||||
int64_t size_m, int64_t size_n, int64_t size_k);
|
int64_t size_m, int64_t size_n, int64_t size_k);
|
||||||
|
|
||||||
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
torch::Tensor fp8_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
torch::Tensor &b_scales, torch::Tensor &workspace,
|
||||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||||
int64_t size_k);
|
int64_t size_k);
|
||||||
|
|
||||||
|
|
|
@ -19,10 +19,10 @@
|
||||||
* Adapted from https://github.com/IST-DASLab/marlin
|
* Adapted from https://github.com/IST-DASLab/marlin
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "./gptq_marlin.cuh"
|
#include "marlin.cuh"
|
||||||
#include "./gptq_marlin_dtypes.cuh"
|
#include "marlin_dtypes.cuh"
|
||||||
|
|
||||||
using namespace gptq_marlin;
|
using namespace marlin;
|
||||||
|
|
||||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||||
static_assert(std::is_same<scalar_t, half>::value || \
|
static_assert(std::is_same<scalar_t, half>::value || \
|
||||||
|
@ -1224,16 +1224,15 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
", size_k = ", size_k);
|
", size_k = ", size_k);
|
||||||
|
|
||||||
// Verify B
|
// Verify B
|
||||||
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k,
|
TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
|
||||||
" is not divisible by tile_size = ", gptq_marlin::tile_size);
|
" is not divisible by tile_size = ", marlin::tile_size);
|
||||||
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
|
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
|
||||||
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
||||||
", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size);
|
", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
|
||||||
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,
|
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
|
||||||
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||||
" is not divisible by tile_size = ", gptq_marlin::tile_size);
|
" is not divisible by tile_size = ", marlin::tile_size);
|
||||||
int actual_size_n =
|
int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
|
||||||
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
|
|
||||||
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
||||||
", actual_size_n = ", actual_size_n);
|
", actual_size_n = ", actual_size_n);
|
||||||
|
|
||||||
|
@ -1274,11 +1273,9 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
num_groups = b_scales.size(0);
|
num_groups = b_scales.size(0);
|
||||||
|
|
||||||
// Verify workspace size
|
// Verify workspace size
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
|
||||||
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n,
|
", is not divisible by min_thread_n = ", marlin::min_thread_n);
|
||||||
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n);
|
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
|
||||||
int min_workspace_size =
|
|
||||||
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
|
|
||||||
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||||
"workspace.numel = ", workspace.numel(),
|
"workspace.numel = ", workspace.numel(),
|
||||||
" is below min_workspace_size = ", min_workspace_size);
|
" is below min_workspace_size = ", min_workspace_size);
|
||||||
|
@ -1290,14 +1287,14 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
b_scales.data_ptr<at::Half>(), size_m, size_n, size_k,
|
b_scales.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||||
workspace.data_ptr(), num_bits, num_groups, group_size, dev,
|
workspace.data_ptr(), num_bits, num_groups, group_size, dev,
|
||||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||||
gptq_marlin::max_par);
|
marlin::max_par);
|
||||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||||
fp8_marlin::marlin_mm_f16i4<nv_bfloat16>(
|
fp8_marlin::marlin_mm_f16i4<nv_bfloat16>(
|
||||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||||
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m,
|
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m,
|
||||||
size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size,
|
size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size,
|
||||||
dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||||
gptq_marlin::max_par);
|
marlin::max_par);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16");
|
TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16");
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,8 +19,8 @@
|
||||||
* Adapted from https://github.com/IST-DASLab/marlin
|
* Adapted from https://github.com/IST-DASLab/marlin
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "gptq_marlin.cuh"
|
#include "marlin.cuh"
|
||||||
#include "gptq_marlin_dtypes.cuh"
|
#include "marlin_dtypes.cuh"
|
||||||
|
|
||||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||||
static_assert(std::is_same<scalar_t, half>::value || \
|
static_assert(std::is_same<scalar_t, half>::value || \
|
||||||
|
@ -32,7 +32,7 @@ inline std::string str(T x) {
|
||||||
return std::to_string(x);
|
return std::to_string(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace gptq_marlin {
|
namespace marlin {
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
|
||||||
|
@ -72,10 +72,11 @@ __global__ void Marlin(
|
||||||
} // namespace gptq_marlin
|
} // namespace gptq_marlin
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor& b_scales, torch::Tensor& g_idx,
|
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||||
torch::Tensor& perm, torch::Tensor& workspace,
|
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
torch::Tensor& workspace, int64_t num_bits,
|
||||||
int64_t size_k, bool is_k_full) {
|
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||||
|
bool is_k_full) {
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||||
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
||||||
return torch::empty({1, 1});
|
return torch::empty({1, 1});
|
||||||
|
@ -264,6 +265,114 @@ dequant_8bit<nv_bfloat16>(int q) {
|
||||||
return frag_b;
|
return frag_b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Zero-point dequantizers
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit_zp(int q) {
|
||||||
|
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline typename ScalarType<half>::FragB dequant_4bit_zp<half>(
|
||||||
|
int q) {
|
||||||
|
const int LO = 0x000f000f;
|
||||||
|
const int HI = 0x00f000f0;
|
||||||
|
const int EX = 0x64006400;
|
||||||
|
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||||
|
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||||
|
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||||
|
|
||||||
|
const int SUB = 0x64006400;
|
||||||
|
const int MUL = 0x2c002c00;
|
||||||
|
const int ADD = 0xd400d400;
|
||||||
|
typename ScalarType<half>::FragB frag_b;
|
||||||
|
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||||
|
*reinterpret_cast<const half2*>(&SUB));
|
||||||
|
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||||
|
*reinterpret_cast<const half2*>(&MUL),
|
||||||
|
*reinterpret_cast<const half2*>(&ADD));
|
||||||
|
return frag_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||||
|
dequant_4bit_zp<nv_bfloat16>(int q) {
|
||||||
|
static constexpr uint32_t MASK = 0x000f000f;
|
||||||
|
static constexpr uint32_t EX = 0x43004300;
|
||||||
|
|
||||||
|
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||||
|
|
||||||
|
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||||
|
q >>= 4;
|
||||||
|
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||||
|
|
||||||
|
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||||
|
static constexpr uint32_t MUL = 0x3F803F80;
|
||||||
|
static constexpr uint32_t ADD = 0xC300C300;
|
||||||
|
|
||||||
|
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||||
|
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||||
|
return frag_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit_zp(int q) {
|
||||||
|
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline typename ScalarType<half>::FragB dequant_8bit_zp<half>(
|
||||||
|
int q) {
|
||||||
|
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||||
|
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||||
|
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||||
|
|
||||||
|
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||||
|
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||||
|
|
||||||
|
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||||
|
|
||||||
|
typename ScalarType<half>::FragB frag_b;
|
||||||
|
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||||
|
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||||
|
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||||
|
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||||
|
return frag_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||||
|
dequant_8bit_zp<nv_bfloat16>(int q) {
|
||||||
|
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||||
|
|
||||||
|
float fp32_intermediates[4];
|
||||||
|
uint32_t* fp32_intermediates_casted =
|
||||||
|
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||||
|
|
||||||
|
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||||
|
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||||
|
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||||
|
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||||
|
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||||
|
|
||||||
|
fp32_intermediates[0] -= 8388608.f;
|
||||||
|
fp32_intermediates[1] -= 8388608.f;
|
||||||
|
fp32_intermediates[2] -= 8388608.f;
|
||||||
|
fp32_intermediates[3] -= 8388608.f;
|
||||||
|
|
||||||
|
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
|
||||||
|
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
||||||
|
fp32_intermediates_casted[1], 0x7632);
|
||||||
|
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
||||||
|
fp32_intermediates_casted[3], 0x7632);
|
||||||
|
|
||||||
|
return frag_b;
|
||||||
|
}
|
||||||
|
|
||||||
// Multiply dequantized values by the corresponding quantization scale; used
|
// Multiply dequantized values by the corresponding quantization scale; used
|
||||||
// only for grouped quantization.
|
// only for grouped quantization.
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
|
@ -277,6 +386,17 @@ __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
|
||||||
frag_b[1] = __hmul2(frag_b[1], s);
|
frag_b[1] = __hmul2(frag_b[1], s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__device__ inline void sub_zp(typename ScalarType<scalar_t>::FragB& frag_b,
|
||||||
|
typename ScalarType<scalar_t>::scalar_t2& frag_zp,
|
||||||
|
int i) {
|
||||||
|
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
||||||
|
scalar_t2 zp =
|
||||||
|
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]);
|
||||||
|
frag_b[0] = __hsub2(frag_b[0], zp);
|
||||||
|
frag_b[1] = __hsub2(frag_b[1], zp);
|
||||||
|
}
|
||||||
|
|
||||||
// Same as above, but for act_order (each K is multiplied individually)
|
// Same as above, but for act_order (each K is multiplied individually)
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
|
__device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
|
||||||
|
@ -404,6 +524,7 @@ template <typename scalar_t, // compute dtype, half or nv_float16
|
||||||
const int stages, // number of stages for the async global->shared
|
const int stages, // number of stages for the async global->shared
|
||||||
// fetch pipeline
|
// fetch pipeline
|
||||||
const bool has_act_order, // whether act_order is enabled
|
const bool has_act_order, // whether act_order is enabled
|
||||||
|
const bool has_zp, // whether zero-points are enabled
|
||||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||||
// with a separate quantization scale
|
// with a separate quantization scale
|
||||||
>
|
>
|
||||||
|
@ -413,6 +534,8 @@ __global__ void Marlin(
|
||||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||||
// (k/groupsize)xn
|
// (k/groupsize)xn
|
||||||
|
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||||
|
// (k/groupsize)x(n/pack_factor)
|
||||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||||
int num_groups, // number of scale groups per output channel
|
int num_groups, // number of scale groups per output channel
|
||||||
int prob_m, // batch dimension m
|
int prob_m, // batch dimension m
|
||||||
|
@ -437,6 +560,7 @@ __global__ void Marlin(
|
||||||
using FragB = typename ScalarType<scalar_t>::FragB;
|
using FragB = typename ScalarType<scalar_t>::FragB;
|
||||||
using FragC = typename ScalarType<scalar_t>::FragC;
|
using FragC = typename ScalarType<scalar_t>::FragC;
|
||||||
using FragS = typename ScalarType<scalar_t>::FragS;
|
using FragS = typename ScalarType<scalar_t>::FragS;
|
||||||
|
using FragZP = typename ScalarType<scalar_t>::FragZP;
|
||||||
|
|
||||||
constexpr int pack_factor = 32 / num_bits;
|
constexpr int pack_factor = 32 / num_bits;
|
||||||
|
|
||||||
|
@ -566,6 +690,13 @@ __global__ void Marlin(
|
||||||
int tb_n_warps = thread_n_blocks / 4;
|
int tb_n_warps = thread_n_blocks / 4;
|
||||||
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
|
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
|
||||||
|
|
||||||
|
// Zero-points sizes/strides
|
||||||
|
int zp_gl_stride = (prob_n / pack_factor) / 4;
|
||||||
|
constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4;
|
||||||
|
constexpr int zp_tb_groups = s_tb_groups;
|
||||||
|
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
|
||||||
|
int zp_gl_rd_delta = zp_gl_stride;
|
||||||
|
|
||||||
// Global A read index of current thread.
|
// Global A read index of current thread.
|
||||||
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
||||||
(threadIdx.x % a_gl_rd_delta_o);
|
(threadIdx.x % a_gl_rd_delta_o);
|
||||||
|
@ -605,6 +736,19 @@ __global__ void Marlin(
|
||||||
int s_sh_wr = threadIdx.x;
|
int s_sh_wr = threadIdx.x;
|
||||||
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
||||||
|
|
||||||
|
// Zero-points
|
||||||
|
int zp_gl_rd;
|
||||||
|
if constexpr (has_zp) {
|
||||||
|
if constexpr (group_blocks == -1) {
|
||||||
|
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
||||||
|
} else {
|
||||||
|
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
||||||
|
zp_sh_stride * slice_col + threadIdx.x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int zp_sh_wr = threadIdx.x;
|
||||||
|
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
|
||||||
|
|
||||||
// We use a different scale layout for grouped and column-wise quantization as
|
// We use a different scale layout for grouped and column-wise quantization as
|
||||||
// we scale a `half2` tile in column-major layout in the former and in
|
// we scale a `half2` tile in column-major layout in the former and in
|
||||||
// row-major in the latter case.
|
// row-major in the latter case.
|
||||||
|
@ -616,6 +760,18 @@ __global__ void Marlin(
|
||||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||||
(threadIdx.x % 32) % 4;
|
(threadIdx.x % 32) % 4;
|
||||||
|
|
||||||
|
// Zero-points have the same read layout as the scales
|
||||||
|
// (without column-wise case)
|
||||||
|
constexpr int num_col_threads = 8;
|
||||||
|
constexpr int num_row_threads = 4;
|
||||||
|
constexpr int num_ints_per_thread = 8 / pack_factor;
|
||||||
|
int zp_sh_rd;
|
||||||
|
if constexpr (has_zp) {
|
||||||
|
zp_sh_rd = num_ints_per_thread * num_col_threads *
|
||||||
|
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||||
|
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
|
||||||
|
}
|
||||||
|
|
||||||
// Precompute which thread should not read memory in which iterations; this is
|
// Precompute which thread should not read memory in which iterations; this is
|
||||||
// needed if there are more threads than required for a certain tilesize or
|
// needed if there are more threads than required for a certain tilesize or
|
||||||
// when the batchsize is not a multiple of 16.
|
// when the batchsize is not a multiple of 16.
|
||||||
|
@ -664,14 +820,17 @@ __global__ void Marlin(
|
||||||
int4* sh_a = sh;
|
int4* sh_a = sh;
|
||||||
int4* sh_b = sh_a + (stages * a_sh_stage);
|
int4* sh_b = sh_a + (stages * a_sh_stage);
|
||||||
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
|
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
|
||||||
int4* sh_s = sh_g_idx + (stages * g_idx_stage);
|
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
||||||
|
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||||
|
|
||||||
// Register storage for double buffer of shared memory reads.
|
// Register storage for double buffer of shared memory reads.
|
||||||
FragA frag_a[2][thread_m_blocks];
|
FragA frag_a[2][thread_m_blocks];
|
||||||
I4 frag_b_quant[2][b_thread_vecs];
|
I4 frag_b_quant[2][b_thread_vecs];
|
||||||
FragC frag_c[thread_m_blocks][4][2];
|
FragC frag_c[thread_m_blocks][4][2];
|
||||||
FragS frag_s[2][4]; // No act-order
|
FragS frag_s[2][4]; // No act-order
|
||||||
FragS act_frag_s[2][4][4]; // For act-order
|
FragS act_frag_s[2][4][4]; // For act-order
|
||||||
|
int frag_qzp[2][num_ints_per_thread]; // Zero-points
|
||||||
|
FragZP frag_zp; // Zero-points in fp16
|
||||||
|
|
||||||
// Zero accumulators.
|
// Zero accumulators.
|
||||||
auto zero_accums = [&]() {
|
auto zero_accums = [&]() {
|
||||||
|
@ -777,6 +936,28 @@ __global__ void Marlin(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if constexpr (has_zp && group_blocks != -1) {
|
||||||
|
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||||
|
|
||||||
|
if constexpr (group_blocks >= thread_k_blocks) {
|
||||||
|
// Only fetch zero-points if this tile starts a new group
|
||||||
|
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||||
|
if (zp_sh_wr_pred) {
|
||||||
|
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
||||||
|
}
|
||||||
|
zp_gl_rd += zp_gl_rd_delta;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < zp_tb_groups; i++) {
|
||||||
|
if (zp_sh_wr_pred) {
|
||||||
|
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
|
||||||
|
&zp_ptr[zp_gl_rd]);
|
||||||
|
}
|
||||||
|
zp_gl_rd += zp_gl_rd_delta;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Insert a fence even when we are winding down the pipeline to ensure that
|
// Insert a fence even when we are winding down the pipeline to ensure that
|
||||||
|
@ -784,6 +965,12 @@ __global__ void Marlin(
|
||||||
cp_async_fence();
|
cp_async_fence();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto fetch_zp_to_shared = [&]() {
|
||||||
|
if (zp_sh_wr_pred) {
|
||||||
|
cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Wait until the next thread tile has been loaded to shared memory.
|
// Wait until the next thread tile has been loaded to shared memory.
|
||||||
auto wait_for_stage = [&]() {
|
auto wait_for_stage = [&]() {
|
||||||
// We only have `stages - 2` active fetches since we are double buffering
|
// We only have `stages - 2` active fetches since we are double buffering
|
||||||
|
@ -932,8 +1119,73 @@ __global__ void Marlin(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
|
||||||
|
if constexpr (!has_zp) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int pipe = full_pipe % stages;
|
||||||
|
|
||||||
|
if constexpr (group_blocks == -1) {
|
||||||
|
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||||
|
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if constexpr (group_blocks >= thread_k_blocks) {
|
||||||
|
int4* sh_zp_stage =
|
||||||
|
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||||
|
(pipe / (group_blocks / thread_k_blocks)));
|
||||||
|
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||||
|
frag_qzp[k % 2][i] =
|
||||||
|
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int warp_id = threadIdx.x / 32;
|
||||||
|
int n_warps = thread_n_blocks / 4;
|
||||||
|
|
||||||
|
int warp_row = warp_id / n_warps;
|
||||||
|
|
||||||
|
int cur_k = warp_row * 16;
|
||||||
|
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
||||||
|
|
||||||
|
int k_blocks = cur_k / 16;
|
||||||
|
int cur_group_id = k_blocks / group_blocks;
|
||||||
|
|
||||||
|
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||||
|
|
||||||
|
sh_zp_stage += cur_group_id * zp_sh_stride;
|
||||||
|
|
||||||
|
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||||
|
frag_qzp[k % 2][i] =
|
||||||
|
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Execute the actual tensor core matmul of a sub-tile.
|
// Execute the actual tensor core matmul of a sub-tile.
|
||||||
auto matmul = [&](int k) {
|
auto matmul = [&](int k) {
|
||||||
|
if constexpr (has_zp) {
|
||||||
|
FragB frag_zp_0;
|
||||||
|
FragB frag_zp_1;
|
||||||
|
if constexpr (num_bits == 4) {
|
||||||
|
int zp_quant = frag_qzp[k % 2][0];
|
||||||
|
int zp_quant_shift = zp_quant >> 8;
|
||||||
|
frag_zp_0 = dequant_4bit_zp<scalar_t>(zp_quant);
|
||||||
|
frag_zp_1 = dequant_4bit_zp<scalar_t>(zp_quant_shift);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
int zp_quant_0 = frag_qzp[k % 2][0];
|
||||||
|
int zp_quant_1 = frag_qzp[k % 2][1];
|
||||||
|
frag_zp_0 = dequant_8bit_zp<scalar_t>(zp_quant_0);
|
||||||
|
frag_zp_1 = dequant_8bit_zp<scalar_t>(zp_quant_1);
|
||||||
|
}
|
||||||
|
|
||||||
|
frag_zp[0] = frag_zp_0[0];
|
||||||
|
frag_zp[1] = frag_zp_0[1];
|
||||||
|
frag_zp[2] = frag_zp_1[0];
|
||||||
|
frag_zp[3] = frag_zp_1[1];
|
||||||
|
}
|
||||||
|
|
||||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||||
// dequantization and matmul operations.
|
// dequantization and matmul operations.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -944,16 +1196,32 @@ __global__ void Marlin(
|
||||||
int b_quant = frag_b_quant[k % 2][0][j];
|
int b_quant = frag_b_quant[k % 2][0][j];
|
||||||
int b_quant_shift = b_quant >> 8;
|
int b_quant_shift = b_quant >> 8;
|
||||||
|
|
||||||
frag_b0 = dequant_4bit<scalar_t>(b_quant);
|
if constexpr (has_zp) {
|
||||||
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
|
frag_b0 = dequant_4bit_zp<scalar_t>(b_quant);
|
||||||
|
frag_b1 = dequant_4bit_zp<scalar_t>(b_quant_shift);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
frag_b0 = dequant_4bit<scalar_t>(b_quant);
|
||||||
|
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
|
||||||
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
|
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
|
||||||
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
||||||
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
||||||
|
|
||||||
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
|
if constexpr (has_zp) {
|
||||||
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
|
frag_b0 = dequant_8bit_zp<scalar_t>(b_quant_0);
|
||||||
|
frag_b1 = dequant_8bit_zp<scalar_t>(b_quant_1);
|
||||||
|
} else {
|
||||||
|
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
|
||||||
|
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply zero-point to frag_b0
|
||||||
|
if constexpr (has_zp) {
|
||||||
|
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply scale to frag_b0
|
// Apply scale to frag_b0
|
||||||
|
@ -967,6 +1235,11 @@ __global__ void Marlin(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply zero-point to frag_b1
|
||||||
|
if constexpr (has_zp) {
|
||||||
|
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
|
||||||
|
}
|
||||||
|
|
||||||
// Apply scale to frag_b1
|
// Apply scale to frag_b1
|
||||||
if constexpr (has_act_order) {
|
if constexpr (has_act_order) {
|
||||||
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
|
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
|
||||||
|
@ -1189,6 +1462,12 @@ __global__ void Marlin(
|
||||||
}
|
}
|
||||||
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
|
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if constexpr (has_zp && group_blocks == -1) {
|
||||||
|
if (i == 0) {
|
||||||
|
fetch_zp_to_shared();
|
||||||
|
}
|
||||||
|
}
|
||||||
fetch_to_shared(i, i, i < slice_iters);
|
fetch_to_shared(i, i, i < slice_iters);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1197,6 +1476,7 @@ __global__ void Marlin(
|
||||||
init_same_group(0);
|
init_same_group(0);
|
||||||
fetch_to_registers(0, 0);
|
fetch_to_registers(0, 0);
|
||||||
fetch_scales_to_registers(0, 0);
|
fetch_scales_to_registers(0, 0);
|
||||||
|
fetch_zp_to_registers(0, 0);
|
||||||
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
|
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
|
||||||
slice_k_start_shared_fetch += tb_k * (stages - 1);
|
slice_k_start_shared_fetch += tb_k * (stages - 1);
|
||||||
};
|
};
|
||||||
|
@ -1217,6 +1497,7 @@ __global__ void Marlin(
|
||||||
for (int k = 0; k < b_sh_wr_iters; k++) {
|
for (int k = 0; k < b_sh_wr_iters; k++) {
|
||||||
fetch_to_registers(k + 1, pipe % stages);
|
fetch_to_registers(k + 1, pipe % stages);
|
||||||
fetch_scales_to_registers(k + 1, pipe);
|
fetch_scales_to_registers(k + 1, pipe);
|
||||||
|
fetch_zp_to_registers(k + 1, pipe);
|
||||||
if (k == b_sh_wr_iters - 2) {
|
if (k == b_sh_wr_iters - 2) {
|
||||||
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
||||||
slice_iters >= stages);
|
slice_iters >= stages);
|
||||||
|
@ -1354,6 +1635,7 @@ __global__ void Marlin(
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||||
|
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
||||||
}
|
}
|
||||||
|
|
||||||
start_pipes();
|
start_pipes();
|
||||||
|
@ -1363,22 +1645,24 @@ __global__ void Marlin(
|
||||||
}
|
}
|
||||||
|
|
||||||
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
||||||
THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
|
THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
||||||
|
NUM_THREADS) \
|
||||||
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||||
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
|
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
||||||
num_threads == NUM_THREADS) { \
|
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
|
||||||
cudaFuncSetAttribute( \
|
cudaFuncSetAttribute( \
|
||||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
||||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||||
GROUP_BLOCKS>, \
|
HAS_ZP, GROUP_BLOCKS>, \
|
||||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
||||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||||
GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
HAS_ZP, GROUP_BLOCKS> \
|
||||||
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
|
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||||
prob_k, locks); \
|
A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \
|
||||||
|
prob_m, prob_n, prob_k, locks); \
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
@ -1548,39 +1832,61 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
||||||
return exec_config_t{0, {-1, -1, -1}};
|
return exec_config_t{0, {-1, -1, -1}};
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
#define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
|
||||||
|
|
||||||
|
#define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||||
|
\
|
||||||
|
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||||
|
\
|
||||||
|
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||||
|
\
|
||||||
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
|
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
|
||||||
void* g_idx, void* perm, void* a_tmp, int prob_m,
|
void* g_idx, void* perm, void* a_tmp, int prob_m,
|
||||||
int prob_n, int prob_k, void* workspace, int num_bits,
|
int prob_n, int prob_k, void* workspace, int num_bits,
|
||||||
bool has_act_order, bool is_k_full, int num_groups,
|
bool has_act_order, bool is_k_full, bool has_zp,
|
||||||
int group_size, int dev, cudaStream_t stream, int thread_k,
|
int num_groups, int group_size, int dev,
|
||||||
int thread_n, int sms, int max_par) {
|
cudaStream_t stream, int thread_k, int thread_n, int sms,
|
||||||
|
int max_par) {
|
||||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||||
|
@ -1665,6 +1971,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
|
||||||
const int4* B_ptr = (const int4*)B;
|
const int4* B_ptr = (const int4*)B;
|
||||||
int4* C_ptr = (int4*)C;
|
int4* C_ptr = (int4*)C;
|
||||||
const int4* s_ptr = (const int4*)s;
|
const int4* s_ptr = (const int4*)s;
|
||||||
|
const int4* zp_ptr = (const int4*)zp;
|
||||||
const int* g_idx_ptr = (const int*)g_idx;
|
const int* g_idx_ptr = (const int*)g_idx;
|
||||||
const int* perm_ptr = (const int*)perm;
|
const int* perm_ptr = (const int*)perm;
|
||||||
int4* a_tmp_ptr = (int4*)a_tmp;
|
int4* a_tmp_ptr = (int4*)a_tmp;
|
||||||
|
@ -1701,28 +2008,33 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
|
||||||
thread_m_blocks = exec_cfg.max_m_blocks;
|
thread_m_blocks = exec_cfg.max_m_blocks;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Define kernel configurations
|
|
||||||
if (false) {
|
if (false) {
|
||||||
}
|
}
|
||||||
CALL_IF(4, 32, 2, 256)
|
GPTQ_CALL_IF(4, 16, 4, 256)
|
||||||
CALL_IF(4, 16, 4, 256)
|
GPTQ_CALL_IF(4, 8, 8, 256)
|
||||||
CALL_IF(4, 8, 8, 256)
|
GPTQ_CALL_IF(4, 8, 4, 128)
|
||||||
CALL_IF(4, 8, 4, 128)
|
GPTQ_CALL_IF(4, 4, 8, 128)
|
||||||
CALL_IF(4, 4, 8, 128)
|
GPTQ_CALL_IF(8, 16, 4, 256)
|
||||||
CALL_IF(8, 32, 2, 256)
|
GPTQ_CALL_IF(8, 8, 8, 256)
|
||||||
CALL_IF(8, 16, 4, 256)
|
GPTQ_CALL_IF(8, 8, 4, 128)
|
||||||
CALL_IF(8, 8, 8, 256)
|
GPTQ_CALL_IF(8, 4, 8, 128)
|
||||||
CALL_IF(8, 8, 4, 128)
|
|
||||||
CALL_IF(8, 4, 8, 128)
|
AWQ_CALL_IF(4, 16, 4, 256)
|
||||||
|
AWQ_CALL_IF(4, 8, 8, 256)
|
||||||
|
AWQ_CALL_IF(4, 8, 4, 128)
|
||||||
|
AWQ_CALL_IF(4, 4, 8, 128)
|
||||||
|
AWQ_CALL_IF(8, 16, 4, 256)
|
||||||
|
AWQ_CALL_IF(8, 8, 8, 256)
|
||||||
|
AWQ_CALL_IF(8, 8, 4, 128)
|
||||||
|
AWQ_CALL_IF(8, 4, 8, 128)
|
||||||
else {
|
else {
|
||||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
|
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||||
str(prob_n) + ", " + str(prob_k) + "]" +
|
", ", prob_k, "]", ", has_act_order = ", has_act_order,
|
||||||
", has_act_order = " + str(has_act_order) +
|
", num_groups = ", num_groups, ", group_size = ", group_size,
|
||||||
", num_groups = " + str(num_groups) +
|
", thread_m_blocks = ", thread_m_blocks,
|
||||||
", group_size = " + str(group_size) +
|
", thread_n_blocks = ", thread_n_blocks,
|
||||||
", thread_m_blocks = " + str(thread_m_blocks) +
|
", thread_k_blocks = ", thread_k_blocks,
|
||||||
", thread_n_blocks = " + str(thread_n_blocks) +
|
", num_bits = ", num_bits);
|
||||||
", thread_k_blocks = " + str(thread_k_blocks));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
|
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
|
||||||
|
@ -1733,10 +2045,11 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
|
||||||
} // namespace gptq_marlin
|
} // namespace gptq_marlin
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor& b_scales, torch::Tensor& g_idx,
|
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||||
torch::Tensor& perm, torch::Tensor& workspace,
|
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
torch::Tensor& workspace, int64_t num_bits,
|
||||||
int64_t size_k, bool is_k_full) {
|
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||||
|
bool is_k_full, bool has_zp) {
|
||||||
// Verify num_bits
|
// Verify num_bits
|
||||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||||
|
@ -1749,16 +2062,15 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
", size_k = ", size_k);
|
", size_k = ", size_k);
|
||||||
|
|
||||||
// Verify B
|
// Verify B
|
||||||
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k,
|
TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
|
||||||
" is not divisible by tile_size = ", gptq_marlin::tile_size);
|
" is not divisible by tile_size = ", marlin::tile_size);
|
||||||
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
|
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
|
||||||
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
||||||
", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size);
|
", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
|
||||||
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,
|
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
|
||||||
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||||
" is not divisible by tile_size = ", gptq_marlin::tile_size);
|
" is not divisible by tile_size = ", marlin::tile_size);
|
||||||
int actual_size_n =
|
int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
|
||||||
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
|
|
||||||
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
||||||
", actual_size_n = ", actual_size_n);
|
", actual_size_n = ", actual_size_n);
|
||||||
|
|
||||||
|
@ -1772,6 +2084,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
||||||
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
||||||
|
|
||||||
|
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
|
||||||
|
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
|
||||||
|
|
||||||
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
|
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
|
||||||
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
|
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
|
||||||
|
|
||||||
|
@ -1805,8 +2120,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
int group_size = -1;
|
int group_size = -1;
|
||||||
bool has_act_order = g_idx.size(0) != 0;
|
bool has_act_order = g_idx.size(0) != 0;
|
||||||
|
|
||||||
int b_rank = b_scales.sizes().size();
|
int rank = b_scales.sizes().size();
|
||||||
TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2");
|
TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2");
|
||||||
TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
|
TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
|
||||||
" is not size_n = ", size_n);
|
" is not size_n = ", size_n);
|
||||||
num_groups = b_scales.size(0);
|
num_groups = b_scales.size(0);
|
||||||
|
@ -1832,34 +2147,44 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Verify b_zeros
|
||||||
|
if (has_zp) {
|
||||||
|
int rank = b_zeros.sizes().size();
|
||||||
|
TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2");
|
||||||
|
TORCH_CHECK(b_zeros.size(0) == num_groups,
|
||||||
|
"b_zeros dim 0 = ", b_zeros.size(0),
|
||||||
|
" is not num_groups = ", num_groups);
|
||||||
|
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
|
||||||
|
"b_zeros dim 1 = ", b_scales.size(1),
|
||||||
|
" is not size_n / pack_factor = ", size_n / pack_factor);
|
||||||
|
}
|
||||||
|
|
||||||
// Verify workspace size
|
// Verify workspace size
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
|
||||||
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n,
|
", is not divisible by min_thread_n = ", marlin::min_thread_n);
|
||||||
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n);
|
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
|
||||||
int min_workspace_size =
|
|
||||||
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
|
|
||||||
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||||
"workspace.numel = ", workspace.numel(),
|
"workspace.numel = ", workspace.numel(),
|
||||||
" is below min_workspace_size = ", min_workspace_size);
|
" is below min_workspace_size = ", min_workspace_size);
|
||||||
|
|
||||||
int dev = a.get_device();
|
int dev = a.get_device();
|
||||||
if (a.scalar_type() == at::ScalarType::Half) {
|
if (a.scalar_type() == at::ScalarType::Half) {
|
||||||
gptq_marlin::marlin_mm_f16i4<half>(
|
marlin::marlin_mm_f16i4<half>(
|
||||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||||
b_scales.data_ptr<at::Half>(), g_idx.data_ptr(), perm.data_ptr(),
|
b_scales.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||||
workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups,
|
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
|
||||||
group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
|
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||||
thread_n, sms, gptq_marlin::max_par);
|
thread_k, thread_n, sms, marlin::max_par);
|
||||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||||
gptq_marlin::marlin_mm_f16i4<nv_bfloat16>(
|
marlin::marlin_mm_f16i4<nv_bfloat16>(
|
||||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||||
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
|
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
|
||||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||||
size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order,
|
a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||||
is_k_full, num_groups, group_size, dev,
|
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
|
||||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||||
gptq_marlin::max_par);
|
thread_k, thread_n, sms, marlin::max_par);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,23 +1,16 @@
|
||||||
#include "gptq_marlin.cuh"
|
#include "marlin.cuh"
|
||||||
|
|
||||||
namespace gptq_marlin {
|
|
||||||
|
|
||||||
static constexpr int repack_stages = 8;
|
|
||||||
|
|
||||||
static constexpr int repack_threads = 256;
|
|
||||||
|
|
||||||
static constexpr int tile_k_size = tile_size;
|
|
||||||
static constexpr int tile_n_size = tile_k_size * 4;
|
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
|
||||||
|
namespace marlin {
|
||||||
|
|
||||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||||
__global__ void marlin_repack_kernel(
|
__global__ void gptq_marlin_repack_kernel(
|
||||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||||
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
||||||
int size_k, int size_n) {}
|
int size_k, int size_n) {}
|
||||||
|
|
||||||
} // namespace gptq_marlin
|
} // namespace marlin
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||||
int64_t size_k, int64_t size_n,
|
int64_t size_k, int64_t size_n,
|
||||||
|
@ -29,8 +22,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
|
namespace marlin {
|
||||||
|
|
||||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||||
__global__ void marlin_repack_kernel(
|
__global__ void gptq_marlin_repack_kernel(
|
||||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||||
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
||||||
int size_k, int size_n) {
|
int size_k, int size_n) {
|
||||||
|
@ -259,28 +254,28 @@ __global__ void marlin_repack_kernel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gptq_marlin
|
} // namespace marlin
|
||||||
|
|
||||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||||
cudaFuncSetAttribute( \
|
cudaFuncSetAttribute( \
|
||||||
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
|
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||||
NUM_BITS, HAS_PERM>, \
|
HAS_PERM>, \
|
||||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||||
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
|
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||||
HAS_PERM> \
|
HAS_PERM> \
|
||||||
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
|
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||||
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||||
int64_t size_k, int64_t size_n,
|
int64_t size_k, int64_t size_n,
|
||||||
int64_t num_bits) {
|
int64_t num_bits) {
|
||||||
// Verify compatibility with marlin tile of 16x64
|
// Verify compatibility with marlin tile of 16x64
|
||||||
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
|
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
|
||||||
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
|
" is not divisible by tile_k_size = ", marlin::tile_k_size);
|
||||||
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
|
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
|
||||||
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);
|
" is not divisible by tile_n_size = ", marlin::tile_n_size);
|
||||||
|
|
||||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||||
|
@ -308,10 +303,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||||
auto options = torch::TensorOptions()
|
auto options = torch::TensorOptions()
|
||||||
.dtype(b_q_weight.dtype())
|
.dtype(b_q_weight.dtype())
|
||||||
.device(b_q_weight.device());
|
.device(b_q_weight.device());
|
||||||
torch::Tensor out =
|
torch::Tensor out = torch::empty(
|
||||||
torch::empty({size_k / gptq_marlin::tile_size,
|
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||||
size_n * gptq_marlin::tile_size / pack_factor},
|
options);
|
||||||
options);
|
|
||||||
|
|
||||||
// Detect if there is act_order
|
// Detect if there is act_order
|
||||||
bool has_perm = perm.size(0) != 0;
|
bool has_perm = perm.size(0) != 0;
|
||||||
|
|
|
@ -9,7 +9,9 @@
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
namespace gptq_marlin {
|
namespace marlin {
|
||||||
|
|
||||||
|
// Marlin params
|
||||||
|
|
||||||
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
||||||
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
||||||
|
@ -25,6 +27,15 @@ static constexpr int min_thread_k = 64;
|
||||||
static constexpr int tile_size = 16;
|
static constexpr int tile_size = 16;
|
||||||
static constexpr int max_par = 16;
|
static constexpr int max_par = 16;
|
||||||
|
|
||||||
|
// Repack params
|
||||||
|
static constexpr int repack_stages = 8;
|
||||||
|
|
||||||
|
static constexpr int repack_threads = 256;
|
||||||
|
|
||||||
|
static constexpr int tile_k_size = tile_size;
|
||||||
|
static constexpr int tile_n_size = tile_k_size * 4;
|
||||||
|
|
||||||
|
// Helpers
|
||||||
template <typename T, int n>
|
template <typename T, int n>
|
||||||
struct Vec {
|
struct Vec {
|
||||||
T elems[n];
|
T elems[n];
|
||||||
|
@ -73,4 +84,4 @@ __device__ inline void cp_async_wait() {
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
} // namespace gptq_marlin
|
} // namespace marlin
|
|
@ -30,7 +30,7 @@ inline std::string str(T x) {
|
||||||
return std::to_string(x);
|
return std::to_string(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace marlin {
|
namespace marlin_dense {
|
||||||
|
|
||||||
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||||
|
|
||||||
|
@ -1040,7 +1040,7 @@ void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace marlin
|
} // namespace marlin_dense
|
||||||
|
|
||||||
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||||
|
@ -1054,24 +1054,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
TORCH_CHECK(size_k == a.size(1),
|
TORCH_CHECK(size_k == a.size(1),
|
||||||
"Shape mismatch: a.size(1) = " + str(a.size(1)) +
|
"Shape mismatch: a.size(1) = " + str(a.size(1)) +
|
||||||
", size_k = " + str(size_k));
|
", size_k = " + str(size_k));
|
||||||
TORCH_CHECK(size_k % marlin::tile_size == 0,
|
TORCH_CHECK(size_k % marlin_dense::tile_size == 0,
|
||||||
"size_k = " + str(size_k) +
|
"size_k = " + str(size_k) + " is not divisible by tile_size = " +
|
||||||
" is not divisible by tile_size = " + str(marlin::tile_size));
|
str(marlin_dense::tile_size));
|
||||||
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
|
TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0),
|
||||||
"Shape mismatch: b_q_weight.size(0) = " +
|
"Shape mismatch: b_q_weight.size(0) = " +
|
||||||
str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
|
str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
|
||||||
", tile_size = " + str(marlin::tile_size));
|
", tile_size = " + str(marlin_dense::tile_size));
|
||||||
|
|
||||||
// Verify N
|
// Verify N
|
||||||
TORCH_CHECK(b_scales.size(1) == size_n,
|
TORCH_CHECK(b_scales.size(1) == size_n,
|
||||||
"b_scales.size(1) = " + str(b_scales.size(1)) +
|
"b_scales.size(1) = " + str(b_scales.size(1)) +
|
||||||
", size_n = " + str(size_n));
|
", size_n = " + str(size_n));
|
||||||
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
|
TORCH_CHECK(
|
||||||
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
|
b_q_weight.size(1) % marlin_dense::tile_size == 0,
|
||||||
" is not divisible by tile_size = " + str(marlin::tile_size));
|
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
|
||||||
|
" is not divisible by tile_size = " + str(marlin_dense::tile_size));
|
||||||
|
|
||||||
int actual_size_n =
|
int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) *
|
||||||
(b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit;
|
marlin_dense::pack_factor_4bit;
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
size_n == actual_size_n,
|
size_n == actual_size_n,
|
||||||
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
|
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
|
||||||
|
@ -1116,21 +1117,22 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
"Unexpected groupsize = " + str(groupsize));
|
"Unexpected groupsize = " + str(groupsize));
|
||||||
|
|
||||||
// Verify workspace size
|
// Verify workspace size
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0,
|
||||||
size_n % marlin::min_thread_n == 0,
|
"size_n = " + str(size_n) +
|
||||||
"size_n = " + str(size_n) +
|
", is not divisible by min_thread_n = " +
|
||||||
", is not divisible by min_thread_n = " + str(marlin::min_thread_n));
|
str(marlin_dense::min_thread_n));
|
||||||
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
|
int min_workspace_size =
|
||||||
|
(size_n / marlin_dense::min_thread_n) * marlin_dense::max_par;
|
||||||
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||||
"workspace.numel = " + str(workspace.numel()) +
|
"workspace.numel = " + str(workspace.numel()) +
|
||||||
" is below min_workspace_size = " + str(min_workspace_size));
|
" is below min_workspace_size = " + str(min_workspace_size));
|
||||||
|
|
||||||
int dev = a.get_device();
|
int dev = a.get_device();
|
||||||
marlin::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
|
marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
|
||||||
b_scales.data_ptr(), size_m, size_n, size_k,
|
b_scales.data_ptr(), size_m, size_n, size_k,
|
||||||
workspace.data_ptr(), groupsize, dev,
|
workspace.data_ptr(), groupsize, dev,
|
||||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n,
|
at::cuda::getCurrentCUDAStream(dev), thread_k,
|
||||||
sms, marlin::max_par);
|
thread_n, sms, marlin_dense::max_par);
|
||||||
|
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
|
|
||||||
#ifndef _data_types_cuh
|
#ifndef _data_types_cuh
|
||||||
#define _data_types_cuh
|
#define _data_types_cuh
|
||||||
#include "gptq_marlin.cuh"
|
#include "marlin.cuh"
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
|
|
||||||
namespace gptq_marlin {
|
namespace marlin {
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
class ScalarType {};
|
class ScalarType {};
|
||||||
|
@ -23,6 +23,7 @@ class ScalarType<half> {
|
||||||
using FragB = Vec<half2, 2>;
|
using FragB = Vec<half2, 2>;
|
||||||
using FragC = Vec<float, 4>;
|
using FragC = Vec<float, 4>;
|
||||||
using FragS = Vec<half2, 1>;
|
using FragS = Vec<half2, 1>;
|
||||||
|
using FragZP = Vec<half2, 4>;
|
||||||
|
|
||||||
static __device__ float inline num2float(const half x) {
|
static __device__ float inline num2float(const half x) {
|
||||||
return __half2float(x);
|
return __half2float(x);
|
||||||
|
@ -51,6 +52,7 @@ class ScalarType<nv_bfloat16> {
|
||||||
using FragB = Vec<nv_bfloat162, 2>;
|
using FragB = Vec<nv_bfloat162, 2>;
|
||||||
using FragC = Vec<float, 4>;
|
using FragC = Vec<float, 4>;
|
||||||
using FragS = Vec<nv_bfloat162, 1>;
|
using FragS = Vec<nv_bfloat162, 1>;
|
||||||
|
using FragZP = Vec<nv_bfloat162, 4>;
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
static __device__ float inline num2float(const nv_bfloat16 x) {
|
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||||
|
@ -72,6 +74,6 @@ class ScalarType<nv_bfloat16> {
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gptq_marlin
|
} // namespace marlin
|
||||||
|
|
||||||
#endif
|
#endif
|
|
@ -9,6 +9,7 @@ setup(
|
||||||
CUDAExtension(
|
CUDAExtension(
|
||||||
name="marlin_kernels",
|
name="marlin_kernels",
|
||||||
sources=[
|
sources=[
|
||||||
|
"marlin_kernels/awq_marlin_repack.cu",
|
||||||
"marlin_kernels/fp8_marlin.cu",
|
"marlin_kernels/fp8_marlin.cu",
|
||||||
"marlin_kernels/gptq_marlin.cu",
|
"marlin_kernels/gptq_marlin.cu",
|
||||||
"marlin_kernels/gptq_marlin_repack.cu",
|
"marlin_kernels/gptq_marlin_repack.cu",
|
||||||
|
|
|
@ -156,16 +156,26 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||||
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||||
)
|
)
|
||||||
|
|
||||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
if not self.sym:
|
||||||
|
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||||
|
else:
|
||||||
|
qzeros = None
|
||||||
|
|
||||||
|
if self.quant_method == "awq":
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||||
scales = weights.get_tensor(f"{prefix}.scales")
|
scales = weights.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
return repack_gptq_for_marlin(
|
return repack_gptq_for_marlin(
|
||||||
qweight=qweight,
|
qweight=qweight,
|
||||||
scales=scales,
|
scales=scales,
|
||||||
|
qzeros=qzeros,
|
||||||
g_idx=g_idx,
|
g_idx=g_idx,
|
||||||
bits=self.bits,
|
bits=self.bits,
|
||||||
desc_act=self.desc_act,
|
desc_act=self.desc_act,
|
||||||
groupsize=self.groupsize,
|
groupsize=self.groupsize,
|
||||||
|
quant_method=self.quant_method,
|
||||||
sym=self.sym,
|
sym=self.sym,
|
||||||
sharded_infeatures=False,
|
sharded_infeatures=False,
|
||||||
)
|
)
|
||||||
|
@ -275,14 +285,26 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||||
quantize=self.quantize,
|
quantize=self.quantize,
|
||||||
sym=self.sym,
|
sym=self.sym,
|
||||||
):
|
):
|
||||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
if not self.sym:
|
||||||
|
qzeros = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
qzeros = None
|
||||||
|
|
||||||
|
if self.quant_method == "awq":
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||||
return repack_gptq_for_marlin(
|
return repack_gptq_for_marlin(
|
||||||
qweight=qweight,
|
qweight=qweight,
|
||||||
scales=scales,
|
scales=scales,
|
||||||
|
qzeros=qzeros,
|
||||||
g_idx=g_idx,
|
g_idx=g_idx,
|
||||||
bits=self.bits,
|
bits=self.bits,
|
||||||
desc_act=self.desc_act,
|
desc_act=self.desc_act,
|
||||||
groupsize=self.groupsize,
|
groupsize=self.groupsize,
|
||||||
|
quant_method=self.quant_method,
|
||||||
sym=self.sym,
|
sym=self.sym,
|
||||||
sharded_infeatures=False,
|
sharded_infeatures=False,
|
||||||
)
|
)
|
||||||
|
@ -349,18 +371,31 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||||
quantize=self.quantize,
|
quantize=self.quantize,
|
||||||
sym=self.sym,
|
sym=self.sym,
|
||||||
):
|
):
|
||||||
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
|
||||||
for w2 in w[1:]:
|
if not self.sym:
|
||||||
torch.testing.assert_close(w2, w[0])
|
qzeros = torch.cat(
|
||||||
g_idx = w[0]
|
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
qzeros = None
|
||||||
|
|
||||||
|
if self.quant_method == "awq":
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
|
for w2 in w[1:]:
|
||||||
|
torch.testing.assert_close(w2, w[0])
|
||||||
|
g_idx = w[0]
|
||||||
|
|
||||||
return repack_gptq_for_marlin(
|
return repack_gptq_for_marlin(
|
||||||
qweight=qweight,
|
qweight=qweight,
|
||||||
scales=scales,
|
scales=scales,
|
||||||
|
qzeros=qzeros,
|
||||||
g_idx=g_idx,
|
g_idx=g_idx,
|
||||||
bits=self.bits,
|
bits=self.bits,
|
||||||
desc_act=self.desc_act,
|
desc_act=self.desc_act,
|
||||||
groupsize=self.groupsize,
|
groupsize=self.groupsize,
|
||||||
|
quant_method=self.quant_method,
|
||||||
sym=self.sym,
|
sym=self.sym,
|
||||||
sharded_infeatures=False,
|
sharded_infeatures=False,
|
||||||
)
|
)
|
||||||
|
@ -438,7 +473,19 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||||
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||||
)
|
)
|
||||||
|
|
||||||
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
if not self.sym:
|
||||||
|
if self.desc_act or self.groupsize == -1:
|
||||||
|
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||||
|
else:
|
||||||
|
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||||
|
else:
|
||||||
|
qzeros = None
|
||||||
|
|
||||||
|
if self.quant_method == "awq":
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
|
|
||||||
if self.desc_act or self.groupsize == -1:
|
if self.desc_act or self.groupsize == -1:
|
||||||
scales = weights.get_tensor(f"{prefix}.scales")
|
scales = weights.get_tensor(f"{prefix}.scales")
|
||||||
else:
|
else:
|
||||||
|
@ -449,10 +496,12 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||||
return repack_gptq_for_marlin(
|
return repack_gptq_for_marlin(
|
||||||
qweight=qweight,
|
qweight=qweight,
|
||||||
scales=scales,
|
scales=scales,
|
||||||
|
qzeros=qzeros,
|
||||||
g_idx=g_idx,
|
g_idx=g_idx,
|
||||||
bits=self.bits,
|
bits=self.bits,
|
||||||
desc_act=self.desc_act,
|
desc_act=self.desc_act,
|
||||||
groupsize=self.groupsize,
|
groupsize=self.groupsize,
|
||||||
|
quant_method=self.quant_method,
|
||||||
sym=self.sym,
|
sym=self.sym,
|
||||||
sharded_infeatures=sharded_in_features,
|
sharded_infeatures=sharded_in_features,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
@ -174,11 +175,12 @@ def can_use_gptq_marlin(
|
||||||
SYSTEM == "cuda"
|
SYSTEM == "cuda"
|
||||||
and marlin_kernels is not None
|
and marlin_kernels is not None
|
||||||
and has_sm_8_0
|
and has_sm_8_0
|
||||||
and quantize == "gptq"
|
and quantize in {"awq", "gptq"}
|
||||||
and quant_method == "gptq"
|
and quant_method in {"awq", "gptq"}
|
||||||
and bits in GPTQ_MARLIN_BITS
|
and bits in GPTQ_MARLIN_BITS
|
||||||
and groupsize in GPTQ_MARLIN_GROUP_SIZES
|
and groupsize in GPTQ_MARLIN_GROUP_SIZES
|
||||||
and sym
|
# We only suppord asymmetric quantization for AWQ.
|
||||||
|
and (sym or quant_method == "awq")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -234,6 +236,7 @@ class GPTQMarlinWeight(Weight):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
qweight: torch.Tensor
|
qweight: torch.Tensor
|
||||||
|
qzeros: torch.Tensor
|
||||||
scales: torch.Tensor
|
scales: torch.Tensor
|
||||||
g_idx: torch.Tensor
|
g_idx: torch.Tensor
|
||||||
perm: torch.Tensor
|
perm: torch.Tensor
|
||||||
|
@ -256,11 +259,13 @@ class GPTQMarlinWeight(Weight):
|
||||||
def repack_gptq_for_marlin(
|
def repack_gptq_for_marlin(
|
||||||
*,
|
*,
|
||||||
qweight: torch.Tensor,
|
qweight: torch.Tensor,
|
||||||
|
qzeros: Optional[torch.Tensor],
|
||||||
scales: torch.Tensor,
|
scales: torch.Tensor,
|
||||||
g_idx: torch.Tensor,
|
g_idx: Optional[torch.Tensor],
|
||||||
bits: int,
|
bits: int,
|
||||||
desc_act: bool,
|
desc_act: bool,
|
||||||
groupsize: int,
|
groupsize: int,
|
||||||
|
quant_method: str,
|
||||||
sym: bool,
|
sym: bool,
|
||||||
sharded_infeatures: bool,
|
sharded_infeatures: bool,
|
||||||
) -> GPTQMarlinWeight:
|
) -> GPTQMarlinWeight:
|
||||||
|
@ -279,30 +284,54 @@ def repack_gptq_for_marlin(
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
|
f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
|
||||||
)
|
)
|
||||||
if not sym:
|
if not (sym or quant_method == "awq"):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
|
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.")
|
||||||
|
|
||||||
weights_per_int = 32 // bits
|
weights_per_int = 32 // bits
|
||||||
in_features = qweight.shape[0] * weights_per_int
|
in_features = qweight.shape[0]
|
||||||
out_features = qweight.shape[1]
|
out_features = qweight.shape[1]
|
||||||
|
|
||||||
|
# AWQ uses column packing, GPTQ uses row packing
|
||||||
|
if quant_method == "awq":
|
||||||
|
out_features *= weights_per_int
|
||||||
|
else:
|
||||||
|
in_features *= weights_per_int
|
||||||
|
|
||||||
if in_features % groupsize != 0:
|
if in_features % groupsize != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Number of input features ({in_features}) not divisible by group size ({groupsize})"
|
f"Number of input features ({in_features}) not divisible by group size ({groupsize})"
|
||||||
)
|
)
|
||||||
|
|
||||||
if desc_act and groupsize != -1:
|
if g_idx is not None and desc_act and groupsize != -1:
|
||||||
perm = torch.argsort(g_idx).to(torch.int)
|
perm = torch.argsort(g_idx).to(torch.int)
|
||||||
g_idx = g_idx[perm]
|
g_idx = g_idx[perm]
|
||||||
else:
|
else:
|
||||||
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
|
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
|
||||||
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
|
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
|
||||||
|
|
||||||
repacked = marlin_kernels.gptq_marlin_repack(
|
if quant_method == "awq":
|
||||||
qweight, perm, in_features, out_features, bits
|
repacked = marlin_kernels.awq_marlin_repack(
|
||||||
)
|
qweight, in_features, out_features, bits
|
||||||
|
)
|
||||||
|
if qzeros is not None:
|
||||||
|
qzeros = awq_to_marlin_zero_points(
|
||||||
|
qzeros,
|
||||||
|
in_features // groupsize,
|
||||||
|
out_features,
|
||||||
|
bits,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
repacked = marlin_kernels.gptq_marlin_repack(
|
||||||
|
qweight, perm, in_features, out_features, bits
|
||||||
|
)
|
||||||
|
|
||||||
|
if qzeros is None:
|
||||||
|
qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)
|
||||||
|
|
||||||
scales = permute_scales(scales)
|
scales = permute_scales(scales)
|
||||||
|
|
||||||
|
@ -310,6 +339,7 @@ def repack_gptq_for_marlin(
|
||||||
|
|
||||||
return GPTQMarlinWeight(
|
return GPTQMarlinWeight(
|
||||||
qweight=repacked,
|
qweight=repacked,
|
||||||
|
qzeros=qzeros,
|
||||||
scales=scales,
|
scales=scales,
|
||||||
g_idx=g_idx,
|
g_idx=g_idx,
|
||||||
perm=perm,
|
perm=perm,
|
||||||
|
@ -343,6 +373,7 @@ class GPTQMarlinLinear(nn.Module):
|
||||||
self.is_full_k = weight.is_full_k
|
self.is_full_k = weight.is_full_k
|
||||||
|
|
||||||
self.qweight = weight.qweight
|
self.qweight = weight.qweight
|
||||||
|
self.qzeros = weight.qzeros
|
||||||
self.scales = weight.scales
|
self.scales = weight.scales
|
||||||
self.g_idx = weight.g_idx
|
self.g_idx = weight.g_idx
|
||||||
self.perm = weight.perm
|
self.perm = weight.perm
|
||||||
|
@ -363,6 +394,7 @@ class GPTQMarlinLinear(nn.Module):
|
||||||
A_flat,
|
A_flat,
|
||||||
self.qweight,
|
self.qweight,
|
||||||
self.scales,
|
self.scales,
|
||||||
|
self.qzeros,
|
||||||
self.g_idx,
|
self.g_idx,
|
||||||
self.perm,
|
self.perm,
|
||||||
self.workspace,
|
self.workspace,
|
||||||
|
@ -371,6 +403,7 @@ class GPTQMarlinLinear(nn.Module):
|
||||||
self.scales.shape[1],
|
self.scales.shape[1],
|
||||||
A_flat.shape[1],
|
A_flat.shape[1],
|
||||||
self.is_full_k,
|
self.is_full_k,
|
||||||
|
self.qzeros.numel() > 0,
|
||||||
)
|
)
|
||||||
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
|
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
|
||||||
|
|
||||||
|
@ -688,3 +721,116 @@ class MarlinLinear(nn.Module):
|
||||||
C += self.bias
|
C += self.bias
|
||||||
|
|
||||||
return C
|
return C
|
||||||
|
|
||||||
|
|
||||||
|
# Functions below are from vLLM
|
||||||
|
|
||||||
|
|
||||||
|
def get_pack_factor(bits: int) -> int:
|
||||||
|
if 32 % bits != 0:
|
||||||
|
raise ValueError(f"Cannot {bits} bit values into uint32")
|
||||||
|
return 32 // bits
|
||||||
|
|
||||||
|
|
||||||
|
def pack_cols(
|
||||||
|
q_w: torch.Tensor,
|
||||||
|
num_bits: int,
|
||||||
|
size_k: int,
|
||||||
|
size_n: int,
|
||||||
|
):
|
||||||
|
assert q_w.shape == (size_k, size_n)
|
||||||
|
|
||||||
|
pack_factor = get_pack_factor(num_bits)
|
||||||
|
assert size_n % pack_factor == 0
|
||||||
|
|
||||||
|
orig_device = q_w.device
|
||||||
|
|
||||||
|
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
||||||
|
|
||||||
|
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
|
||||||
|
|
||||||
|
for i in range(pack_factor):
|
||||||
|
q_res |= q_w[:, i::pack_factor] << num_bits * i
|
||||||
|
|
||||||
|
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||||
|
q_res = q_res.contiguous()
|
||||||
|
|
||||||
|
return q_res
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_cols(
|
||||||
|
packed_q_w: torch.Tensor,
|
||||||
|
num_bits: int,
|
||||||
|
size_k: int,
|
||||||
|
size_n: int,
|
||||||
|
):
|
||||||
|
pack_factor = get_pack_factor(num_bits)
|
||||||
|
assert size_n % pack_factor == 0
|
||||||
|
assert packed_q_w.shape == (
|
||||||
|
size_k,
|
||||||
|
size_n // pack_factor,
|
||||||
|
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
||||||
|
packed_q_w.shape, size_k, size_n, pack_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
orig_device = packed_q_w.device
|
||||||
|
|
||||||
|
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
|
||||||
|
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
|
||||||
|
|
||||||
|
mask = (1 << num_bits) - 1
|
||||||
|
for i in range(pack_factor):
|
||||||
|
vals = packed_q_w_cpu & mask
|
||||||
|
packed_q_w_cpu >>= num_bits
|
||||||
|
q_res[:, i::pack_factor] = vals
|
||||||
|
|
||||||
|
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||||
|
q_res = q_res.contiguous()
|
||||||
|
|
||||||
|
return q_res
|
||||||
|
|
||||||
|
|
||||||
|
def marlin_zero_points(
|
||||||
|
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Permute zero-points in a similar way to scales, but do not use the
|
||||||
|
# "single" permutation, since zero-points are applied on every MMA
|
||||||
|
zp = zp.reshape((-1, len(_scale_perm)))[:, _scale_perm]
|
||||||
|
|
||||||
|
# Interleave column dim (for the dequantize code) and pack it to int32
|
||||||
|
if num_bits == 4:
|
||||||
|
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||||
|
elif num_bits == 8:
|
||||||
|
interleave = numpy.array([0, 2, 1, 3])
|
||||||
|
else:
|
||||||
|
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||||
|
|
||||||
|
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||||
|
zp = zp.reshape((-1, size_n)).contiguous()
|
||||||
|
zp = pack_cols(zp, num_bits, size_k, size_n)
|
||||||
|
|
||||||
|
return zp
|
||||||
|
|
||||||
|
|
||||||
|
def awq_to_marlin_zero_points(
|
||||||
|
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# AWQ zero-points are quantized and packed on the column dim.
|
||||||
|
# In addition, the values are permuted based on dequantizer.
|
||||||
|
# Here we undo both of these, and then apply marlin permutation
|
||||||
|
# and pack it back.
|
||||||
|
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
|
||||||
|
|
||||||
|
# Undo interleaving (use argsort(..) to get inverse perm)
|
||||||
|
if num_bits == 4:
|
||||||
|
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
|
||||||
|
elif num_bits == 8:
|
||||||
|
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
|
||||||
|
else:
|
||||||
|
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||||
|
|
||||||
|
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
|
||||||
|
q_zp = q_zp.reshape((-1, size_n)).contiguous()
|
||||||
|
|
||||||
|
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
|
||||||
|
return marlin_zp
|
||||||
|
|
|
@ -34,7 +34,7 @@ def _get_quantizer_config(model_id, revision):
|
||||||
groupsize = -1
|
groupsize = -1
|
||||||
quant_method = "gptq"
|
quant_method = "gptq"
|
||||||
checkpoint_format = None
|
checkpoint_format = None
|
||||||
sym = True
|
sym = False
|
||||||
desc_act = False
|
desc_act = False
|
||||||
|
|
||||||
filename = "config.json"
|
filename = "config.json"
|
||||||
|
@ -52,12 +52,17 @@ def _get_quantizer_config(model_id, revision):
|
||||||
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
|
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "zero_point" in data["quantization_config"]:
|
||||||
|
sym = not data["quantization_config"]["zero_point"]
|
||||||
|
quant_method = "awq"
|
||||||
|
elif "sym" in data["quantization_config"]:
|
||||||
|
sym = data["quantization_config"]["sym"]
|
||||||
|
|
||||||
bits = data["quantization_config"]["bits"]
|
bits = data["quantization_config"]["bits"]
|
||||||
groupsize = data["quantization_config"]["group_size"]
|
groupsize = data["quantization_config"]["group_size"]
|
||||||
# Order is important here, desc_act is missing on some real models
|
# Order is important here, desc_act is missing on some real models
|
||||||
quant_method = data["quantization_config"]["quant_method"]
|
quant_method = data["quantization_config"]["quant_method"]
|
||||||
checkpoint_format = data["quantization_config"].get("checkpoint_format")
|
checkpoint_format = data["quantization_config"].get("checkpoint_format")
|
||||||
sym = data["quantization_config"]["sym"]
|
|
||||||
desc_act = data["quantization_config"]["desc_act"]
|
desc_act = data["quantization_config"]["desc_act"]
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = "quantize_config.json"
|
filename = "quantize_config.json"
|
||||||
|
@ -72,7 +77,13 @@ def _get_quantizer_config(model_id, revision):
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
bits = data["bits"]
|
bits = data["bits"]
|
||||||
groupsize = data["group_size"]
|
groupsize = data["group_size"]
|
||||||
sym = data["sym"]
|
|
||||||
|
if "zero_point" in data:
|
||||||
|
sym = not data["zero_point"]
|
||||||
|
quant_method = "awq"
|
||||||
|
elif "sym" in data:
|
||||||
|
sym = data["sym"]
|
||||||
|
|
||||||
desc_act = data["desc_act"]
|
desc_act = data["desc_act"]
|
||||||
if "version" in data and data["version"] == "GEMM":
|
if "version" in data and data["version"] == "GEMM":
|
||||||
quant_method = "awq"
|
quant_method = "awq"
|
||||||
|
|
Loading…
Reference in New Issue