diff --git a/Dockerfile b/Dockerfile index cf5e0ed6..d45aaec5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -112,6 +112,14 @@ RUN make build-flash-attention-v2 FROM kernel-builder as exllama-kernels-builder WORKDIR /usr/src COPY server/exllama_kernels/ . + +RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build + +# Build Transformers exllama kernels +FROM kernel-builder as exllamav2-kernels-builder +WORKDIR /usr/src +COPY server/exllamav2_kernels/ . + # Build specific version of transformers RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build @@ -182,6 +190,8 @@ COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86 COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from exllama kernels builder COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from exllamav2 kernels builder +COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from awq kernels builder COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from eetq kernels builder diff --git a/router/src/lib.rs b/router/src/lib.rs index d6df2f56..b547dc15 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -315,7 +315,6 @@ mod tests { if !filename.exists() { std::fs::rename(tmp_filename, filename).unwrap() } - } Tokenizer::from_file("tokenizer.json").unwrap() } diff --git a/server/exllamav2_kernels/exllamav2_kernels/config.h b/server/exllamav2_kernels/exllamav2_kernels/config.h new file mode 100644 index 00000000..86baaf41 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/config.h @@ -0,0 +1,13 @@ +#ifndef _config_h +#define _config_h + +#define MAX_Q_GEMM_ROWS 50 + +#define QMODE_2BIT 1 +#define QMODE_3BIT 1 +#define QMODE_4BIT 1 +#define QMODE_5BIT 1 +#define QMODE_6BIT 0 +#define QMODE_8BIT 0 + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cpp/util.h b/server/exllamav2_kernels/exllamav2_kernels/cpp/util.h new file mode 100644 index 00000000..919703a8 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cpp/util.h @@ -0,0 +1,12 @@ +#ifndef _util_h +#define _util_h + +#define DBGS(__x) printf("%s\n", __x) +#define DBGI(__x) printf("%s: %i\n", #__x, __x) +#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y) +#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z) +#define DBGF(__x) printf("%s: %f\n", #__x, __x) +#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y) +#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z) + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh new file mode 100644 index 00000000..12684ff8 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh @@ -0,0 +1,56 @@ +#ifndef _compat_cuh +#define _compat_cuh + +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) +{ + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do + { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } + while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) +{ + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do + { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } + while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) +#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } + +#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } +#endif + +#endif +#endif + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/compat_gemm.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/compat_gemm.cuh new file mode 100644 index 00000000..19b1e4a6 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/compat_gemm.cuh @@ -0,0 +1,38 @@ +#ifndef _compat_gemm_cuh +#define _compat_gemm_cuh + +#if defined(USE_ROCM) + +// For some reason this include is not present anywhere in exllama_v2 codebase, but it is required +// for symbols as hipblasHalf. +#include + +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, + hipblasOperation_t transA, + hipblasOperation_t transB, + int m, + int n, + int k, + const half* alpha, + const half* AP, + int lda, + const half* BP, + int ldb, + const half* beta, + half* CP, + int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); +} +#define hipblasHgemm __compat_hipblasHgemm + +// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. +#define rocblas_operation_none HIPBLAS_OP_N +#define rocblas_hgemm __compat_hipblasHgemm +#endif + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/matrix_view.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/matrix_view.cuh new file mode 100644 index 00000000..55af84f2 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/matrix_view.cuh @@ -0,0 +1,121 @@ +#ifndef _matrix_view_cuh +#define _matrix_view_cuh + +#include +#include + +#include "quant/qdq_util.cuh" + +class MatrixView_half +{ +public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } + + __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const + { + half2* ptr = (half2*) item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __low2half(i01); + items[1] = __high2half(i01); + items[2] = __low2half(i23); + items[3] = __high2half(i23); + } + __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const + { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2float(__low2half(i01)); + items[1] = __half2float(__high2half(i01)); + items[2] = __half2float(__low2half(i23)); + items[3] = __half2float(__high2half(i23)); + } + + __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const + { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2half2(__low2half(i01)); + items[1] = __half2half2(__high2half(i01)); + items[2] = __half2half2(__low2half(i23)); + items[3] = __half2half2(__high2half(i23)); + } +}; + +class MatrixView_half_rw +{ +public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } + + __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) + { + half2 v01 = __halves2half2(v0, v1); + half2 v23 = __halves2half2(v2, v3); + half2* ptr = (half2*) item_ptr(row, column); + ptr[0] = v01; + ptr[1] = v23; + } +}; + +class MatrixView_q4_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const + { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const + { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + items[2] = (d >> 8) & 0x0f; + items[3] = (d >> 12) & 0x0f; + } +}; + +#endif \ No newline at end of file diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu new file mode 100644 index 00000000..351b9cd5 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu @@ -0,0 +1,211 @@ +#include "q_gemm.cuh" +#include "util.cuh" +#include "matrix_view.cuh" +#include "../config.h" + +#include "quant/qdq_2.cuh" +#include "quant/qdq_3.cuh" +#include "quant/qdq_4.cuh" +#include "quant/qdq_5.cuh" +#include "quant/qdq_6.cuh" +#include "quant/qdq_8.cuh" + +#define BLOCK_KN_SIZE 128 +#define BLOCK_M_SIZE_MAX 8 +#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32) +#define CLEAR_N_SIZE 256 + +#include "q_gemm_kernel.cuh" +#include "q_gemm_kernel_gptq.cuh" + +#include "compat_gemm.cuh" + +void gemm_half_q_half_cuda_part +( + const half* a, + QMatrix* b, + half* c, + int size_m, + int size_n, + int size_k, + int m_count, + bool clear +) +{ + if (!b->is_gptq) + { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count); + + kernel<<>> + ( + a, + b->cuda_q_weight, + b->cuda_q_scale, + b->cuda_q_scale_max, + c, + size_m, + size_n, + size_k, + b->groups, + b->groupsize, + b->cuda_q_perm, + b->rows_8, + b->rows_6, + b->rows_5, + b->rows_4, + b->rows_3, + b->rows_2, + clear + ); + } + else + { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count); + +// DBGX((uint64_t) b->cuda_q_perm); +// DBGI(b->rows_4); +// DBGI(b->height); + + kernel<<>> + ( + a, + b->cuda_q_weight, + b->cuda_gptq_qzeros, + b->cuda_gptq_scales, + c, + size_m, + size_n, + size_k, + b->groups, + b->groupsize, + b->cuda_q_perm, + b->rows_4, + clear + ); + } +} + +void gemm_half_q_half_cuda +( + cublasHandle_t cublas_handle, + const half* a, + QMatrix* b, + half* c, + int size_m, + int size_n, + int size_k, + bool clear, + half* temp_dq, + bool force_cuda +) +{ + if (size_m > MAX_Q_GEMM_ROWS && !force_cuda) + { + //printf("cublas\n"); + + // Reconstruct FP16 matrix, then cuBLAS + + if (!temp_dq) temp_dq = b->temp_dq; + b->reconstruct(temp_dq); + + //cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH); + + const half alpha = __float2half(1.0f); + const half beta = clear ? __float2half(0.0f) : __float2half(1.0f); + cublasHgemm(cublas_handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + size_n, size_m, size_k, + &alpha, temp_dq, size_n, + a, size_k, + &beta, c, size_n); + + //const float alpha = 1.0f; + //const float beta = clear ? 0.0f : 1.0f; + //cublasSgemmEx(cublas_handle, + // CUBLAS_OP_N, + // CUBLAS_OP_N, + // size_n, size_m, size_k, + // &alpha, temp_dq, CUDA_R_16F, size_n, + // a, CUDA_R_16F, size_k, + // &beta, c, CUDA_R_16F, size_n); + + //const float alpha = 1.0f; + //const float beta = clear ? 0.0f : 1.0f; + //cublasGemmEx(cublas_handle, + // CUBLAS_OP_N, CUBLAS_OP_N, + // size_n, size_m, size_k, + // &alpha, temp_dq, CUDA_R_16F, size_n, + // a, CUDA_R_16F, size_k, + // &beta, c, CUDA_R_16F, size_n, + // CUDA_R_16F, CUBLAS_GEMM_DFALT_TENSOR_OP); + } + else + { + //printf("cuda\n"); + + // Quantized matmul + + //if (clear) clear_tensor_cuda(c, size_m, size_n); + + int max_chunks = size_m / BLOCK_M_SIZE_MAX; + int last_chunk = max_chunks * BLOCK_M_SIZE_MAX; + int last_chunk_size = size_m - last_chunk; + + if (max_chunks) + { + gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear); + } + + if (last_chunk_size) + { + gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear); + } + } +} + +__global__ void clear_kernel +( + half* __restrict__ c, + const int size_m, + const int size_n +) +{ + int m = blockIdx.y; + int n = (blockIdx.x * CLEAR_N_SIZE + threadIdx.x) * 8; + if (n >= size_n) return; + int4* c_ptr = (int4*)(c + m * size_n + n); + *c_ptr = {}; +} + +void clear_tensor_cuda +( + half* c, + int size_m, + int size_n +) +{ + return; + dim3 blockDim, gridDim; + blockDim.x = CLEAR_N_SIZE; + blockDim.y = 1; + gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE); + gridDim.y = size_m; + clear_kernel<<>>(c, size_m, size_n); +} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh new file mode 100644 index 00000000..c69f1a70 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh @@ -0,0 +1,33 @@ +#ifndef _q_gemm_cuh +#define _q_gemm_cuh + +#include +#include +#include +#include +#include + +#include "q_matrix.cuh" + +void gemm_half_q_half_cuda +( + cublasHandle_t cublas_handle, + const half* a, + QMatrix* b, + half* c, + int size_m, + int size_n, + int size_k, + bool clear = false, + half* reconstruct = NULL, + bool force_cuda = false +); + +void clear_tensor_cuda +( + half* c, + int size_m, + int size_n +); + +#endif \ No newline at end of file diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh new file mode 100644 index 00000000..0b899a84 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh @@ -0,0 +1,487 @@ +#include "compat.cuh" + +#include +#include + +__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + + + +typedef void (*fp_gemm_half_q_half_kernel) +( + const half*, + const uint32_t*, + const uint32_t*, + const half*, + half*, + const int, + const int, + const int, + const int, + const int, + const uint16_t*, + const int, + const int, + const int, + const int, + const int, + const int, + const bool +); + +template +__global__ void gemm_half_q_half_kernel +( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_q_scale, + const half* __restrict__ b_q_scale_max, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int groups, + const int groupsize, + const uint16_t* __restrict__ b_q_perm, + const int rows_8, + const int rows_6, + const int rows_5, + const int rows_4, + const int rows_3, + const int rows_2, + const bool clear +) +{ + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n); + + int t = threadIdx.x; + + // Block + + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + int n = offset_n + t * 4; + + // Preload block_a + + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) + { + for (int m = 0; m < m_count; ++m) + { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + half a0 = a_ptr[b_q_perm[offset_k + t]]; + block_a_ptr[t] = a0; + } + } + + // Clear + + if (n >= size_n) return; + + if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0) + { + for (int m = 0; m < m_count; m++) + *((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + + int group = offset_k / groupsize; + + // Preload scales + + float scales[MAX_GROUPS_IN_BLOCK][4]; + + int groups_in_block = DIVIDE((end_k - offset_k), groupsize); + for (int g = 0; g < groups_in_block; g++) + { + int qscales[4]; + b_q_scale_.item4(qscales, group + g, n); + qscales[0]++; + qscales[1]++; + qscales[2]++; + qscales[3]++; + float maxscale = __half2float(b_q_scale_max[group + g]); + scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale; + scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale; + scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale; + scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale; + } + + // a, b offset + + int pre_rows_8 = min(rows_8, offset_k); + int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0; + int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0; + int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0; + int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0; + int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0; + int qk = 0; + qk += pre_rows_8 / 32 * 8; + qk += pre_rows_6 / 32 * 6; + qk += pre_rows_5 / 32 * 5; + qk += pre_rows_4 / 32 * 4; + qk += pre_rows_3 / 32 * 3; + qk += pre_rows_2 / 32 * 2; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + + int scales_idx = 0; + float qs_f0 = scales[scales_idx][0]; + float qs_f1 = scales[scales_idx][1]; + float qs_f2 = scales[scales_idx][2]; + float qs_f3 = scales[scales_idx][3]; + int nextgroup = offset_k + groupsize; + + // Column result + + float block_c[m_count][4] = {}; + + // Dequantize groups + + int k = offset_k; + + while (k < rows_8 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_f0 = scales[scales_idx][0]; + qs_f1 = scales[scales_idx][1]; + qs_f2 = scales[scales_idx][2]; + qs_f3 = scales[scales_idx][3]; + nextgroup += groupsize; + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + int4 load_int4[2]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); + block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); + block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); + block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); + } + a_ptr += 8; + } + k += 32; + } + + while (k < rows_6 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_f0 = scales[scales_idx][0]; + qs_f1 = scales[scales_idx][1]; + qs_f2 = scales[scales_idx][2]; + qs_f3 = scales[scales_idx][3]; + nextgroup += groupsize; + } + + #pragma unroll + for (int j = 0; j < 2; j++) + { + int4 load_int4[3]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][8]; + dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n); + dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n); + dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n); + dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); + block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); + block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); + block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); + } + a_ptr += 16; + } + k += 32; + } + + while (k < rows_5 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_f0 = scales[scales_idx][0]; + qs_f1 = scales[scales_idx][1]; + qs_f2 = scales[scales_idx][2]; + qs_f3 = scales[scales_idx][3]; + nextgroup += groupsize; + } + + #pragma unroll + for (int j = 0; j < 1; j++) + { + int4 load_int4[5]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[3] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[4] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][16]; + dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n); + dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n); + dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n); + dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); + block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); + block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); + block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); + } + a_ptr += 32; + } + + k += 32; + } + + while (k < rows_4 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_f0 = scales[scales_idx][0]; + qs_f1 = scales[scales_idx][1]; + qs_f2 = scales[scales_idx][2]; + qs_f3 = scales[scales_idx][3]; + nextgroup += groupsize; + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + int4 load_int4[1]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][4]; + dequant_4bit_8(load_int4[0].x, dq[0], size_n); + dequant_4bit_8(load_int4[0].y, dq[1], size_n); + dequant_4bit_8(load_int4[0].z, dq[2], size_n); + dequant_4bit_8(load_int4[0].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); + block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); + block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); + block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); + } + a_ptr += 8; + } + k += 32; + } + + while (k < rows_3 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_f0 = scales[scales_idx][0]; + qs_f1 = scales[scales_idx][1]; + qs_f2 = scales[scales_idx][2]; + qs_f3 = scales[scales_idx][3]; + nextgroup += groupsize; + } + + #pragma unroll + for (int j = 0; j < 1; j++) + { + int4 load_int4[3]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); + block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); + block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); + block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); + } + a_ptr += 32; + } + k += 32; + } + + while (k < rows_2 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_f0 = scales[scales_idx][0]; + qs_f1 = scales[scales_idx][1]; + qs_f2 = scales[scales_idx][2]; + qs_f3 = scales[scales_idx][3]; + nextgroup += groupsize; + } + + #pragma unroll + for (int j = 0; j < 2; j++) + { + int4 load_int4[1]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][8]; + dequant_2bit_16(load_int4[0].x, dq[0], size_n); + dequant_2bit_16(load_int4[0].y, dq[1], size_n); + dequant_2bit_16(load_int4[0].z, dq[2], size_n); + dequant_2bit_16(load_int4[0].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); + block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); + block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); + block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); + } + + a_ptr += 16; + } + k += 32; + } + + // Accumulate column sums in c + + for (int m = 0; m < m_count; m++) + { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); + half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); + atomicAdd(out , result01); + atomicAdd(out + 1, result23); + } +} + +fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count) +{ + #if BLOCK_M_SIZE_MAX >= 1 + if (m_count == 1) return gemm_half_q_half_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 2 + if (m_count == 2) return gemm_half_q_half_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 3 + if (m_count == 3) return gemm_half_q_half_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 4 + if (m_count == 4) return gemm_half_q_half_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 5 + if (m_count == 5) return gemm_half_q_half_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 6 + if (m_count == 6) return gemm_half_q_half_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 7 + if (m_count == 7) return gemm_half_q_half_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 8 + if (m_count == 8) return gemm_half_q_half_kernel; + #endif + return NULL; +} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh new file mode 100644 index 00000000..ebaa42d0 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh @@ -0,0 +1,219 @@ +#include "compat.cuh" + +__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hadd2(result, g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __half2float(__low2half(result)) + __half2float(__high2half(result)); +} + +typedef void (*fp_gemm_half_q_half_gptq_kernel) +( + const half*, + const uint32_t*, + const uint32_t*, + const half*, + half*, + const int, + const int, + const int, + const int, + const int, + const uint16_t*, + const int, + const bool +); + +template +__global__ void gemm_half_q_half_gptq_kernel +( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int groups, + const int groupsize, + const uint16_t* __restrict__ b_q_perm, + const int rows_4, + const bool clear +) +{ + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) + { + for (int m = 0; m < m_count; ++m) + { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; + else a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + + if (n >= size_n) return; + + if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0) + { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + + int zeros[4]; + float scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + +// __syncthreads(); + + // Column result + + float block_c[m_count][4] = {}; + + // Dequantize and multiply + + int k = offset_k; + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + const int4* b_ptr4 = (int4*) b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][4]; + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + + #pragma unroll + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); + block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); + block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); + block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); + } + + b_ptr += size_n; + a_ptr += 8; + } + + k += 32; + } + + for (int m = 0; m < m_count; m++) + { + half2 *out = (half2*) c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); + half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); + atomicAdd(out , result01); + atomicAdd(out + 1, result23); + } +} + +fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count) +{ + #if BLOCK_M_SIZE_MAX >= 1 + if (m_count == 1) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 2 + if (m_count == 2) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 3 + if (m_count == 3) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 4 + if (m_count == 4) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 5 + if (m_count == 5) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 6 + if (m_count == 6) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 7 + if (m_count == 7) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 8 + if (m_count == 8) return gemm_half_q_half_gptq_kernel; + #endif + return NULL; +} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu new file mode 100644 index 00000000..6aed7470 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu @@ -0,0 +1,623 @@ +#include "q_matrix.cuh" +#include "matrix_view.cuh" +#include "util.cuh" + +#include "quant/qdq_2.cuh" +#include "quant/qdq_3.cuh" +#include "quant/qdq_4.cuh" +#include "quant/qdq_5.cuh" +#include "quant/qdq_6.cuh" +#include "quant/qdq_8.cuh" + +#define BLOCK_KN_SIZE 128 + +#define THREADS_X 32 +#define THREADS_Y 32 + +// Shuffle quantized data on load + +__global__ void shuffle_kernel +( + uint32_t* __restrict__ b_q_weight, + const int size_k, + const int size_n, + const int rows_8, + const int rows_6, + const int rows_5, + const int rows_4, + const int rows_3, + const int rows_2 +) +{ + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; } + while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; } + while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; } + while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; } + while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; } + while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; } +} + + +// QMatrix constructor + +QMatrix::QMatrix +( + const int _device, + const int _height, + const int _width, + const int _groups, + + uint32_t* _q_weight, + uint16_t* _q_perm, + uint16_t* _q_invperm, + uint32_t* _q_scale, + half* _q_scale_max, + uint16_t* _q_groups, + + uint32_t* _gptq_qzeros, + half* _gptq_scales, + uint32_t* _gptq_g_idx, + + half* _temp_dq +) : + device(_device), + height(_height), + width(_width), + groups(_groups), + temp_dq(_temp_dq) +{ + cudaSetDevice(device); + + failed = false; + + cuda_q_weight = _q_weight; + cuda_q_perm = _q_perm; + cuda_q_invperm = _q_invperm; + cuda_q_scale = _q_scale; + cuda_q_scale_max = _q_scale_max; + cuda_q_groups = _q_groups; + cuda_gptq_qzeros = _gptq_qzeros; + cuda_gptq_scales = _gptq_scales; + + is_gptq = (_gptq_qzeros != NULL); + + groupsize = 1; + while (groupsize * groups < height) groupsize *= 2; + + // Create group map + + rows_8 = 0; + rows_6 = 0; + rows_5 = 0; + rows_4 = 0; + rows_3 = 0; + rows_2 = 0; + + if (!is_gptq) + { + uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t)); + cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost); + + for (int i = 0; i < groups; i++) + { + int bits = cpu_q_groups[i * 2]; + if (bits == 8) rows_8 += groupsize; + if (bits == 6) rows_6 += groupsize; + if (bits == 5) rows_5 += groupsize; + if (bits == 4) rows_4 += groupsize; + if (bits == 3) rows_3 += groupsize; + if (bits == 2) rows_2 += groupsize; + } + + free(cpu_q_groups); + + rows_6 += rows_8; + rows_5 += rows_6; + rows_4 += rows_5; + rows_3 += rows_4; + rows_2 += rows_3; + } + else + { + rows_4 = height; + rows_3 = height; + rows_2 = height; + + if (_gptq_g_idx) + { + if (!make_sequential(_gptq_g_idx)) + { + failed = true; + //printf("FAIL\n"); + return; + } + } + } + + // Shuffle quantized data + + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = 1; + + shuffle_kernel<<>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2); +} + +QMatrix::~QMatrix() +{ +} + +// Reconstruct b[k,n] (GPTQ) + +__global__ void reconstruct_gptq_kernel +( + const uint32_t* __restrict__ b_q_weight, + const uint16_t* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + //const uint16_t* __restrict__ b_q_groups, + const int size_k, + const int size_n, + const int groupsize, + const int groups, + half* __restrict__ b, + const int rows_4 +) +{ + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + + __shared__ uint16_t perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; + + if (b_q_perm) + { + if (offset_k + t < size_k) + perm[t] = b_q_perm[offset_k + t]; + } + + // Column + + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + + int zeros[4]; + half2 scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + + for (int p = 0; p < 4; p++) + { + half2 dq[4][4]; + const int4* b_ptr4 = (int4*) b_ptr; + int4 load_int4 = *b_ptr4; + + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + + b_ptr += size_n; + //half* dqh = (half*)dq; + if (b_q_perm) + { + for (int j = 0; j < 4; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + else + { + for (int j = 0; j < 4; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + + +// Reconstruct b[k,n] + +__global__ void reconstruct_kernel +( + const uint32_t* __restrict__ b_q_weight, + const uint16_t* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_q_scale, + const half* __restrict__ b_q_scale_max, + //const uint16_t* __restrict__ b_q_groups, + const int size_k, + const int size_n, + const int groupsize, + const int groups, + half* __restrict__ b, + const int rows_8, + const int rows_6, + const int rows_5, + const int rows_4, + const int rows_3, + const int rows_2 +) +{ + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x; + + // Preload remapping table + + int t = threadIdx.x; + __shared__ uint16_t perm[BLOCK_KN_SIZE]; + if (offset_k + t < size_k) + perm[t] = b_q_perm[offset_k + t]; + + // Column + + int n = offset_n + t; + if (n >= size_n) return; + + // Find initial group + + int group = offset_k / groupsize; + + int pre_rows_8 = min(rows_8, offset_k); + int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0; + int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0; + int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0; + int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0; + int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0; + int qk = 0; + qk += pre_rows_8 / 32 * 8; + qk += pre_rows_6 / 32 * 6; + qk += pre_rows_5 / 32 * 5; + qk += pre_rows_4 / 32 * 4; + qk += pre_rows_3 / 32 * 3; + qk += pre_rows_2 / 32 * 2; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + half2 qs_h2 = __halves2half2(qs_h, qs_h); + int nextgroup = offset_k + groupsize; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + int k = offset_k; + int lk = 0; + + __syncthreads(); + + while (k < rows_8 && k < end_k) + { + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 4; p++) + { + half2 dq[4]; + uint32_t q_0 = *b_ptr; b_ptr += size_n; + uint32_t q_1 = *b_ptr; b_ptr += size_n; + dequant_8bit_8(q_0, q_1, dq, size_n); + for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*) dq; + for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_6 && k < end_k) + { + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 2; p++) + { + half2 dq[8]; + uint32_t q_0 = *b_ptr; b_ptr += size_n; + uint32_t q_1 = *b_ptr; b_ptr += size_n; + uint32_t q_2 = *b_ptr; b_ptr += size_n; + dequant_6bit_16(q_0, q_1, q_2, dq, size_n); + for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*) dq; + for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_5 && k < end_k) + { + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 1; p++) + { + half2 dq[16]; + uint32_t q_0 = *b_ptr; b_ptr += size_n; + uint32_t q_1 = *b_ptr; b_ptr += size_n; + uint32_t q_2 = *b_ptr; b_ptr += size_n; + uint32_t q_3 = *b_ptr; b_ptr += size_n; + uint32_t q_4 = *b_ptr; b_ptr += size_n; + dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n); + for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*) dq; + for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_4 && k < end_k) + { + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 4; p++) + { + half2 dq[4]; + uint32_t q_0 = *b_ptr; b_ptr += size_n; + dequant_4bit_8(q_0, dq, size_n); + for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*) dq; + for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_3 && k < end_k) + { + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 1; p++) + { + half2 dq[16]; + uint32_t q_0 = *b_ptr; b_ptr += size_n; + uint32_t q_1 = *b_ptr; b_ptr += size_n; + uint32_t q_2 = *b_ptr; b_ptr += size_n; + dequant_3bit_32(q_0, q_1, q_2, dq, size_n); + for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*) dq; + for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_2 && k < end_k) + { + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 2; p++) + { + half2 dq[8]; + uint32_t q_0 = *b_ptr; b_ptr += size_n; + dequant_2bit_16(q_0, dq, size_n); + for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*) dq; + for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } +} + +void QMatrix::reconstruct(half* out) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + + if (!is_gptq) + { + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + reconstruct_kernel<<>> + ( + cuda_q_weight, + cuda_q_perm, + cuda_q_scale, + cuda_q_scale_max, + //cuda_q_groups, + height, + width, + groupsize, + groups, + out, + rows_8, + rows_6, + rows_5, + rows_4, + rows_3, + rows_2 + ); + } + else + { + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4); + reconstruct_gptq_kernel<<>> + ( + cuda_q_weight, + cuda_q_perm, + cuda_gptq_qzeros, + cuda_gptq_scales, + //const uint16_t* __restrict__ b_q_groups, + height, + width, + groupsize, + groups, + out, + rows_4 + ); + } +} + +__global__ void make_sequential_kernel +( + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const uint16_t* __restrict__ q_perm, + const int w_height, + const int w_width +) +{ + const uint64_t* w2 = (uint64_t*) w; + uint64_t* w_new2 = (uint64_t*) w_new; + int w2_stride = w_width >> 1; + + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + + int w_new2_row = blockIdx.y; + + int q_perm_idx = w_new2_row << 3; + + uint64_t dst = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +bool QMatrix::make_sequential(const uint32_t* cpu_g_idx) +{ + uint32_t* cuda_new_qweight = NULL; + cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); + if (err != cudaSuccess) { + cudaError_t cuda_status = cudaGetLastError(); // Clear error + return false; + } + + uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); + uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); + uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); + + // Group histogram + + for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; + + // Group map + + for (int i = 0, acc = 0; i < groups; i++) + { + short tmp = cpu_g_idx_map[i]; + cpu_g_idx_map[i] = acc; + acc += tmp; + } + + // X map (inverse) + + for (int row = 0; row < height; row++) + { + uint32_t target_group = cpu_g_idx[row]; + uint32_t target_row = cpu_g_idx_map[target_group]; + cpu_g_idx_map[target_group]++; + cpu_x_map_inv[row] = target_row; + } + + // X map + + for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; + + // Reduce to uint16_t + + uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map; + uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv; + for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row]; + for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row]; + + // Move to CUDA + + cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice); + cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice); + + // Rearrange rows in w + + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = height / 8; + + make_sequential_kernel<<>> + ( + cuda_q_weight, + cuda_new_qweight, + cuda_q_perm, + height / 8, + width + ); + + // Replace qweights + + cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + + // Cleanup + + cudaDeviceSynchronize(); + + cudaFree(cuda_new_qweight); + free(cpu_g_idx_map); + free(cpu_x_map); + free(cpu_x_map_inv); + + return true; +} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh new file mode 100644 index 00000000..dda83a4f --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh @@ -0,0 +1,73 @@ +#ifndef _q_matrix_cuh +#define _q_matrix_cuh + +#include +#include +#include +#include + +#define MAX_SUPERGROUPS 16 + +class QMatrix +{ +public: + + int device; + bool is_gptq; + + int height; + int width; + int groups; + int groupsize; + + int rows_8; + int rows_6; + int rows_5; + int rows_4; + int rows_3; + int rows_2; + + uint32_t* cuda_q_weight = NULL; + uint16_t* cuda_q_perm = NULL; + uint16_t* cuda_q_invperm = NULL; + uint32_t* cuda_q_scale = NULL; + half* cuda_q_scale_max = NULL; + uint16_t* cuda_q_groups = NULL; + uint32_t* cuda_gptq_qzeros = NULL; + half* cuda_gptq_scales = NULL; + + half* temp_dq; + + bool failed; + + QMatrix + ( + const int _device, + const int _height, + const int _width, + const int _groups, + + uint32_t* _q_weight, + uint16_t* _q_perm, + uint16_t* _q_invperm, + uint32_t* _q_scale, + half* _q_scale_max, + uint16_t* _q_groups, + + uint32_t* _gptq_qzeros, + half* _gptq_scales, + uint32_t* _gptq_g_idx, + + half* _temp_dq + ); + + ~QMatrix(); + + void reconstruct(half* out); + bool make_sequential(const uint32_t* cpu_g_idx); + +private: + +}; + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_2.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_2.cuh new file mode 100644 index 00000000..3beaeefa --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_2.cuh @@ -0,0 +1,103 @@ +#ifndef _qdq_2_cuh +#define _qdq_2_cuh + +#include "qdq_util.cuh" +#include "../../config.h" + +#if QMODE_2BIT == 1 + +// Permutation: +// +// ffddbb99 77553311 eeccaa88 66442200 + +__forceinline__ __device__ void shuffle_2bit_16 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0]; + uint32_t qb = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + uint32_t qa0 = qa & 0x03; + uint32_t qa1 = (qa & 0x0c) >> 2; + qa >>= 4; + qb |= (qa1 << (i * 2 + 16)); + qb |= (qa0 << (i * 2)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_2bit_16 +( + const uint32_t q_0, + half2 (&dq)[8], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y4_ = __float2half_rn(1.0f / 4.0f); + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y4 = __halves2half2(y4_, y4_); + const half2 y16 = __halves2half2(y16_, y16_); + const half2 y64 = __halves2half2(y64_, y64_); + const half z1_ = __float2half_rn(-1024.0f - 2.0f); + const half z4_ = __float2half_rn(-1024.0f / 4.0f - 2.0f); + const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f); + const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z4 = __halves2half2(z4_, z4_); + const half2 z16 = __halves2half2(z16_, z16_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 + half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 + half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 + qa >>= 8; + half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 + half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 + half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 + half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y4, z4); + dq[2] = __hfma2(q2.as_half2, y16, z16); + dq[3] = __hfma2(q3.as_half2, y64, z64); + dq[4] = __hadd2(q4.as_half2, z1); + dq[5] = __hfma2(q5.as_half2, y4, z4); + dq[6] = __hfma2(q6.as_half2, y16, z16); + dq[7] = __hfma2(q7.as_half2, y64, z64); +} + +#else + +__forceinline__ __device__ void shuffle_2bit_16 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_2bit_16 +( + const uint32_t q_0, + half2 (&dq)[8], + int stride +) +{ + half dqh[16]; + for (int i = 0; i < 16; i++) dqh[i] = dq_ns(exb(q_0, i * 2, 0x03), 2); + + for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif + +#endif \ No newline at end of file diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_3.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_3.cuh new file mode 100644 index 00000000..10117376 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_3.cuh @@ -0,0 +1,169 @@ +#ifndef _qdq_3_cuh +#define _qdq_3_cuh + +#include "qdq_util.cuh" +#include "../../config.h" + +#if QMODE_3BIT == 1 + +// Permutation: +// +// v9997775 55333111 u8886664 44222000 (u, v lsb) +// vjjjhhhf ffdddbbb uiiiggge eecccaaa +// vtttrrrp ppnnnlll usssqqqo oommmkkk + +__forceinline__ __device__ void shuffle_3bit_32 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + + // qa: aa999888 77766655 54443332 22111000 + // qb: lkkkjjji iihhhggg fffeeedd dcccbbba + // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll + + uint32_t qd = qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: ..999888 77766655 54443332 22111000 + // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa + // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk + // qd: vvvuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + + for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); } + for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); } + for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); } + + // za: 9997775 55333111 8886664 44222000 + // zb: jjjhhhf ffdddbbb iiiggge eecccaaa + // zc: tttrrrp ppnnnlll sssqqqo oommmkkk + // qd: vvvuuu + + za |= ((qd & 0x01) >> 0) << 15; + zb |= ((qd & 0x02) >> 1) << 15; + zc |= ((qd & 0x04) >> 2) << 15; + za |= ((qd & 0x08) >> 3) << 31; + zb |= ((qd & 0x10) >> 4) << 31; + zc |= ((qd & 0x20) >> 5) << 31; + + // za: v9997775 55333111 u8886664 44222000 (u, v lsb) + // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa + // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; +} + +__forceinline__ __device__ void dequant_3bit_32 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[16], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y8_ = __float2half_rn(1.0f / 8.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y8 = __halves2half2(y8_, y8_); + const half2 y64 = __halves2half2(y64_, y64_); + const half z1_ = __float2half_rn(-1024.0f - 4.0f); + const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f); + const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z8 = __halves2half2(z8_, z8_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + + half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 + qa >>= 6; + half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 + half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 + qa >>= 9; + qa &= 0x00010001; + half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 + half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + qb >>= 6; + half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 + half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 + half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 + qb >>= 8; + qb &= 0x00020002; + half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 + half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + qc >>= 6; + half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 + half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 + qc >>= 7; + qc &= 0x00040004; + half2_uint32 q15((qa | qb | qc) | c0); + + dq[ 0] = __hadd2( q0.as_half2, z1); + dq[ 1] = __hfma2( q1.as_half2, y8, z8); + dq[ 2] = __hadd2( q2.as_half2, z1); + dq[ 3] = __hfma2( q3.as_half2, y8, z8); + dq[ 4] = __hfma2( q4.as_half2, y64, z64); + dq[ 5] = __hadd2( q5.as_half2, z1); + dq[ 6] = __hfma2( q6.as_half2, y8, z8); + dq[ 7] = __hadd2( q7.as_half2, z1); + dq[ 8] = __hfma2( q8.as_half2, y8, z8); + dq[ 9] = __hfma2( q9.as_half2, y64, z64); + dq[10] = __hadd2(q10.as_half2, z1); + dq[11] = __hfma2(q11.as_half2, y8, z8); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y8, z8); + dq[14] = __hfma2(q14.as_half2, y64, z64); + dq[15] = __hadd2(q15.as_half2, z1); +} + +#else + +__forceinline__ __device__ void shuffle_3bit_32 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_3bit_32 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[16], + int stride +) +{ + half dqh[32]; + for (int i = 0; i < 10; i++) dqh[ i] = dq_ns(exb( q_0, i * 3 , 0x07), 4); + dqh[10 ] = dq_ns(exb(q_1, q_0, 30, 0x07), 4); + for (int i = 0; i < 10; i++) dqh[11 + i] = dq_ns(exb( q_1, i * 3 + 1, 0x07), 4); + dqh[21 ] = dq_ns(exb(q_2, q_1, 31, 0x07), 4); + for (int i = 0; i < 10; i++) dqh[22 + i] = dq_ns(exb( q_2, i * 3 + 2, 0x07), 4); + + for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_4.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_4.cuh new file mode 100644 index 00000000..5fb070d0 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_4.cuh @@ -0,0 +1,227 @@ +#ifndef _qdq_4_cuh +#define _qdq_4_cuh + +#include "qdq_util.cuh" +#include "../../config.h" + +#if QMODE_4BIT == 1 + +// Permutation: +// +// 77775555 33331111 66664444 22220000 + +__forceinline__ __device__ void shuffle_4bit_8 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0]; + uint32_t qb = 0; + + #pragma unroll + for (int i = 0; i < 4; i++) + { + uint32_t qa0 = qa & 0x0f; + uint32_t qa1 = (qa & 0xf0) >> 4; + qa >>= 8; + qb |= (qa1 << (i * 4 + 16)); + qb |= (qa0 << (i * 4)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_4bit_8 +( + const uint32_t q_0, + half2 (&dq)[4], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half2 y16 = __halves2half2(y16_, y16_); + const half z1_ = __float2half_rn(-1024.0f - 8.0f); + const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z16 = __halves2half2(z16_, z16_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y16, z16); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y16, z16); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale +( + const uint32_t zero, + const half scale, + half2 (&z1z16)[2], + half2 (&y1y16)[2] +) +{ + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + half2 scale2 = __half2half2(scale); + + z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); + z1z16[1] = __hmul2(scale2, __half2half2(z16)); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __hmul2(scale2, __half2half2(y1)); + y1y16[1] = __hmul2(scale2, __half2half2(y16)); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero +( + const uint32_t zero, + half2(&z1z16)[2], + half2(&y1y16)[2] +) +{ + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + z1z16[0] = __half2half2(z1.as_half); + z1z16[1] = __half2half2(z16); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __half2half2(y1); + y1y16[1] = __half2half2(y16); +} + + +__forceinline__ __device__ void dequant_4bit_8_gptq +( + const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1z16)[2], + half2 (&y1y16)[2], + int stride, + bool scaled +) +{ + const uint32_t c0 = 0x64006400; + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 ) + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 ) + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) + + if (scaled) + { + dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) + dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) + dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); + } + else + { + dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) + dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z ) + dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z ) + } +} + +#else + +__forceinline__ __device__ void shuffle_4bit_8 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_4bit_8 +( + const uint32_t q_0, + half2 (&dq)[4], + int stride +) +{ + half dqh[8]; + for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8); + + for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale +( + const uint32_t zero, + const half scale, + half2 (&z1)[2], + half2 (&y1)[2] +) +{ + half z = __int2half_rn(-((int)zero)); + z = __hmul(z, scale); + z1[0] = __half2half2(z); + y1[0] = __half2half2(scale); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero +( + const uint32_t zero, + half2(&z1)[2], + half2(&y1)[2] +) +{ + half z = __int2half_rn(-((int)zero)); + z1[0] = __half2half2(z); +} + +__forceinline__ __device__ void dequant_4bit_8_gptq +( + const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1)[2], + half2 (&y1)[2], + int stride, + bool scaled +) +{ + half2 dqh2[8]; + + uint32_t qa = q_0; + for (int i = 0; i < 4; i++) + { + half d0 = __int2half_rn(qa & 0x0f); qa >>= 4; + half d1 = __int2half_rn(qa & 0x0f); qa >>= 4; + dqh2[i] = __halves2half2(d0, d1); + } + + if (scaled) + { + dq[0] = __hfma2(dqh2[0], y1[0], z1[0]); + dq[1] = __hfma2(dqh2[1], y1[0], z1[0]); + dq[2] = __hfma2(dqh2[2], y1[0], z1[0]); + dq[3] = __hfma2(dqh2[3], y1[0], z1[0]); + } + else + { + dq[0] = __hadd2(dqh2[0], z1[0]); + dq[1] = __hadd2(dqh2[1], z1[0]); + dq[2] = __hadd2(dqh2[2], z1[0]); + dq[3] = __hadd2(dqh2[3], z1[0]); + } +} + +#endif + +#endif \ No newline at end of file diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_5.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_5.cuh new file mode 100644 index 00000000..454e4b93 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_5.cuh @@ -0,0 +1,207 @@ +#ifndef _qdq_5_cuh +#define _qdq_5_cuh + +#include "qdq_util.cuh" +#include "../../config.h" + +#if QMODE_5BIT == 1 + +// Permutation: +// +// v5555533 33311111 u4444422 22200000 (u, v lsb) +// vbbbbb99 99977777 uaaaaa88 88866666 +// vhhhhhff fffddddd ugggggee eeeccccc +// vnnnnnll llljjjjj ummmmmkk kkkiiiii +// vtttttrr rrrppppp usssssqq qqqooooo + +__forceinline__ __device__ void shuffle_5bit_32 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + uint32_t qd = q[3 * stride]; + uint32_t qe = q[4 * stride]; + + // qa: 66555554 44443333 32222211 11100000 + // qb: ccccbbbb baaaaa99 99988888 77777666 + // qc: jiiiiihh hhhggggg fffffeee eedddddc + // qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj + // qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp + + uint32_t qf = qe >> 22; + qe <<= 8; + qe |= qd >> 24; + qd <<= 6; + qd |= qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: 555554 44443333 32222211 11100000 + // qb: bbbbba aaaa9999 98888877 77766666 + // qc: hhhhhg ggggffff feeeeedd dddccccc + // qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii + // qe: ttttts ssssrrrr rqqqqqpp pppooooo + // qf: vv vvvuuuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + uint32_t zd = 0; + uint32_t ze = 0; + + for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); } + + // za: 5555533 33311111 4444422 22200000 + // zb: bbbbb99 99977777 aaaaa88 88866666 + // zc: hhhhhff fffddddd gggggee eeeccccc + // zd: nnnnnll llljjjjj mmmmmkk kkkiiiii + // ze: tttttrr rrrppppp sssssqq qqqooooo + // qf: vv vvvuuuuu + + za |= ((qf & 0x001) >> 0) << 15; + zb |= ((qf & 0x002) >> 1) << 15; + zc |= ((qf & 0x004) >> 2) << 15; + zd |= ((qf & 0x008) >> 3) << 15; + ze |= ((qf & 0x010) >> 4) << 15; + za |= ((qf & 0x020) >> 5) << 31; + zb |= ((qf & 0x040) >> 6) << 31; + zc |= ((qf & 0x080) >> 7) << 31; + zd |= ((qf & 0x100) >> 8) << 31; + ze |= ((qf & 0x200) >> 9) << 31; + + // za: v5555533 33311111 u4444422 22200000 (u, v lsb) + // zb: vbbbbb99 99977777 uaaaaa88 88866666 + // zc: vhhhhhff fffddddd ugggggee eeeccccc + // zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii + // ze: vtttttrr rrrppppp usssssqq qqqooooo + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; + q[3 * stride] = zd; + q[4 * stride] = ze; +} + +__forceinline__ __device__ void dequant_5bit_32 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + const uint32_t q_3, + const uint32_t q_4, + half2 (&dq)[16], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y32_ = __float2half_rn(1.0f / 32.0f); + const half2 y32 = __halves2half2(y32_, y32_); + const half z1_ = __float2half_rn(-1024.0f - 16.0f); + const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z32 = __halves2half2(z32_, z32_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + uint32_t qd = q_3; + uint32_t qe = q_4; + + half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024 + qa >>= 10; + half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024 + qa >>= 5; + qa &= 0x00010001; + half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024 + half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024 + qb >>= 10; + half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024 + qb >>= 4; + qb &= 0x00020002; + half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024 + half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024 + qc >>= 10; + half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024 + qc >>= 3; + qc &= 0x00040004; + half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024 + half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024 + qd >>= 10; + half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024 + qd >>= 2; + qd &= 0x00080008; + half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024 + qe >>= 10; + half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024 + qe >>= 1; + qe &= 0x00100010; + half2_uint32 q15((qa | qb | qc | qd | qe) | c0); + + dq[ 0] = __hadd2( q0.as_half2, z1); + dq[ 1] = __hfma2( q1.as_half2, y32, z32); + dq[ 2] = __hadd2( q2.as_half2, z1); + dq[ 3] = __hadd2( q3.as_half2, z1); + dq[ 4] = __hfma2( q4.as_half2, y32, z32); + dq[ 5] = __hadd2( q5.as_half2, z1); + dq[ 6] = __hadd2( q6.as_half2, z1); + dq[ 7] = __hfma2( q7.as_half2, y32, z32); + dq[ 8] = __hadd2( q8.as_half2, z1); + dq[ 9] = __hadd2( q9.as_half2, z1); + dq[10] = __hfma2(q10.as_half2, y32, z32); + dq[11] = __hadd2(q11.as_half2, z1); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y32, z32); + dq[14] = __hadd2(q14.as_half2, z1); + dq[15] = __hadd2(q15.as_half2, z1); +} + +#else + +__forceinline__ __device__ void shuffle_5bit_32 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_5bit_32 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + const uint32_t q_3, + const uint32_t q_4, + half2 (&dq)[16], + int stride +) +{ + half dqh[32]; + for (int i = 0; i < 6; i++) dqh[ i] = dq_ns(exb( q_0, i * 5 , 0x1f), 16); + dqh[ 6 ] = dq_ns(exb(q_1, q_0, 30, 0x1f), 16); + for (int i = 0; i < 5; i++) dqh[ 7 + i] = dq_ns(exb( q_1, i * 5 + 3, 0x1f), 16); + dqh[12 ] = dq_ns(exb(q_2, q_1, 28, 0x1f), 16); + for (int i = 0; i < 6; i++) dqh[13 + i] = dq_ns(exb( q_2, i * 5 + 1, 0x1f), 16); + dqh[19 ] = dq_ns(exb(q_3, q_2, 31, 0x1f), 16); + for (int i = 0; i < 5; i++) dqh[20 + i] = dq_ns(exb( q_3, i * 5 + 4, 0x1f), 16); + dqh[25 ] = dq_ns(exb(q_4, q_3, 29, 0x1f), 16); + for (int i = 0; i < 6; i++) dqh[26 + i] = dq_ns(exb( q_4, i * 5 + 2, 0x1f), 16); + + for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif + +#endif \ No newline at end of file diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_6.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_6.cuh new file mode 100644 index 00000000..c2eb8cfb --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_6.cuh @@ -0,0 +1,44 @@ +#ifndef _qdq_6_cuh +#define _qdq_6_cuh + +#include "qdq_util.cuh" +#include "../../config.h" + +#if QMODE_6BIT == 1 + + // Not implemented + +#else + +__forceinline__ __device__ void shuffle_6bit_16 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_6bit_16 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[8], + int stride +) +{ + half dqh[16]; + for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32); + dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32); + for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32); + dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32); + for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32); + + for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif + +#endif + + diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_8.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_8.cuh new file mode 100644 index 00000000..e2409efa --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_8.cuh @@ -0,0 +1,38 @@ +#ifndef _qdq_8_cuh +#define _qdq_8_cuh + +#include "qdq_util.cuh" +#include "../../config.h" + +#if QMODE_8BIT == 1 + + // Not implemented + +#else + +__forceinline__ __device__ void shuffle_8bit_4 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_8bit_8 +( + const uint32_t q_0, + const uint32_t q_1, + half2 (&dq)[4], + int stride +) +{ + half dqh[8]; + for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128); + for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128); + + for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif + +#endif \ No newline at end of file diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh new file mode 100644 index 00000000..71657191 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh @@ -0,0 +1,51 @@ +#ifndef _qdq_util_cuh +#define _qdq_util_cuh + +union half2_uint32 +{ + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} +}; + +union half_uint16 +{ + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} +}; + +// Max_scale premultiplied by 1/256 + +__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) +{ + int qs_i = qs + 1; + half qs_h = __int2half_rn(qs_i * qs_i); + qs_h = __hmul(qs_h, max_scale); + return qs_h; +} + +__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) +{ + return __hmul(__int2half_rn(q - qzero), scale); +} + +__forceinline__ __device__ half dq_ns(const int q, const int qzero) +{ + //return __hsub(__int2half_rn(q), __int2half_rn(qzero)); + return __int2half_rn(q - qzero); +} + +__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) +{ + return (int)((q >> shift) & mask); +} + +__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) +{ + return (int)(__funnelshift_rc(q0, q1, shift) & mask); +} + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh new file mode 100644 index 00000000..06a58d18 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh @@ -0,0 +1,42 @@ + +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +#define DBGS(__x) printf("%s\n", __x) +#define DBGI(__x) printf("%s: %i\n", #__x, __x) +#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y) +#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z) +#define DBGX(__x) printf("%s: %x\n", #__x, __x) +#define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y) +#define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z) +#define DBGF(__x) printf("%s: %f\n", #__x, __x) +#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y) +#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z) +#define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x)) +#define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y)) +#define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z)) + +#define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y)) +#define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z)) + +__forceinline__ __device__ half dq_scale_(const int qs, const half max_scale) +{ + half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f)); + qs_h = __hmul(qs_h, qs_h); + qs_h = __hmul(qs_h, max_scale); + return qs_h; +} + +__forceinline__ __device__ float clamp(float x, float a, float b) +{ + return fmaxf(a, fminf(b, x)); +} + +#define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); } +inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true) +{ + if (code != cudaSuccess) + { + fprintf(stderr,"CUDA error: %s %s %d\n", cudaGetErrorString(code), file, line); + if (abort) exit(code); + } +} diff --git a/server/exllamav2_kernels/exllamav2_kernels/ext.cpp b/server/exllamav2_kernels/exllamav2_kernels/ext.cpp new file mode 100644 index 00000000..5e52e6ab --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/ext.cpp @@ -0,0 +1,134 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "config.h" + +#include "cuda/q_matrix.cuh" +#include "cuda/q_gemm.cuh" + +#include "cpp/util.h" + +// Some decluttering macros + +#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") + + +// Quant matrix + +uintptr_t make_q_matrix +( + torch::Tensor q_weight, + torch::Tensor q_perm, + torch::Tensor q_invperm, + torch::Tensor q_scale, + torch::Tensor q_scale_max, + torch::Tensor q_groups, + torch::Tensor gptq_qzeros, + torch::Tensor gptq_scales, + torch::Tensor gptq_g_idx, + torch::Tensor temp_dq +) +{ + TORCH_CHECK_DTYPE(q_weight, kInt); + TORCH_CHECK_DTYPE_OPT(q_perm, kShort); + TORCH_CHECK_DTYPE_OPT(q_invperm, kShort); + TORCH_CHECK_DTYPE_OPT(q_scale, kInt); + TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf); + TORCH_CHECK_DTYPE_OPT(q_groups, kShort); + TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt); + TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf); + TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt); + + TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1); + + int device = q_weight.device().index(); + int width = q_weight.size(1); + int groups; + int height; + + if (!q_scale.device().is_meta()) + { + TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8); + TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1); + groups = q_scale.size(0); + height = q_invperm.size(0); + } + else + { + TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8); + TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1); + groups = gptq_qzeros.size(0); + height = q_weight.size(0) * 8; + } + + TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer") + + QMatrix* m = new QMatrix + ( + device, + height, + width, + groups, + (uint32_t*) q_weight.data_ptr(), + q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(), + q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(), + q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(), + q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(), + q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(), + gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(), + gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(), + gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(), + (half*) temp_dq.data_ptr() + ); + + return reinterpret_cast (m); +} + +void gemm_half_q_half +( + torch::Tensor a, + uintptr_t b, + torch::Tensor c, + bool force_cuda +) +{ + QMatrix* qm = reinterpret_cast (b); + + TORCH_CHECK_DTYPE(a, kHalf); + TORCH_CHECK_DTYPE(c, kHalf); + TORCH_CHECK_SHAPES(a, 0, c, 0, 1); + TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes") + TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes") + + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + + gemm_half_q_half_cuda + ( + at::cuda::getCurrentCUDABlasHandle(), + (const half*) a.data_ptr(), + qm, + (half*) c.data_ptr(), + c.size(0), // m + c.size(1), // n + a.size(1), // k + true, + NULL, + force_cuda + ); +} + +// Bindings + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("make_q_matrix", &make_q_matrix, "make_q_matrix"); + m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half"); +} diff --git a/server/exllamav2_kernels/setup.py b/server/exllamav2_kernels/setup.py new file mode 100644 index 00000000..518db1df --- /dev/null +++ b/server/exllamav2_kernels/setup.py @@ -0,0 +1,17 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name="exllamav2_kernels", + ext_modules=[ + CUDAExtension( + name="exllamav2_kernels", + sources=[ + "exllamav2_kernels/ext.cpp", + "exllamav2_kernels/cuda/q_matrix.cu", + "exllamav2_kernels/cuda/q_gemm.cu", + ], + ) + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 75d2b159..fa831682 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -168,7 +168,7 @@ def serve( # When using GPTQ, Exllama kernels need some global kernels # For which we have the finale shapes only after the model has loaded # This will allocate those buffers. - from text_generation_server.utils.gptq.exllama import ( + from text_generation_server.utils.layers import ( create_exllama_buffers, set_device, ) diff --git a/server/text_generation_server/utils/gptq/exllamav2.py b/server/text_generation_server/utils/gptq/exllamav2.py new file mode 100644 index 00000000..1945338b --- /dev/null +++ b/server/text_generation_server/utils/gptq/exllamav2.py @@ -0,0 +1,191 @@ +# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2 + +from logging import getLogger + +import torch +import torch.nn as nn +import math + +logger = getLogger(__name__) + +try: + from exllamav2_kernels import make_q_matrix, gemm_half_q_half +except ImportError: + logger.error('exllamav2_kernels not installed.') + raise + +# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension +none_tensor = torch.empty((1, 1), device="meta") + +def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): + """Matrix multiplication, returns x @ q4""" + output_shape = x.shape[:-1] + (q4_width,) + x = x.view(-1, x.shape[-1]) + output = torch.empty((x.shape[0], q4_width), dtype = torch.half, device = x.device) + gemm_half_q_half(x, q_handle, output, force_cuda) + return output.view(output_shape) + +def ext_make_q_matrix(w: dict, temp_dq, key: str = None): + """ + Create Q matrix + """ + # EXL2 + # won't work as the moment because the tensors are not the same. + if "q_weight" in w: + w["q_scale_max"] /= 256 + w["q_perm"] = w["q_perm"].short() + w["q_invperm"] = w["q_invperm"].short() + return make_q_matrix(w["q_weight"], + w["q_perm"], + w["q_invperm"], + w["q_scale"], + w["q_scale_max"], + w["q_groups"], + none_tensor, + none_tensor, + none_tensor, + temp_dq) + # GPTQ + elif "qweight" in w: + if w["scales"].dtype == torch.float: + w["scales"] = w["scales"].half() + + # GPTQ with g_idx (act_order) + if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item(): + w["q_perm"] = torch.empty((w["qweight"].shape[0] * 8,), dtype = torch.short, device = w["qweight"].device) + w["q_invperm"] = torch.empty_like(w["q_perm"]) + # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. + return make_q_matrix(w["qweight"], + w["q_perm"], + w["q_invperm"], + none_tensor, + none_tensor, + none_tensor, + w["qzeros"], + w["scales"], + w["g_idx"].cpu(), + temp_dq) + # GPTQ without g_idx + else: + return make_q_matrix(w["qweight"], + none_tensor, + none_tensor, + none_tensor, + none_tensor, + none_tensor, + w["qzeros"], + w["scales"], + none_tensor, + temp_dq) + +DEVICE = None +FIXED_BYTES = 0 +LAYERS = [] + + +def set_device(device): + global DEVICE + DEVICE = device + + +def create_exllama_buffers(): + global FIXED_BYTES, LAYERS, DEVICE + temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES) + + for layer in LAYERS: + layer.post_init(temp_dq) + + +class QuantLinear(nn.Module): + QUANT_TYPE = "exllamav2" + + """Linear layer implementation with per-group 4-bit quantization of the weights""" + + # def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs): + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + super().__init__() + if bits != 4: + raise ValueError( + f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.") + self.q_handle = None + self.q_tensors = None + self.bits = bits + self.maxq = 2 ** self.bits - 1 + self.infeatures = qweight.shape[0] // self.bits * 32 + self.outfeatures = qweight.shape[1] + self.padding = - self.outfeatures % 32 + self.outfeatures = self.outfeatures + self.padding + + self.device = qweight.device + self.qweight = qweight + self.qzeros = qzeros + self.scales = scales + self.g_idx = g_idx + self.bias = bias if bias is not None else None + self.group_size = groupsize + + infeatures = self.infeatures + outfeatures = self.outfeatures + assert qweight.shape == (infeatures // 32 * self.bits, outfeatures) + assert infeatures % self.group_size == 0 + assert qzeros.shape == (infeatures // self.group_size, outfeatures // 32 * self.bits) + assert scales.shape == (infeatures // self.group_size, outfeatures) + assert g_idx.shape == (infeatures, ), f"{g_idx.shape}, {infeatures}" + + global FIXED_BYTES, LAYERS + FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed()) + LAYERS.append(self) + + def post_init(self, temp_dq): + assert self.qweight.device.type == "cuda" + assert self.qweight.device.index is not None + self.q_tensors = { + "qweight":self.qweight, + "qzeros":self.qzeros, + "scales":self.scales, + "g_idx":self.g_idx + } + temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) + self.q_handle = ext_make_q_matrix( + self.q_tensors, temp_dq + ) + + def forward(self, x, force_cuda = False): + output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) + + if self.bias is not None: + output.add_(self.bias) + return output + + def temp_dq_size(self): + return self.infeatures * self.outfeatures * 2 + 128 + + def temp_fwd_size(self, max_input_len, max_batch_size): + return self.outfeatures * max_input_len * max_batch_size * 4 + 128 + + def scratch_space_fixed(self, max_input_len=4096, max_batch_size=16): + return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size) + + +class ExLlamaV2DeviceTensors: + + device_idx: int + scratch_bytes: int + scratch_idx: int + scratch: torch.tensor = None + + def __init__(self, device, scratch_bytes): + self.device = device + self.scratch_bytes = scratch_bytes + + def prepare(self): + self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = self.device) + + def get_scratch_slice(self, size_bytes): + + if self.scratch is None: self.prepare() + + size_bytes = ((size_bytes + 127) // 128) * 128 + size_half = size_bytes // 2 + scratch_slice = self.scratch.narrow(0, 0, size_half) + return scratch_slice diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 7bb95dd2..e6a90116 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -31,15 +31,31 @@ try: major, _minor = torch.cuda.get_device_capability() except Exception: major = 1 + HAS_EXLLAMA = False CAN_EXLLAMA = major >= 8 +V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" +if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1: + logger.warning("Disabling exllama v2 and using v1 instead because there are issues when sharding") + V2 = False + if os.getenv("DISABLE_EXLLAMA") == "True": HAS_EXLLAMA = False elif CAN_EXLLAMA: try: - from text_generation_server.utils.gptq.exllama import Ex4bitLinear + if V2: + from text_generation_server.utils.gptq.exllamav2 import (QuantLinear as ExllamaQuantLinear, + create_exllama_buffers, + set_device, + ) + HAS_EXLLAMA = "2" + else: + from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear, + create_exllama_buffers, + set_device, + ) + HAS_EXLLAMA = "1" - HAS_EXLLAMA = True except ImportError: pass @@ -308,7 +324,7 @@ def get_linear(weight, bias, quantize): ) if use_exllama: - linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + linear = ExllamaQuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize) else: linear = QuantLinear( qweight, diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 2f330d9c..f3344988 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -278,23 +278,13 @@ class Weights: ) use_exllama = False else: - logger.info("Using exllama kernels") + logger.info(f"Using exllama kernels v{HAS_EXLLAMA}") if use_exllama: - if groupsize >= 0: - # Exllama reorders the weights in advance and the activations on the fly, thus - # the scales and zero-points do not need to be reordered. - qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) - scales = self.get_sharded(f"{prefix}.scales", dim=0) - else: - qzeros = self.get_tensor(f"{prefix}.qzeros") - scales = self.get_tensor(f"{prefix}.scales") - - # For tp > 1, at this point we know we do not use act-order - if self.process_group.size() == 1: - g_idx = self.get_tensor(f"{prefix}.g_idx") - else: - g_idx = None + qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) + scales = self.get_sharded(f"{prefix}.scales", dim=0) + g_idx = self.get_sharded(f"{prefix}.g_idx", dim= 0) + g_idx = g_idx - g_idx[0] else: # The triton kernel reorders the scales/zero points instead of the weight/activation. # Thus, each rank needs the full qzeros/scales.