From 564199bab3d5151bb9b02c0a6fb08f4e1d58ff1c Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 21 Dec 2023 17:25:22 +0100 Subject: [PATCH] feat: update exllamav2 kernels (#1370) Co-authored-by: Nicolas Patry --- .../exllamav2_kernels/config.h | 2 + .../exllamav2_kernels/cuda/q_gemm.cu | 98 +++--- .../exllamav2_kernels/cuda/q_gemm.cuh | 5 +- .../exllamav2_kernels/cuda/q_gemm_kernel.cuh | 317 +++++++++++------- .../cuda/q_gemm_kernel_gptq.cuh | 142 +++++--- .../exllamav2_kernels/cuda/q_matrix.cu | 70 ++-- .../exllamav2_kernels/cuda/q_matrix.cuh | 4 +- .../exllamav2_kernels/cuda/quant/qdq_util.cuh | 2 + .../exllamav2_kernels/cuda/util.cuh | 12 + .../exllamav2_kernels/ext.cpp | 5 + server/tests/utils/test_hub.py | 8 +- .../flash_santacoder_modeling.py | 2 +- .../utils/gptq/exllamav2.py | 33 ++ server/text_generation_server/utils/hub.py | 32 +- server/text_generation_server/utils/layers.py | 6 +- server/text_generation_server/utils/log.py | 6 + .../text_generation_server/utils/weights.py | 36 +- 17 files changed, 525 insertions(+), 255 deletions(-) create mode 100644 server/text_generation_server/utils/log.py diff --git a/server/exllamav2_kernels/exllamav2_kernels/config.h b/server/exllamav2_kernels/exllamav2_kernels/config.h index 86baaf41..32a1a37d 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/config.h +++ b/server/exllamav2_kernels/exllamav2_kernels/config.h @@ -2,6 +2,7 @@ #define _config_h #define MAX_Q_GEMM_ROWS 50 +#define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS #define QMODE_2BIT 1 #define QMODE_3BIT 1 @@ -10,4 +11,5 @@ #define QMODE_6BIT 0 #define QMODE_8BIT 0 + #endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu index 351b9cd5..b4e4cf22 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu @@ -10,16 +10,19 @@ #include "quant/qdq_6.cuh" #include "quant/qdq_8.cuh" -#define BLOCK_KN_SIZE 128 -#define BLOCK_M_SIZE_MAX 8 -#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32) +#define GPTQ_BLOCK_KN_SIZE 128 +#define GPTQ_BLOCK_M_SIZE_MAX 8 +#define GPTQ_MAX_GROUPS_IN_BLOCK (GPTQ_BLOCK_KN_SIZE / 32) + +#define EXL2_BLOCK_KN_SIZE 64 +#define EXL2_BLOCK_M_SIZE_MAX 8 +#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32) + #define CLEAR_N_SIZE 256 #include "q_gemm_kernel.cuh" #include "q_gemm_kernel_gptq.cuh" -#include "compat_gemm.cuh" - void gemm_half_q_half_cuda_part ( const half* a, @@ -29,20 +32,23 @@ void gemm_half_q_half_cuda_part int size_n, int size_k, int m_count, - bool clear + bool clear, + const half* r_weights, + int r_weights_stride, + bool mul_r_weights ) { if (!b->is_gptq) { dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; + blockDim.x = EXL2_BLOCK_KN_SIZE; blockDim.y = 1; blockDim.z = 1; - gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); + gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4); gridDim.y = DIVIDE(size_m, m_count); - gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + gridDim.z = DIVIDE(size_k, EXL2_BLOCK_KN_SIZE); - fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count); + fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights); kernel<<>> ( @@ -55,7 +61,7 @@ void gemm_half_q_half_cuda_part size_n, size_k, b->groups, - b->groupsize, + b->cuda_q_group_map, b->cuda_q_perm, b->rows_8, b->rows_6, @@ -63,24 +69,27 @@ void gemm_half_q_half_cuda_part b->rows_4, b->rows_3, b->rows_2, - clear + clear, + r_weights, + r_weights_stride ); } else { dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; + blockDim.x = GPTQ_BLOCK_KN_SIZE; blockDim.y = 1; blockDim.z = 1; - gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); + gridDim.x = DIVIDE(size_n, GPTQ_BLOCK_KN_SIZE * 4); gridDim.y = DIVIDE(size_m, m_count); - gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + gridDim.z = DIVIDE(size_k, GPTQ_BLOCK_KN_SIZE); - fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count); + fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(m_count, r_weights != NULL, mul_r_weights); -// DBGX((uint64_t) b->cuda_q_perm); -// DBGI(b->rows_4); -// DBGI(b->height); +// DBGX((uint64_t) r_weights); +// if (r_weights) +// print_global_mem(r_weights, 1, 1, 1); +// DBGI(r_weights_stride); kernel<<>> ( @@ -93,10 +102,12 @@ void gemm_half_q_half_cuda_part size_n, size_k, b->groups, - b->groupsize, + b->gptq_groupsize, b->cuda_q_perm, b->rows_4, - clear + clear, + r_weights, + r_weights_stride ); } } @@ -112,13 +123,14 @@ void gemm_half_q_half_cuda int size_k, bool clear, half* temp_dq, - bool force_cuda + bool force_cuda, + const half* r_weights, + const int r_weights_stride, + bool mul_r_weights ) { if (size_m > MAX_Q_GEMM_ROWS && !force_cuda) { - //printf("cublas\n"); - // Reconstruct FP16 matrix, then cuBLAS if (!temp_dq) temp_dq = b->temp_dq; @@ -139,12 +151,12 @@ void gemm_half_q_half_cuda //const float alpha = 1.0f; //const float beta = clear ? 0.0f : 1.0f; //cublasSgemmEx(cublas_handle, - // CUBLAS_OP_N, - // CUBLAS_OP_N, - // size_n, size_m, size_k, - // &alpha, temp_dq, CUDA_R_16F, size_n, - // a, CUDA_R_16F, size_k, - // &beta, c, CUDA_R_16F, size_n); + // CUBLAS_OP_N, + // CUBLAS_OP_N, + // size_n, size_m, size_k, + // &alpha, temp_dq, CUDA_R_16F, size_n, + // a, CUDA_R_16F, size_k, + // &beta, c, CUDA_R_16F, size_n); //const float alpha = 1.0f; //const float beta = clear ? 0.0f : 1.0f; @@ -158,24 +170,21 @@ void gemm_half_q_half_cuda } else { - //printf("cuda\n"); - // Quantized matmul - //if (clear) clear_tensor_cuda(c, size_m, size_n); - - int max_chunks = size_m / BLOCK_M_SIZE_MAX; - int last_chunk = max_chunks * BLOCK_M_SIZE_MAX; + int block_m_size_max = b->is_gptq ? GPTQ_BLOCK_M_SIZE_MAX : EXL2_BLOCK_M_SIZE_MAX; + int max_chunks = size_m / block_m_size_max; + int last_chunk = max_chunks * block_m_size_max; int last_chunk_size = size_m - last_chunk; if (max_chunks) { - gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear); + gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, block_m_size_max, clear, r_weights, r_weights_stride, mul_r_weights); } if (last_chunk_size) { - gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear); + gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear, r_weights, r_weights_stride, mul_r_weights); } } } @@ -201,11 +210,10 @@ void clear_tensor_cuda int size_n ) { - return; - dim3 blockDim, gridDim; - blockDim.x = CLEAR_N_SIZE; - blockDim.y = 1; - gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE); - gridDim.y = size_m; - clear_kernel<<>>(c, size_m, size_n); +// dim3 blockDim, gridDim; +// blockDim.x = CLEAR_N_SIZE; +// blockDim.y = 1; +// gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE); +// gridDim.y = size_m; +// clear_kernel<<>>(c, size_m, size_n); } diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh index c69f1a70..b643f915 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh @@ -20,7 +20,10 @@ void gemm_half_q_half_cuda int size_k, bool clear = false, half* reconstruct = NULL, - bool force_cuda = false + bool force_cuda = false, + const half* r_weights = NULL, + const int r_weights_stride = 0, + bool mul_r_weights = false ); void clear_tensor_cuda diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh index 0b899a84..9cd2ba01 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh @@ -1,8 +1,5 @@ #include "compat.cuh" -#include -#include - __forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h) { half2 result = {}; @@ -60,6 +57,47 @@ __forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, c return fma(result_f, qs_f, g_result); } +__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h) +{ + // Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127 + + float result = {}; + #pragma unroll + for (int i = 0; i < 4; i++) + { + half2 w01 = dq[i]; + float w0 = __low2float(w01); + float w1 = __high2float(w01); + float x0 = __half2float(*a_ptr++); + float x1 = __half2float(*a_ptr++); + result = fma(w0, x0, result); + result = fma(w1, x1, result); + } + float qs = __half2float(qs_h); + result *= qs; + half result_h = __float2half_rn(result); + return __hadd(result_h, g_result); +} + +__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + +__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} typedef void (*fp_gemm_half_q_half_kernel) @@ -73,7 +111,7 @@ typedef void (*fp_gemm_half_q_half_kernel) const int, const int, const int, - const int, + const uint16_t*, const uint16_t*, const int, const int, @@ -81,10 +119,12 @@ typedef void (*fp_gemm_half_q_half_kernel) const int, const int, const int, - const bool + const bool, + const half*, + const int ); -template +template __global__ void gemm_half_q_half_kernel ( const half* __restrict__ a, @@ -96,7 +136,7 @@ __global__ void gemm_half_q_half_kernel const int size_n, const int size_k, const int groups, - const int groupsize, + const uint16_t* __restrict__ b_q_group_map, const uint16_t* __restrict__ b_q_perm, const int rows_8, const int rows_6, @@ -104,7 +144,9 @@ __global__ void gemm_half_q_half_kernel const int rows_4, const int rows_3, const int rows_2, - const bool clear + const bool clear, + const half* r_weights, + const int r_weights_stride ) { MatrixView_half a_(a, size_m, size_k); @@ -115,18 +157,34 @@ __global__ void gemm_half_q_half_kernel // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_n = blockIdx.x * EXL2_BLOCK_KN_SIZE * 4; int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; + int offset_k = blockIdx.z * EXL2_BLOCK_KN_SIZE; - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_n = min(offset_n + EXL2_BLOCK_KN_SIZE * 4, size_n); int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + int end_k = min(offset_k + EXL2_BLOCK_KN_SIZE, size_k); int n = offset_n + t * 4; + // Read weights + + half_uint16 weights[MAX_Q_GEMM_WEIGHTS]; + if constexpr (use_r_weights) + { + uint16_t any_w = 0; + const half* w_ptr = r_weights; + for (int m = 0; m < m_count; ++m) + { + weights[m].as_half = *w_ptr; + w_ptr += r_weights_stride; + any_w |= weights[m].as_uint16; + } + if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!) + } + // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + __shared__ half block_a[m_count][EXL2_BLOCK_KN_SIZE]; if (offset_k + t < end_k) { @@ -135,6 +193,7 @@ __global__ void gemm_half_q_half_kernel const half* a_ptr = a_.item_ptr(offset_m + m, 0); half* block_a_ptr = block_a[m]; half a0 = a_ptr[b_q_perm[offset_k + t]]; +// half a0 = a_ptr[offset_k + t]; block_a_ptr[t] = a0; } } @@ -153,14 +212,19 @@ __global__ void gemm_half_q_half_kernel // Find initial group - int group = offset_k / groupsize; + //int group = offset_k / groupsize; + int group = b_q_group_map[offset_k * 2]; + +// if (offset_m == 0 && t == 0) +// DBGI2(offset_k, group); // Preload scales - float scales[MAX_GROUPS_IN_BLOCK][4]; + half scales[EXL2_MAX_GROUPS_IN_BLOCK][4]; - int groups_in_block = DIVIDE((end_k - offset_k), groupsize); - for (int g = 0; g < groups_in_block; g++) + //int groups_in_block = DIVIDE((end_k - offset_k), groupsize); + int temp_k = offset_k; + for (int g = 0; temp_k < end_k; g++) { int qscales[4]; b_q_scale_.item4(qscales, group + g, n); @@ -168,11 +232,12 @@ __global__ void gemm_half_q_half_kernel qscales[1]++; qscales[2]++; qscales[3]++; - float maxscale = __half2float(b_q_scale_max[group + g]); - scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale; - scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale; - scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale; - scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale; + half maxscale = b_q_scale_max[group + g]; + scales[g][0] = __hmul(__int2half_rn(qscales[0] * qscales[0]), maxscale); + scales[g][1] = __hmul(__int2half_rn(qscales[1] * qscales[1]), maxscale); + scales[g][2] = __hmul(__int2half_rn(qscales[2] * qscales[2]), maxscale); + scales[g][3] = __hmul(__int2half_rn(qscales[3] * qscales[3]), maxscale); + temp_k += b_q_group_map[temp_k * 2 + 1]; } // a, b offset @@ -193,20 +258,20 @@ __global__ void gemm_half_q_half_kernel const uint32_t* b_ptr = b_q_weight + qk * size_n + n; const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; + int a_stride = EXL2_BLOCK_KN_SIZE; // Initial group int scales_idx = 0; - float qs_f0 = scales[scales_idx][0]; - float qs_f1 = scales[scales_idx][1]; - float qs_f2 = scales[scales_idx][2]; - float qs_f3 = scales[scales_idx][3]; - int nextgroup = offset_k + groupsize; + half qs_h0 = scales[scales_idx][0]; + half qs_h1 = scales[scales_idx][1]; + half qs_h2 = scales[scales_idx][2]; + half qs_h3 = scales[scales_idx][3]; + int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1]; // Column result - float block_c[m_count][4] = {}; + half block_c[m_count][4] = {}; // Dequantize groups @@ -218,11 +283,11 @@ __global__ void gemm_half_q_half_kernel { group++; scales_idx++; - qs_f0 = scales[scales_idx][0]; - qs_f1 = scales[scales_idx][1]; - qs_f2 = scales[scales_idx][2]; - qs_f3 = scales[scales_idx][3]; - nextgroup += groupsize; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; } #pragma unroll @@ -240,10 +305,11 @@ __global__ void gemm_half_q_half_kernel for (int m = 0; m < m_count; m++) { - block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); - block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); - block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); - block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); } a_ptr += 8; } @@ -256,11 +322,11 @@ __global__ void gemm_half_q_half_kernel { group++; scales_idx++; - qs_f0 = scales[scales_idx][0]; - qs_f1 = scales[scales_idx][1]; - qs_f2 = scales[scales_idx][2]; - qs_f3 = scales[scales_idx][3]; - nextgroup += groupsize; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; } #pragma unroll @@ -279,10 +345,11 @@ __global__ void gemm_half_q_half_kernel for (int m = 0; m < m_count; m++) { - block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); - block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); - block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); - block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); } a_ptr += 16; } @@ -295,11 +362,11 @@ __global__ void gemm_half_q_half_kernel { group++; scales_idx++; - qs_f0 = scales[scales_idx][0]; - qs_f1 = scales[scales_idx][1]; - qs_f2 = scales[scales_idx][2]; - qs_f3 = scales[scales_idx][3]; - nextgroup += groupsize; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; } #pragma unroll @@ -320,10 +387,11 @@ __global__ void gemm_half_q_half_kernel for (int m = 0; m < m_count; m++) { - block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); - block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); - block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); - block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); } a_ptr += 32; } @@ -337,11 +405,11 @@ __global__ void gemm_half_q_half_kernel { group++; scales_idx++; - qs_f0 = scales[scales_idx][0]; - qs_f1 = scales[scales_idx][1]; - qs_f2 = scales[scales_idx][2]; - qs_f3 = scales[scales_idx][3]; - nextgroup += groupsize; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; } #pragma unroll @@ -358,10 +426,11 @@ __global__ void gemm_half_q_half_kernel for (int m = 0; m < m_count; m++) { - block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); - block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); - block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); - block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); } a_ptr += 8; } @@ -374,11 +443,11 @@ __global__ void gemm_half_q_half_kernel { group++; scales_idx++; - qs_f0 = scales[scales_idx][0]; - qs_f1 = scales[scales_idx][1]; - qs_f2 = scales[scales_idx][2]; - qs_f3 = scales[scales_idx][3]; - nextgroup += groupsize; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; } #pragma unroll @@ -397,10 +466,11 @@ __global__ void gemm_half_q_half_kernel for (int m = 0; m < m_count; m++) { - block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); - block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); - block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); - block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); } a_ptr += 32; } @@ -413,15 +483,15 @@ __global__ void gemm_half_q_half_kernel { group++; scales_idx++; - qs_f0 = scales[scales_idx][0]; - qs_f1 = scales[scales_idx][1]; - qs_f2 = scales[scales_idx][2]; - qs_f3 = scales[scales_idx][3]; - nextgroup += groupsize; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; } #pragma unroll - for (int j = 0; j < 2; j++) + for (int j = 0; j < 1; j++) { int4 load_int4[1]; load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; @@ -434,15 +504,16 @@ __global__ void gemm_half_q_half_kernel for (int m = 0; m < m_count; m++) { - block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); - block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); - block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); - block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); } a_ptr += 16; } - k += 32; + k += 16; } // Accumulate column sums in c @@ -450,38 +521,60 @@ __global__ void gemm_half_q_half_kernel for (int m = 0; m < m_count; m++) { half2* out = (half2*)c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); - half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + + if constexpr (mul_r_weights) + { + half2 w_mul2 = __half2half2(weights[m].as_half); + result01 = __hmul2(result01, w_mul2); + result23 = __hmul2(result23, w_mul2); + } + atomicAdd(out , result01); atomicAdd(out + 1, result23); +// *out = result01; +// *(out + 1) = result23; } } -fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count) +template +struct map_m_count_exl2 { + static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count) + { + #if EXL2_BLOCK_M_SIZE_MAX >= 1 + if (m_count == 1) return gemm_half_q_half_kernel<1, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 2 + if (m_count == 2) return gemm_half_q_half_kernel<2, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 3 + if (m_count == 3) return gemm_half_q_half_kernel<3, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 4 + if (m_count == 4) return gemm_half_q_half_kernel<4, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 5 + if (m_count == 5) return gemm_half_q_half_kernel<5, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 6 + if (m_count == 6) return gemm_half_q_half_kernel<6, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 7 + if (m_count == 7) return gemm_half_q_half_kernel<7, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 8 + if (m_count == 8) return gemm_half_q_half_kernel<8, use_r_weights, mul_r_weights>; + #endif + return NULL; + } +}; + +fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count, bool r_weights, bool mul_r_weights) { - #if BLOCK_M_SIZE_MAX >= 1 - if (m_count == 1) return gemm_half_q_half_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 2 - if (m_count == 2) return gemm_half_q_half_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 3 - if (m_count == 3) return gemm_half_q_half_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 4 - if (m_count == 4) return gemm_half_q_half_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 5 - if (m_count == 5) return gemm_half_q_half_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 6 - if (m_count == 6) return gemm_half_q_half_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 7 - if (m_count == 7) return gemm_half_q_half_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 8 - if (m_count == 8) return gemm_half_q_half_kernel; - #endif + if (!r_weights && !mul_r_weights) return map_m_count_exl2::pick_gemm_half_q_half_kernel(m_count); + if (!r_weights && mul_r_weights) return map_m_count_exl2::pick_gemm_half_q_half_kernel(m_count); + if ( r_weights && !mul_r_weights) return map_m_count_exl2< true, false>::pick_gemm_half_q_half_kernel(m_count); + if ( r_weights && mul_r_weights) return map_m_count_exl2< true, true>::pick_gemm_half_q_half_kernel(m_count); return NULL; } diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh index ebaa42d0..74b0db2b 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh @@ -18,6 +18,15 @@ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr) return __half2float(__low2half(result)) + __half2float(__high2half(result)); } +__forceinline__ __device__ half2 dot22_8_h2(half2(&dq)[4], const half* a_ptr) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return result; +} + typedef void (*fp_gemm_half_q_half_gptq_kernel) ( const half*, @@ -32,10 +41,12 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel) const int, const uint16_t*, const int, - const bool + const bool, + const half*, + const int ); -template +template __global__ void gemm_half_q_half_gptq_kernel ( const half* __restrict__ a, @@ -50,7 +61,9 @@ __global__ void gemm_half_q_half_gptq_kernel const int groupsize, const uint16_t* __restrict__ b_q_perm, const int rows_4, - const bool clear + const bool clear, + const half* r_weights, + const int r_weights_stride ) { MatrixView_half a_(a, size_m, size_k); @@ -62,19 +75,35 @@ __global__ void gemm_half_q_half_gptq_kernel // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_n = blockIdx.x * GPTQ_BLOCK_KN_SIZE * 4; int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; + int offset_k = blockIdx.z * GPTQ_BLOCK_KN_SIZE; - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_n = min(offset_n + GPTQ_BLOCK_KN_SIZE * 4, size_n); int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + int end_k = min(offset_k + GPTQ_BLOCK_KN_SIZE, size_k); int n = offset_n + t * 4; + // Read weights + + half_uint16 weights[MAX_Q_GEMM_WEIGHTS]; + if constexpr (use_r_weights) + { + uint16_t any_w = 0; + const half* w_ptr = r_weights; + for (int m = 0; m < m_count; ++m) + { + weights[m].as_half = *w_ptr; + w_ptr += r_weights_stride; + any_w |= weights[m].as_uint16; + } + if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!) + } + // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + __shared__ half block_a[m_count][GPTQ_BLOCK_KN_SIZE]; if (offset_k + t < end_k) { @@ -113,16 +142,16 @@ __global__ void gemm_half_q_half_gptq_kernel const uint32_t* b_ptr = b_q_weight + qk * size_n + n; const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; + int a_stride = GPTQ_BLOCK_KN_SIZE; // Initial group int zeros[4]; - float scales[4]; + half2 scales[4]; half2 z1z16[4][2]; half2 y1y16[4][2]; b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_f(scales, group, n); + b_gptq_scales_.item4_h2(scales, group, n); dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); @@ -132,7 +161,7 @@ __global__ void gemm_half_q_half_gptq_kernel // Column result - float block_c[m_count][4] = {}; + half2 block_c[m_count][4] = {}; // Dequantize and multiply @@ -144,7 +173,7 @@ __global__ void gemm_half_q_half_gptq_kernel group++; nextgroup += groupsize; b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_f(scales, group, n); + b_gptq_scales_.item4_h2(scales, group, n); dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); @@ -166,10 +195,11 @@ __global__ void gemm_half_q_half_gptq_kernel #pragma unroll for (int m = 0; m < m_count; m++) { - block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); - block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); - block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); - block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = __hfma2(dot22_8_h2(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); + block_c[m][1] = __hfma2(dot22_8_h2(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); + block_c[m][2] = __hfma2(dot22_8_h2(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); + block_c[m][3] = __hfma2(dot22_8_h2(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); } b_ptr += size_n; @@ -182,38 +212,62 @@ __global__ void gemm_half_q_half_gptq_kernel for (int m = 0; m < m_count; m++) { half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); - half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); + half result0 = __hadd(__low2half(block_c[m][0]), __high2half(block_c[m][0])); + half result1 = __hadd(__low2half(block_c[m][1]), __high2half(block_c[m][1])); + half result2 = __hadd(__low2half(block_c[m][2]), __high2half(block_c[m][2])); + half result3 = __hadd(__low2half(block_c[m][3]), __high2half(block_c[m][3])); + half2 result01 = __halves2half2(result0, result1); + half2 result23 = __halves2half2(result2, result3); + + if constexpr (mul_r_weights) + { + half2 w_mul2 = __half2half2(weights[m].as_half); + result01 = __hmul2(result01, w_mul2); + result23 = __hmul2(result23, w_mul2); + } + atomicAdd(out , result01); atomicAdd(out + 1, result23); } } -fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count) +template +struct map_m_count_gptq { + static constexpr fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(int m_count) + { + #if GPTQ_BLOCK_M_SIZE_MAX >= 1 + if (m_count == 1) return gemm_half_q_half_gptq_kernel<1, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 2 + if (m_count == 2) return gemm_half_q_half_gptq_kernel<2, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 3 + if (m_count == 3) return gemm_half_q_half_gptq_kernel<3, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 4 + if (m_count == 4) return gemm_half_q_half_gptq_kernel<4, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 5 + if (m_count == 5) return gemm_half_q_half_gptq_kernel<5, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 6 + if (m_count == 6) return gemm_half_q_half_gptq_kernel<6, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 7 + if (m_count == 7) return gemm_half_q_half_gptq_kernel<7, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 8 + if (m_count == 8) return gemm_half_q_half_gptq_kernel<8, use_r_weights, mul_r_weights>; + #endif + return NULL; + } +}; + +fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(const int m_count, bool r_weights, bool mul_r_weights) { - #if BLOCK_M_SIZE_MAX >= 1 - if (m_count == 1) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 2 - if (m_count == 2) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 3 - if (m_count == 3) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 4 - if (m_count == 4) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 5 - if (m_count == 5) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 6 - if (m_count == 6) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 7 - if (m_count == 7) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 8 - if (m_count == 8) return gemm_half_q_half_gptq_kernel; - #endif + if (!r_weights && !mul_r_weights) return map_m_count_gptq::pick_gemm_half_q_half_gptq_kernel(m_count); + if (!r_weights && mul_r_weights) return map_m_count_gptq::pick_gemm_half_q_half_gptq_kernel(m_count); + if ( r_weights && !mul_r_weights) return map_m_count_gptq< true, false>::pick_gemm_half_q_half_gptq_kernel(m_count); + if ( r_weights && mul_r_weights) return map_m_count_gptq< true, true>::pick_gemm_half_q_half_gptq_kernel(m_count); return NULL; } diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu index 6aed7470..ae08cc1f 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu @@ -57,6 +57,7 @@ QMatrix::QMatrix uint32_t* _q_scale, half* _q_scale_max, uint16_t* _q_groups, + uint16_t* _q_group_map, uint32_t* _gptq_qzeros, half* _gptq_scales, @@ -80,13 +81,17 @@ QMatrix::QMatrix cuda_q_scale = _q_scale; cuda_q_scale_max = _q_scale_max; cuda_q_groups = _q_groups; + cuda_q_group_map = _q_group_map; cuda_gptq_qzeros = _gptq_qzeros; cuda_gptq_scales = _gptq_scales; is_gptq = (_gptq_qzeros != NULL); - groupsize = 1; - while (groupsize * groups < height) groupsize *= 2; + if (is_gptq) + { + gptq_groupsize = 1; + while (gptq_groupsize * groups < height) gptq_groupsize *= 2; + } // Create group map @@ -102,15 +107,26 @@ QMatrix::QMatrix uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t)); cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost); + int row = 0; for (int i = 0; i < groups; i++) { int bits = cpu_q_groups[i * 2]; - if (bits == 8) rows_8 += groupsize; - if (bits == 6) rows_6 += groupsize; - if (bits == 5) rows_5 += groupsize; - if (bits == 4) rows_4 += groupsize; - if (bits == 3) rows_3 += groupsize; - if (bits == 2) rows_2 += groupsize; + + int rows; + if (i < groups - 1) + { + int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1]; + rows = qrows * 32 / bits; + } + else rows = height - row; + + if (bits == 8) rows_8 += rows; + if (bits == 6) rows_6 += rows; + if (bits == 5) rows_5 += rows; + if (bits == 4) rows_4 += rows; + if (bits == 3) rows_3 += rows; + if (bits == 2) rows_2 += rows; + row += rows; } free(cpu_q_groups); @@ -138,6 +154,13 @@ QMatrix::QMatrix } } +// DBGI(rows_8); +// DBGI(rows_6); +// DBGI(rows_5); +// DBGI(rows_4); +// DBGI(rows_3); +// DBGI(rows_2); + // Shuffle quantized data dim3 blockDim, gridDim; @@ -283,10 +306,10 @@ __global__ void reconstruct_kernel const uint16_t* __restrict__ b_q_perm, const uint32_t* __restrict__ b_q_scale, const half* __restrict__ b_q_scale_max, - //const uint16_t* __restrict__ b_q_groups, + const uint16_t* __restrict__ b_q_group_map, const int size_k, const int size_n, - const int groupsize, + //const int groupsize, const int groups, half* __restrict__ b, const int rows_8, @@ -317,7 +340,8 @@ __global__ void reconstruct_kernel // Find initial group - int group = offset_k / groupsize; + // int group = offset_k / groupsize; + int group = b_q_group_map[offset_k * 2]; int pre_rows_8 = min(rows_8, offset_k); int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0; @@ -337,7 +361,7 @@ __global__ void reconstruct_kernel half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); half2 qs_h2 = __halves2half2(qs_h, qs_h); - int nextgroup = offset_k + groupsize; + int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1]; int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); int k = offset_k; @@ -347,7 +371,7 @@ __global__ void reconstruct_kernel while (k < rows_8 && k < end_k) { - if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } for (int p = 0; p < 4; p++) { half2 dq[4]; @@ -363,7 +387,7 @@ __global__ void reconstruct_kernel while (k < rows_6 && k < end_k) { - if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } for (int p = 0; p < 2; p++) { half2 dq[8]; @@ -380,7 +404,7 @@ __global__ void reconstruct_kernel while (k < rows_5 && k < end_k) { - if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } for (int p = 0; p < 1; p++) { half2 dq[16]; @@ -399,7 +423,7 @@ __global__ void reconstruct_kernel while (k < rows_4 && k < end_k) { - if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } for (int p = 0; p < 4; p++) { half2 dq[4]; @@ -414,7 +438,7 @@ __global__ void reconstruct_kernel while (k < rows_3 && k < end_k) { - if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } for (int p = 0; p < 1; p++) { half2 dq[16]; @@ -431,8 +455,8 @@ __global__ void reconstruct_kernel while (k < rows_2 && k < end_k) { - if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } - for (int p = 0; p < 2; p++) + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 1; p++) { half2 dq[8]; uint32_t q_0 = *b_ptr; b_ptr += size_n; @@ -441,7 +465,7 @@ __global__ void reconstruct_kernel half* dqh = (half*) dq; for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]); } - k += 32; + k += 16; } } @@ -461,10 +485,10 @@ void QMatrix::reconstruct(half* out) cuda_q_perm, cuda_q_scale, cuda_q_scale_max, - //cuda_q_groups, + cuda_q_group_map, height, width, - groupsize, + //groupsize, groups, out, rows_8, @@ -487,7 +511,7 @@ void QMatrix::reconstruct(half* out) //const uint16_t* __restrict__ b_q_groups, height, width, - groupsize, + gptq_groupsize, groups, out, rows_4 diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh index dda83a4f..d36b8d66 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh @@ -18,7 +18,7 @@ public: int height; int width; int groups; - int groupsize; + int gptq_groupsize; int rows_8; int rows_6; @@ -33,6 +33,7 @@ public: uint32_t* cuda_q_scale = NULL; half* cuda_q_scale_max = NULL; uint16_t* cuda_q_groups = NULL; + uint16_t* cuda_q_group_map = NULL; uint32_t* cuda_gptq_qzeros = NULL; half* cuda_gptq_scales = NULL; @@ -53,6 +54,7 @@ public: uint32_t* _q_scale, half* _q_scale_max, uint16_t* _q_groups, + uint16_t* _q_group_map, uint32_t* _gptq_qzeros, half* _gptq_scales, diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh index 71657191..cac9df9c 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh @@ -7,6 +7,7 @@ union half2_uint32 half2 as_half2; __device__ half2_uint32(uint32_t val) : as_uint32(val) {} __device__ half2_uint32(half2 val) : as_half2(val) {} + __device__ half2_uint32() : as_uint32(0) {} }; union half_uint16 @@ -15,6 +16,7 @@ union half_uint16 half as_half; __device__ half_uint16(uint16_t val) : as_uint16(val) {} __device__ half_uint16(half val) : as_half(val) {} + __device__ half_uint16() : as_uint16(0) {} }; // Max_scale premultiplied by 1/256 diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh index 06a58d18..f56eda79 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh @@ -1,3 +1,11 @@ +#ifndef _util_cuh +#define _util_cuh + +#include +#include +#include +#include +#include #define DIVIDE(x, size) (((x) + (size) - 1) / (size)) @@ -40,3 +48,7 @@ inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort= if (abort) exit(code); } } + +void print_global_mem(const half* ptr, int rows, int columns, int stride); + +#endif \ No newline at end of file diff --git a/server/exllamav2_kernels/exllamav2_kernels/ext.cpp b/server/exllamav2_kernels/exllamav2_kernels/ext.cpp index 5e52e6ab..ff4e1851 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/ext.cpp +++ b/server/exllamav2_kernels/exllamav2_kernels/ext.cpp @@ -31,6 +31,7 @@ uintptr_t make_q_matrix torch::Tensor q_scale, torch::Tensor q_scale_max, torch::Tensor q_groups, + torch::Tensor q_group_map, torch::Tensor gptq_qzeros, torch::Tensor gptq_scales, torch::Tensor gptq_g_idx, @@ -43,6 +44,7 @@ uintptr_t make_q_matrix TORCH_CHECK_DTYPE_OPT(q_scale, kInt); TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf); TORCH_CHECK_DTYPE_OPT(q_groups, kShort); + TORCH_CHECK_DTYPE_OPT(q_group_map, kShort); TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt); TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf); TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt); @@ -83,12 +85,15 @@ uintptr_t make_q_matrix q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(), q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(), q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(), + q_group_map.device().is_meta() ? NULL : (uint16_t*) q_group_map.data_ptr(), gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(), gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(), gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(), (half*) temp_dq.data_ptr() ); + if (m->failed) throw std::runtime_error("CUDA out of memory"); + return reinterpret_cast (m); } diff --git a/server/tests/utils/test_hub.py b/server/tests/utils/test_hub.py index 5438c153..49549893 100644 --- a/server/tests/utils/test_hub.py +++ b/server/tests/utils/test_hub.py @@ -32,10 +32,10 @@ def fresh_cache(): current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d - os.environ['HUGGINGFACE_HUB_CACHE'] = d + os.environ["HUGGINGFACE_HUB_CACHE"] = d yield huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value - os.environ['HUGGINGFACE_HUB_CACHE'] = current_value + os.environ["HUGGINGFACE_HUB_CACHE"] = current_value text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value @@ -47,7 +47,7 @@ def prefetched(): revision="main", local_files_only=False, repo_type="model", - allow_patterns=["*.safetensors"] + allow_patterns=["*.safetensors"], ) yield model_id @@ -61,7 +61,7 @@ def test_weight_hub_files_offline_error(offline, fresh_cache): def test_weight_hub_files_offline_ok(prefetched, offline): # If the model is prefetched then we should be able to get the weight files from local cache filenames = weight_hub_files(prefetched) - assert filenames == ['model.safetensors'] + assert filenames == ["model.safetensors"] def test_weight_hub_files(): diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index cd93d32a..22d03adf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -71,7 +71,7 @@ def _load_multi_mqa_gptq( g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = g_idx.to(device=weights.device) - bits, groupsize = weights._get_gptq_params() + bits, groupsize, _ = weights._get_gptq_params() from text_generation_server.utils.layers import HAS_EXLLAMA diff --git a/server/text_generation_server/utils/gptq/exllamav2.py b/server/text_generation_server/utils/gptq/exllamav2.py index dd41b269..a24e834b 100644 --- a/server/text_generation_server/utils/gptq/exllamav2.py +++ b/server/text_generation_server/utils/gptq/exllamav2.py @@ -27,6 +27,32 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): return output.view(output_shape) +# Group map needed for irregular group sizes + + +def make_group_map(q_groups, num_qrows): + + gr = q_groups.tolist() + group_map = [] + num_groups = len(gr) // 2 + + for i in range(num_groups): + bits = gr[i * 2] + if i < num_groups - 1: + qrows = gr[i * 2 + 3] - gr[i * 2 + 1] + else: + qrows = num_qrows - gr[i * 2 + 1] + rows = qrows * 32 // bits + for j in range(rows): + group_map += [i] + group_map += [rows - j] + + return torch.tensor(group_map, dtype=torch.short, device=q_groups.device) + + +# Create Q matrix + + def ext_make_q_matrix(w: dict, temp_dq, key: str = None): """ Create Q matrix @@ -37,6 +63,10 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): w["q_scale_max"] /= 256 w["q_perm"] = w["q_perm"].short() w["q_invperm"] = w["q_invperm"].short() + + if "q_group_map" not in w: + w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0]) + return make_q_matrix( w["q_weight"], w["q_perm"], @@ -44,6 +74,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): w["q_scale"], w["q_scale_max"], w["q_groups"], + w["q_group_map"], none_tensor, none_tensor, none_tensor, @@ -70,6 +101,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): none_tensor, none_tensor, none_tensor, + none_tensor, w["qzeros"], w["scales"], w["g_idx"].cpu(), @@ -84,6 +116,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): none_tensor, none_tensor, none_tensor, + none_tensor, w["qzeros"], w["scales"], none_tensor, diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index 62afff0c..deb1a941 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -18,7 +18,9 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"] -def _cached_weight_files(model_id: str, revision: Optional[str], extension: str) -> List[str]: +def _cached_weight_files( + model_id: str, revision: Optional[str], extension: str +) -> List[str]: """Guess weight files from the cached revision snapshot directory""" d = _get_cached_revision_directory(model_id, revision) if not d: @@ -27,7 +29,9 @@ def _cached_weight_files(model_id: str, revision: Optional[str], extension: str) return filenames -def _weight_hub_files_from_model_info(info: hf_api.ModelInfo, extension: str) -> List[str]: +def _weight_hub_files_from_model_info( + info: hf_api.ModelInfo, extension: str +) -> List[str]: return [ s.rfilename for s in info.siblings @@ -44,21 +48,27 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]: # see _weight_hub_files_from_model_info, that's also what is # done there with the len(s.rfilename.split("/")) == 1 condition root, _, files = next(os.walk(str(d))) - filenames = [f for f in files - if f.endswith(extension) - and "arguments" not in f - and "args" not in f - and "adapter" not in f - and "training" not in f] + filenames = [ + f + for f in files + if f.endswith(extension) + and "arguments" not in f + and "args" not in f + and "adapter" not in f + and "training" not in f + ] return filenames -def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Optional[Path]: +def _get_cached_revision_directory( + model_id: str, revision: Optional[str] +) -> Optional[Path]: if revision is None: revision = "main" repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path( - file_download.repo_folder_name(repo_id=model_id, repo_type="model")) + file_download.repo_folder_name(repo_id=model_id, repo_type="model") + ) if not repo_cache.is_dir(): # No cache for this model @@ -86,7 +96,7 @@ def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Op def weight_hub_files( - model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" + model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" ) -> List[str]: """Get the weights filenames on the hub""" api = HfApi() diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 011a9382..6648b55a 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -19,6 +19,7 @@ from accelerate import init_empty_weights from text_generation_server.utils.gptq.quant_linear import QuantLinear from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM +from text_generation_server.utils.log import log_once HAS_AWQ = True try: @@ -35,10 +36,11 @@ HAS_EXLLAMA = False CAN_EXLLAMA = major >= 8 V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1: - logger.warning( + V2 = False + log_once( + logger.warning, "Disabling exllama v2 and using v1 instead because there are issues when sharding" ) - V2 = False if os.getenv("DISABLE_EXLLAMA") == "True": HAS_EXLLAMA = False diff --git a/server/text_generation_server/utils/log.py b/server/text_generation_server/utils/log.py new file mode 100644 index 00000000..d831fa76 --- /dev/null +++ b/server/text_generation_server/utils/log.py @@ -0,0 +1,6 @@ +from functools import lru_cache + + +@lru_cache(10) +def log_once(log, msg:str): + log(msg) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index a2cca2ea..ee1899ab 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -6,6 +6,7 @@ import torch from loguru import logger from huggingface_hub import hf_hub_download import json +from text_generation_server.utils.log import log_once class Weights: @@ -161,7 +162,7 @@ class Weights: else: g_idx = None - bits, groupsize = self._get_gptq_params() + bits, groupsize, _ = self._get_gptq_params() weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) else: slice_ = self._get_slice(f"{prefix}.weight") @@ -211,10 +212,10 @@ class Weights: else: g_idx = None - bits, groupsize = self._get_gptq_params() + bits, groupsize, desc_act = self._get_gptq_params() from text_generation_server.utils.layers import HAS_EXLLAMA - use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq" + use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] @@ -240,11 +241,15 @@ class Weights: def get_multi_weights_row(self, prefix: str, quantize: str): if quantize == "gptq": use_exllama = True - bits, groupsize = self._get_gptq_params() + bits, groupsize, desc_act = self._get_gptq_params() if bits != 4: use_exllama = False + if desc_act: + log_once(logger.warning, "Disabling exllama because desc_act=True") + use_exllama = False + if self.process_group.size() > 1: g_idx = self.get_tensor(f"{prefix}.g_idx") if g_idx is not None: @@ -274,12 +279,18 @@ class Weights: if use_exllama: if not HAS_EXLLAMA: if CAN_EXLLAMA: - logger.warning( + log_once( + logger.warning, "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True" ) use_exllama = False else: - logger.info(f"Using exllama kernels v{HAS_EXLLAMA}") + log_once( + logger.info, + f"Using exllama kernels v{HAS_EXLLAMA}" + ) + + g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) if use_exllama and groupsize != -1: qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) @@ -288,14 +299,12 @@ class Weights: qzeros = self.get_tensor(f"{prefix}.qzeros") scales = self.get_tensor(f"{prefix}.scales") - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - if use_exllama: g_idx = g_idx - g_idx[0] weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) elif quantize == "awq": - bits, groupsize = self._get_gptq_params() + bits, groupsize, _ = self._get_gptq_params() try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0) @@ -314,18 +323,20 @@ class Weights: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight - def _get_gptq_params(self) -> Tuple[int, int]: + def _get_gptq_params(self) -> Tuple[int, int, int]: try: bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() + desc_act = False except (SafetensorError, RuntimeError) as e: try: bits = self.gptq_bits groupsize = self.gptq_groupsize + desc_act = getattr(self, "gptq_desc_act", False) except Exception: raise e - return bits, groupsize + return bits, groupsize, desc_act def _set_gptq_params(self, model_id, revision): filename = "config.json" @@ -340,6 +351,7 @@ class Weights: data = json.load(f) self.gptq_bits = data["quantization_config"]["bits"] self.gptq_groupsize = data["quantization_config"]["group_size"] + self.gptq_desc_act = data["quantization_config"]["desc_act"] except Exception: filename = "quantize_config.json" try: @@ -353,6 +365,7 @@ class Weights: data = json.load(f) self.gptq_bits = data["bits"] self.gptq_groupsize = data["group_size"] + self.gptq_desc_act = data["desc_act"] except Exception: filename = "quant_config.json" try: @@ -366,5 +379,6 @@ class Weights: data = json.load(f) self.gptq_bits = data["w_bit"] self.gptq_groupsize = data["q_group_size"] + self.gptq_desc_act = data["desc_act"] except Exception: pass