feat: update exllamav2 kernels (#1370)

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
OlivierDehaene 2023-12-21 17:25:22 +01:00 committed by GitHub
parent 987c959f73
commit 564199bab3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 525 additions and 255 deletions

View File

@ -2,6 +2,7 @@
#define _config_h
#define MAX_Q_GEMM_ROWS 50
#define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS
#define QMODE_2BIT 1
#define QMODE_3BIT 1
@ -10,4 +11,5 @@
#define QMODE_6BIT 0
#define QMODE_8BIT 0
#endif

View File

@ -10,16 +10,19 @@
#include "quant/qdq_6.cuh"
#include "quant/qdq_8.cuh"
#define BLOCK_KN_SIZE 128
#define BLOCK_M_SIZE_MAX 8
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
#define GPTQ_BLOCK_KN_SIZE 128
#define GPTQ_BLOCK_M_SIZE_MAX 8
#define GPTQ_MAX_GROUPS_IN_BLOCK (GPTQ_BLOCK_KN_SIZE / 32)
#define EXL2_BLOCK_KN_SIZE 64
#define EXL2_BLOCK_M_SIZE_MAX 8
#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32)
#define CLEAR_N_SIZE 256
#include "q_gemm_kernel.cuh"
#include "q_gemm_kernel_gptq.cuh"
#include "compat_gemm.cuh"
void gemm_half_q_half_cuda_part
(
const half* a,
@ -29,20 +32,23 @@ void gemm_half_q_half_cuda_part
int size_n,
int size_k,
int m_count,
bool clear
bool clear,
const half* r_weights,
int r_weights_stride,
bool mul_r_weights
)
{
if (!b->is_gptq)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.x = EXL2_BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
gridDim.z = DIVIDE(size_k, EXL2_BLOCK_KN_SIZE);
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count);
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights);
kernel<<<gridDim, blockDim>>>
(
@ -55,7 +61,7 @@ void gemm_half_q_half_cuda_part
size_n,
size_k,
b->groups,
b->groupsize,
b->cuda_q_group_map,
b->cuda_q_perm,
b->rows_8,
b->rows_6,
@ -63,24 +69,27 @@ void gemm_half_q_half_cuda_part
b->rows_4,
b->rows_3,
b->rows_2,
clear
clear,
r_weights,
r_weights_stride
);
}
else
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.x = GPTQ_BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
gridDim.x = DIVIDE(size_n, GPTQ_BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
gridDim.z = DIVIDE(size_k, GPTQ_BLOCK_KN_SIZE);
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(m_count, r_weights != NULL, mul_r_weights);
// DBGX((uint64_t) b->cuda_q_perm);
// DBGI(b->rows_4);
// DBGI(b->height);
// DBGX((uint64_t) r_weights);
// if (r_weights)
// print_global_mem(r_weights, 1, 1, 1);
// DBGI(r_weights_stride);
kernel<<<gridDim, blockDim>>>
(
@ -93,10 +102,12 @@ void gemm_half_q_half_cuda_part
size_n,
size_k,
b->groups,
b->groupsize,
b->gptq_groupsize,
b->cuda_q_perm,
b->rows_4,
clear
clear,
r_weights,
r_weights_stride
);
}
}
@ -112,13 +123,14 @@ void gemm_half_q_half_cuda
int size_k,
bool clear,
half* temp_dq,
bool force_cuda
bool force_cuda,
const half* r_weights,
const int r_weights_stride,
bool mul_r_weights
)
{
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
{
//printf("cublas\n");
// Reconstruct FP16 matrix, then cuBLAS
if (!temp_dq) temp_dq = b->temp_dq;
@ -139,12 +151,12 @@ void gemm_half_q_half_cuda
//const float alpha = 1.0f;
//const float beta = clear ? 0.0f : 1.0f;
//cublasSgemmEx(cublas_handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// size_n, size_m, size_k,
// &alpha, temp_dq, CUDA_R_16F, size_n,
// a, CUDA_R_16F, size_k,
// &beta, c, CUDA_R_16F, size_n);
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// size_n, size_m, size_k,
// &alpha, temp_dq, CUDA_R_16F, size_n,
// a, CUDA_R_16F, size_k,
// &beta, c, CUDA_R_16F, size_n);
//const float alpha = 1.0f;
//const float beta = clear ? 0.0f : 1.0f;
@ -158,24 +170,21 @@ void gemm_half_q_half_cuda
}
else
{
//printf("cuda\n");
// Quantized matmul
//if (clear) clear_tensor_cuda(c, size_m, size_n);
int max_chunks = size_m / BLOCK_M_SIZE_MAX;
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
int block_m_size_max = b->is_gptq ? GPTQ_BLOCK_M_SIZE_MAX : EXL2_BLOCK_M_SIZE_MAX;
int max_chunks = size_m / block_m_size_max;
int last_chunk = max_chunks * block_m_size_max;
int last_chunk_size = size_m - last_chunk;
if (max_chunks)
{
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear);
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, block_m_size_max, clear, r_weights, r_weights_stride, mul_r_weights);
}
if (last_chunk_size)
{
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear, r_weights, r_weights_stride, mul_r_weights);
}
}
}
@ -201,11 +210,10 @@ void clear_tensor_cuda
int size_n
)
{
return;
dim3 blockDim, gridDim;
blockDim.x = CLEAR_N_SIZE;
blockDim.y = 1;
gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
gridDim.y = size_m;
clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
// dim3 blockDim, gridDim;
// blockDim.x = CLEAR_N_SIZE;
// blockDim.y = 1;
// gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
// gridDim.y = size_m;
// clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
}

View File

@ -20,7 +20,10 @@ void gemm_half_q_half_cuda
int size_k,
bool clear = false,
half* reconstruct = NULL,
bool force_cuda = false
bool force_cuda = false,
const half* r_weights = NULL,
const int r_weights_stride = 0,
bool mul_r_weights = false
);
void clear_tensor_cuda

View File

@ -1,8 +1,5 @@
#include "compat.cuh"
#include <cuda_runtime.h>
#include <cuda_fp16.h>
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
@ -60,6 +57,47 @@ __forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, c
return fma(result_f, qs_f, g_result);
}
__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h)
{
// Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127
float result = {};
#pragma unroll
for (int i = 0; i < 4; i++)
{
half2 w01 = dq[i];
float w0 = __low2float(w01);
float w1 = __high2float(w01);
float x0 = __half2float(*a_ptr++);
float x1 = __half2float(*a_ptr++);
result = fma(w0, x0, result);
result = fma(w1, x1, result);
}
float qs = __half2float(qs_h);
result *= qs;
half result_h = __float2half_rn(result);
return __hadd(result_h, g_result);
}
__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
half result_h = __hadd(__low2half(result), __high2half(result));
return __hfma(result_h, qs_h, g_result);
}
__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
half result_h = __hadd(__low2half(result), __high2half(result));
return __hfma(result_h, qs_h, g_result);
}
typedef void (*fp_gemm_half_q_half_kernel)
@ -73,7 +111,7 @@ typedef void (*fp_gemm_half_q_half_kernel)
const int,
const int,
const int,
const int,
const uint16_t*,
const uint16_t*,
const int,
const int,
@ -81,10 +119,12 @@ typedef void (*fp_gemm_half_q_half_kernel)
const int,
const int,
const int,
const bool
const bool,
const half*,
const int
);
template <bool first_block, int m_count>
template <int m_count, bool use_r_weights, bool mul_r_weights>
__global__ void gemm_half_q_half_kernel
(
const half* __restrict__ a,
@ -96,7 +136,7 @@ __global__ void gemm_half_q_half_kernel
const int size_n,
const int size_k,
const int groups,
const int groupsize,
const uint16_t* __restrict__ b_q_group_map,
const uint16_t* __restrict__ b_q_perm,
const int rows_8,
const int rows_6,
@ -104,7 +144,9 @@ __global__ void gemm_half_q_half_kernel
const int rows_4,
const int rows_3,
const int rows_2,
const bool clear
const bool clear,
const half* r_weights,
const int r_weights_stride
)
{
MatrixView_half a_(a, size_m, size_k);
@ -115,18 +157,34 @@ __global__ void gemm_half_q_half_kernel
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_n = blockIdx.x * EXL2_BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int offset_k = blockIdx.z * EXL2_BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_n = min(offset_n + EXL2_BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int end_k = min(offset_k + EXL2_BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Read weights
half_uint16 weights[MAX_Q_GEMM_WEIGHTS];
if constexpr (use_r_weights)
{
uint16_t any_w = 0;
const half* w_ptr = r_weights;
for (int m = 0; m < m_count; ++m)
{
weights[m].as_half = *w_ptr;
w_ptr += r_weights_stride;
any_w |= weights[m].as_uint16;
}
if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!)
}
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
__shared__ half block_a[m_count][EXL2_BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
@ -135,6 +193,7 @@ __global__ void gemm_half_q_half_kernel
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0 = a_ptr[b_q_perm[offset_k + t]];
// half a0 = a_ptr[offset_k + t];
block_a_ptr[t] = a0;
}
}
@ -153,14 +212,19 @@ __global__ void gemm_half_q_half_kernel
// Find initial group
int group = offset_k / groupsize;
//int group = offset_k / groupsize;
int group = b_q_group_map[offset_k * 2];
// if (offset_m == 0 && t == 0)
// DBGI2(offset_k, group);
// Preload scales
float scales[MAX_GROUPS_IN_BLOCK][4];
half scales[EXL2_MAX_GROUPS_IN_BLOCK][4];
int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
for (int g = 0; g < groups_in_block; g++)
//int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
int temp_k = offset_k;
for (int g = 0; temp_k < end_k; g++)
{
int qscales[4];
b_q_scale_.item4(qscales, group + g, n);
@ -168,11 +232,12 @@ __global__ void gemm_half_q_half_kernel
qscales[1]++;
qscales[2]++;
qscales[3]++;
float maxscale = __half2float(b_q_scale_max[group + g]);
scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale;
scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale;
scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale;
scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale;
half maxscale = b_q_scale_max[group + g];
scales[g][0] = __hmul(__int2half_rn(qscales[0] * qscales[0]), maxscale);
scales[g][1] = __hmul(__int2half_rn(qscales[1] * qscales[1]), maxscale);
scales[g][2] = __hmul(__int2half_rn(qscales[2] * qscales[2]), maxscale);
scales[g][3] = __hmul(__int2half_rn(qscales[3] * qscales[3]), maxscale);
temp_k += b_q_group_map[temp_k * 2 + 1];
}
// a, b offset
@ -193,20 +258,20 @@ __global__ void gemm_half_q_half_kernel
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
int a_stride = EXL2_BLOCK_KN_SIZE;
// Initial group
int scales_idx = 0;
float qs_f0 = scales[scales_idx][0];
float qs_f1 = scales[scales_idx][1];
float qs_f2 = scales[scales_idx][2];
float qs_f3 = scales[scales_idx][3];
int nextgroup = offset_k + groupsize;
half qs_h0 = scales[scales_idx][0];
half qs_h1 = scales[scales_idx][1];
half qs_h2 = scales[scales_idx][2];
half qs_h3 = scales[scales_idx][3];
int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
// Column result
float block_c[m_count][4] = {};
half block_c[m_count][4] = {};
// Dequantize groups
@ -218,11 +283,11 @@ __global__ void gemm_half_q_half_kernel
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
qs_h0 = scales[scales_idx][0];
qs_h1 = scales[scales_idx][1];
qs_h2 = scales[scales_idx][2];
qs_h3 = scales[scales_idx][3];
nextgroup += b_q_group_map[k * 2 + 1];
}
#pragma unroll
@ -240,10 +305,11 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
}
a_ptr += 8;
}
@ -256,11 +322,11 @@ __global__ void gemm_half_q_half_kernel
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
qs_h0 = scales[scales_idx][0];
qs_h1 = scales[scales_idx][1];
qs_h2 = scales[scales_idx][2];
qs_h3 = scales[scales_idx][3];
nextgroup += b_q_group_map[k * 2 + 1];
}
#pragma unroll
@ -279,10 +345,11 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
}
a_ptr += 16;
}
@ -295,11 +362,11 @@ __global__ void gemm_half_q_half_kernel
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
qs_h0 = scales[scales_idx][0];
qs_h1 = scales[scales_idx][1];
qs_h2 = scales[scales_idx][2];
qs_h3 = scales[scales_idx][3];
nextgroup += b_q_group_map[k * 2 + 1];
}
#pragma unroll
@ -320,10 +387,11 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
}
a_ptr += 32;
}
@ -337,11 +405,11 @@ __global__ void gemm_half_q_half_kernel
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
qs_h0 = scales[scales_idx][0];
qs_h1 = scales[scales_idx][1];
qs_h2 = scales[scales_idx][2];
qs_h3 = scales[scales_idx][3];
nextgroup += b_q_group_map[k * 2 + 1];
}
#pragma unroll
@ -358,10 +426,11 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
}
a_ptr += 8;
}
@ -374,11 +443,11 @@ __global__ void gemm_half_q_half_kernel
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
qs_h0 = scales[scales_idx][0];
qs_h1 = scales[scales_idx][1];
qs_h2 = scales[scales_idx][2];
qs_h3 = scales[scales_idx][3];
nextgroup += b_q_group_map[k * 2 + 1];
}
#pragma unroll
@ -397,10 +466,11 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
}
a_ptr += 32;
}
@ -413,15 +483,15 @@ __global__ void gemm_half_q_half_kernel
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
qs_h0 = scales[scales_idx][0];
qs_h1 = scales[scales_idx][1];
qs_h2 = scales[scales_idx][2];
qs_h3 = scales[scales_idx][3];
nextgroup += b_q_group_map[k * 2 + 1];
}
#pragma unroll
for (int j = 0; j < 2; j++)
for (int j = 0; j < 1; j++)
{
int4 load_int4[1];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
@ -434,15 +504,16 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
}
a_ptr += 16;
}
k += 32;
k += 16;
}
// Accumulate column sums in c
@ -450,38 +521,60 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++)
{
half2* out = (half2*)c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
if constexpr (mul_r_weights)
{
half2 w_mul2 = __half2half2(weights[m].as_half);
result01 = __hmul2(result01, w_mul2);
result23 = __hmul2(result23, w_mul2);
}
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
// *out = result01;
// *(out + 1) = result23;
}
}
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count)
template <bool use_r_weights, bool mul_r_weights>
struct map_m_count_exl2 {
static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count)
{
#if EXL2_BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_kernel<1, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_kernel<2, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_kernel<3, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_kernel<4, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_kernel<5, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_kernel<6, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_kernel<7, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_kernel<8, use_r_weights, mul_r_weights>;
#endif
return NULL;
}
};
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count, bool r_weights, bool mul_r_weights)
{
#if BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_kernel<true, 1>;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_kernel<true, 2>;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_kernel<true, 3>;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_kernel<true, 4>;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_kernel<true, 5>;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_kernel<true, 6>;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_kernel<true, 7>;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_kernel<true, 8>;
#endif
if (!r_weights && !mul_r_weights) return map_m_count_exl2<false, false>::pick_gemm_half_q_half_kernel(m_count);
if (!r_weights && mul_r_weights) return map_m_count_exl2<false, true>::pick_gemm_half_q_half_kernel(m_count);
if ( r_weights && !mul_r_weights) return map_m_count_exl2< true, false>::pick_gemm_half_q_half_kernel(m_count);
if ( r_weights && mul_r_weights) return map_m_count_exl2< true, true>::pick_gemm_half_q_half_kernel(m_count);
return NULL;
}

View File

@ -18,6 +18,15 @@ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
return __half2float(__low2half(result)) + __half2float(__high2half(result));
}
__forceinline__ __device__ half2 dot22_8_h2(half2(&dq)[4], const half* a_ptr)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return result;
}
typedef void (*fp_gemm_half_q_half_gptq_kernel)
(
const half*,
@ -32,10 +41,12 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)
const int,
const uint16_t*,
const int,
const bool
const bool,
const half*,
const int
);
template <bool first_block, int m_count>
template <int m_count, bool use_r_weights, bool mul_r_weights>
__global__ void gemm_half_q_half_gptq_kernel
(
const half* __restrict__ a,
@ -50,7 +61,9 @@ __global__ void gemm_half_q_half_gptq_kernel
const int groupsize,
const uint16_t* __restrict__ b_q_perm,
const int rows_4,
const bool clear
const bool clear,
const half* r_weights,
const int r_weights_stride
)
{
MatrixView_half a_(a, size_m, size_k);
@ -62,19 +75,35 @@ __global__ void gemm_half_q_half_gptq_kernel
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_n = blockIdx.x * GPTQ_BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int offset_k = blockIdx.z * GPTQ_BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_n = min(offset_n + GPTQ_BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int end_k = min(offset_k + GPTQ_BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Read weights
half_uint16 weights[MAX_Q_GEMM_WEIGHTS];
if constexpr (use_r_weights)
{
uint16_t any_w = 0;
const half* w_ptr = r_weights;
for (int m = 0; m < m_count; ++m)
{
weights[m].as_half = *w_ptr;
w_ptr += r_weights_stride;
any_w |= weights[m].as_uint16;
}
if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!)
}
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
__shared__ half block_a[m_count][GPTQ_BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
@ -113,16 +142,16 @@ __global__ void gemm_half_q_half_gptq_kernel
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
int a_stride = GPTQ_BLOCK_KN_SIZE;
// Initial group
int zeros[4];
float scales[4];
half2 scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
@ -132,7 +161,7 @@ __global__ void gemm_half_q_half_gptq_kernel
// Column result
float block_c[m_count][4] = {};
half2 block_c[m_count][4] = {};
// Dequantize and multiply
@ -144,7 +173,7 @@ __global__ void gemm_half_q_half_gptq_kernel
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
@ -166,10 +195,11 @@ __global__ void gemm_half_q_half_gptq_kernel
#pragma unroll
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][0] = __hfma2(dot22_8_h2(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
block_c[m][1] = __hfma2(dot22_8_h2(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
block_c[m][2] = __hfma2(dot22_8_h2(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
block_c[m][3] = __hfma2(dot22_8_h2(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
}
b_ptr += size_n;
@ -182,38 +212,62 @@ __global__ void gemm_half_q_half_gptq_kernel
for (int m = 0; m < m_count; m++)
{
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
half result0 = __hadd(__low2half(block_c[m][0]), __high2half(block_c[m][0]));
half result1 = __hadd(__low2half(block_c[m][1]), __high2half(block_c[m][1]));
half result2 = __hadd(__low2half(block_c[m][2]), __high2half(block_c[m][2]));
half result3 = __hadd(__low2half(block_c[m][3]), __high2half(block_c[m][3]));
half2 result01 = __halves2half2(result0, result1);
half2 result23 = __halves2half2(result2, result3);
if constexpr (mul_r_weights)
{
half2 w_mul2 = __half2half2(weights[m].as_half);
result01 = __hmul2(result01, w_mul2);
result23 = __hmul2(result23, w_mul2);
}
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
template <bool use_r_weights, bool mul_r_weights>
struct map_m_count_gptq {
static constexpr fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(int m_count)
{
#if GPTQ_BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_gptq_kernel<1, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_gptq_kernel<2, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_gptq_kernel<3, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_gptq_kernel<4, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_gptq_kernel<5, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_gptq_kernel<6, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_gptq_kernel<7, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_gptq_kernel<8, use_r_weights, mul_r_weights>;
#endif
return NULL;
}
};
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(const int m_count, bool r_weights, bool mul_r_weights)
{
#if BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
#endif
if (!r_weights && !mul_r_weights) return map_m_count_gptq<false, false>::pick_gemm_half_q_half_gptq_kernel(m_count);
if (!r_weights && mul_r_weights) return map_m_count_gptq<false, true>::pick_gemm_half_q_half_gptq_kernel(m_count);
if ( r_weights && !mul_r_weights) return map_m_count_gptq< true, false>::pick_gemm_half_q_half_gptq_kernel(m_count);
if ( r_weights && mul_r_weights) return map_m_count_gptq< true, true>::pick_gemm_half_q_half_gptq_kernel(m_count);
return NULL;
}

View File

@ -57,6 +57,7 @@ QMatrix::QMatrix
uint32_t* _q_scale,
half* _q_scale_max,
uint16_t* _q_groups,
uint16_t* _q_group_map,
uint32_t* _gptq_qzeros,
half* _gptq_scales,
@ -80,13 +81,17 @@ QMatrix::QMatrix
cuda_q_scale = _q_scale;
cuda_q_scale_max = _q_scale_max;
cuda_q_groups = _q_groups;
cuda_q_group_map = _q_group_map;
cuda_gptq_qzeros = _gptq_qzeros;
cuda_gptq_scales = _gptq_scales;
is_gptq = (_gptq_qzeros != NULL);
groupsize = 1;
while (groupsize * groups < height) groupsize *= 2;
if (is_gptq)
{
gptq_groupsize = 1;
while (gptq_groupsize * groups < height) gptq_groupsize *= 2;
}
// Create group map
@ -102,15 +107,26 @@ QMatrix::QMatrix
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
int row = 0;
for (int i = 0; i < groups; i++)
{
int bits = cpu_q_groups[i * 2];
if (bits == 8) rows_8 += groupsize;
if (bits == 6) rows_6 += groupsize;
if (bits == 5) rows_5 += groupsize;
if (bits == 4) rows_4 += groupsize;
if (bits == 3) rows_3 += groupsize;
if (bits == 2) rows_2 += groupsize;
int rows;
if (i < groups - 1)
{
int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1];
rows = qrows * 32 / bits;
}
else rows = height - row;
if (bits == 8) rows_8 += rows;
if (bits == 6) rows_6 += rows;
if (bits == 5) rows_5 += rows;
if (bits == 4) rows_4 += rows;
if (bits == 3) rows_3 += rows;
if (bits == 2) rows_2 += rows;
row += rows;
}
free(cpu_q_groups);
@ -138,6 +154,13 @@ QMatrix::QMatrix
}
}
// DBGI(rows_8);
// DBGI(rows_6);
// DBGI(rows_5);
// DBGI(rows_4);
// DBGI(rows_3);
// DBGI(rows_2);
// Shuffle quantized data
dim3 blockDim, gridDim;
@ -283,10 +306,10 @@ __global__ void reconstruct_kernel
const uint16_t* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_q_scale,
const half* __restrict__ b_q_scale_max,
//const uint16_t* __restrict__ b_q_groups,
const uint16_t* __restrict__ b_q_group_map,
const int size_k,
const int size_n,
const int groupsize,
//const int groupsize,
const int groups,
half* __restrict__ b,
const int rows_8,
@ -317,7 +340,8 @@ __global__ void reconstruct_kernel
// Find initial group
int group = offset_k / groupsize;
// int group = offset_k / groupsize;
int group = b_q_group_map[offset_k * 2];
int pre_rows_8 = min(rows_8, offset_k);
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
@ -337,7 +361,7 @@ __global__ void reconstruct_kernel
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
half2 qs_h2 = __halves2half2(qs_h, qs_h);
int nextgroup = offset_k + groupsize;
int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int k = offset_k;
@ -347,7 +371,7 @@ __global__ void reconstruct_kernel
while (k < rows_8 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 4; p++)
{
half2 dq[4];
@ -363,7 +387,7 @@ __global__ void reconstruct_kernel
while (k < rows_6 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 2; p++)
{
half2 dq[8];
@ -380,7 +404,7 @@ __global__ void reconstruct_kernel
while (k < rows_5 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++)
{
half2 dq[16];
@ -399,7 +423,7 @@ __global__ void reconstruct_kernel
while (k < rows_4 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 4; p++)
{
half2 dq[4];
@ -414,7 +438,7 @@ __global__ void reconstruct_kernel
while (k < rows_3 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++)
{
half2 dq[16];
@ -431,8 +455,8 @@ __global__ void reconstruct_kernel
while (k < rows_2 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 2; p++)
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++)
{
half2 dq[8];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
@ -441,7 +465,7 @@ __global__ void reconstruct_kernel
half* dqh = (half*) dq;
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
k += 16;
}
}
@ -461,10 +485,10 @@ void QMatrix::reconstruct(half* out)
cuda_q_perm,
cuda_q_scale,
cuda_q_scale_max,
//cuda_q_groups,
cuda_q_group_map,
height,
width,
groupsize,
//groupsize,
groups,
out,
rows_8,
@ -487,7 +511,7 @@ void QMatrix::reconstruct(half* out)
//const uint16_t* __restrict__ b_q_groups,
height,
width,
groupsize,
gptq_groupsize,
groups,
out,
rows_4

View File

@ -18,7 +18,7 @@ public:
int height;
int width;
int groups;
int groupsize;
int gptq_groupsize;
int rows_8;
int rows_6;
@ -33,6 +33,7 @@ public:
uint32_t* cuda_q_scale = NULL;
half* cuda_q_scale_max = NULL;
uint16_t* cuda_q_groups = NULL;
uint16_t* cuda_q_group_map = NULL;
uint32_t* cuda_gptq_qzeros = NULL;
half* cuda_gptq_scales = NULL;
@ -53,6 +54,7 @@ public:
uint32_t* _q_scale,
half* _q_scale_max,
uint16_t* _q_groups,
uint16_t* _q_group_map,
uint32_t* _gptq_qzeros,
half* _gptq_scales,

View File

@ -7,6 +7,7 @@ union half2_uint32
half2 as_half2;
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
__device__ half2_uint32(half2 val) : as_half2(val) {}
__device__ half2_uint32() : as_uint32(0) {}
};
union half_uint16
@ -15,6 +16,7 @@ union half_uint16
half as_half;
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
__device__ half_uint16(half val) : as_half(val) {}
__device__ half_uint16() : as_uint16(0) {}
};
// Max_scale premultiplied by 1/256

View File

@ -1,3 +1,11 @@
#ifndef _util_cuh
#define _util_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/cuda/CUDAContext.h>
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
@ -40,3 +48,7 @@ inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=
if (abort) exit(code);
}
}
void print_global_mem(const half* ptr, int rows, int columns, int stride);
#endif

View File

@ -31,6 +31,7 @@ uintptr_t make_q_matrix
torch::Tensor q_scale,
torch::Tensor q_scale_max,
torch::Tensor q_groups,
torch::Tensor q_group_map,
torch::Tensor gptq_qzeros,
torch::Tensor gptq_scales,
torch::Tensor gptq_g_idx,
@ -43,6 +44,7 @@ uintptr_t make_q_matrix
TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
TORCH_CHECK_DTYPE_OPT(q_group_map, kShort);
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
@ -83,12 +85,15 @@ uintptr_t make_q_matrix
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
q_group_map.device().is_meta() ? NULL : (uint16_t*) q_group_map.data_ptr(),
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
(half*) temp_dq.data_ptr()
);
if (m->failed) throw std::runtime_error("CUDA out of memory");
return reinterpret_cast<uintptr_t> (m);
}

View File

@ -32,10 +32,10 @@ def fresh_cache():
current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
os.environ['HUGGINGFACE_HUB_CACHE'] = d
os.environ["HUGGINGFACE_HUB_CACHE"] = d
yield
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
os.environ['HUGGINGFACE_HUB_CACHE'] = current_value
os.environ["HUGGINGFACE_HUB_CACHE"] = current_value
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value
@ -47,7 +47,7 @@ def prefetched():
revision="main",
local_files_only=False,
repo_type="model",
allow_patterns=["*.safetensors"]
allow_patterns=["*.safetensors"],
)
yield model_id
@ -61,7 +61,7 @@ def test_weight_hub_files_offline_error(offline, fresh_cache):
def test_weight_hub_files_offline_ok(prefetched, offline):
# If the model is prefetched then we should be able to get the weight files from local cache
filenames = weight_hub_files(prefetched)
assert filenames == ['model.safetensors']
assert filenames == ["model.safetensors"]
def test_weight_hub_files():

View File

@ -71,7 +71,7 @@ def _load_multi_mqa_gptq(
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device)
bits, groupsize = weights._get_gptq_params()
bits, groupsize, _ = weights._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA

View File

@ -27,6 +27,32 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
return output.view(output_shape)
# Group map needed for irregular group sizes
def make_group_map(q_groups, num_qrows):
gr = q_groups.tolist()
group_map = []
num_groups = len(gr) // 2
for i in range(num_groups):
bits = gr[i * 2]
if i < num_groups - 1:
qrows = gr[i * 2 + 3] - gr[i * 2 + 1]
else:
qrows = num_qrows - gr[i * 2 + 1]
rows = qrows * 32 // bits
for j in range(rows):
group_map += [i]
group_map += [rows - j]
return torch.tensor(group_map, dtype=torch.short, device=q_groups.device)
# Create Q matrix
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
"""
Create Q matrix
@ -37,6 +63,10 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
w["q_scale_max"] /= 256
w["q_perm"] = w["q_perm"].short()
w["q_invperm"] = w["q_invperm"].short()
if "q_group_map" not in w:
w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0])
return make_q_matrix(
w["q_weight"],
w["q_perm"],
@ -44,6 +74,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
w["q_scale"],
w["q_scale_max"],
w["q_groups"],
w["q_group_map"],
none_tensor,
none_tensor,
none_tensor,
@ -70,6 +101,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
none_tensor,
none_tensor,
none_tensor,
none_tensor,
w["qzeros"],
w["scales"],
w["g_idx"].cpu(),
@ -84,6 +116,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
none_tensor,
none_tensor,
none_tensor,
none_tensor,
w["qzeros"],
w["scales"],
none_tensor,

View File

@ -18,7 +18,9 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
def _cached_weight_files(model_id: str, revision: Optional[str], extension: str) -> List[str]:
def _cached_weight_files(
model_id: str, revision: Optional[str], extension: str
) -> List[str]:
"""Guess weight files from the cached revision snapshot directory"""
d = _get_cached_revision_directory(model_id, revision)
if not d:
@ -27,7 +29,9 @@ def _cached_weight_files(model_id: str, revision: Optional[str], extension: str)
return filenames
def _weight_hub_files_from_model_info(info: hf_api.ModelInfo, extension: str) -> List[str]:
def _weight_hub_files_from_model_info(
info: hf_api.ModelInfo, extension: str
) -> List[str]:
return [
s.rfilename
for s in info.siblings
@ -44,21 +48,27 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
# see _weight_hub_files_from_model_info, that's also what is
# done there with the len(s.rfilename.split("/")) == 1 condition
root, _, files = next(os.walk(str(d)))
filenames = [f for f in files
if f.endswith(extension)
and "arguments" not in f
and "args" not in f
and "adapter" not in f
and "training" not in f]
filenames = [
f
for f in files
if f.endswith(extension)
and "arguments" not in f
and "args" not in f
and "adapter" not in f
and "training" not in f
]
return filenames
def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Optional[Path]:
def _get_cached_revision_directory(
model_id: str, revision: Optional[str]
) -> Optional[Path]:
if revision is None:
revision = "main"
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path(
file_download.repo_folder_name(repo_id=model_id, repo_type="model"))
file_download.repo_folder_name(repo_id=model_id, repo_type="model")
)
if not repo_cache.is_dir():
# No cache for this model
@ -86,7 +96,7 @@ def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Op
def weight_hub_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[str]:
"""Get the weights filenames on the hub"""
api = HfApi()

View File

@ -19,6 +19,7 @@ from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
from text_generation_server.utils.log import log_once
HAS_AWQ = True
try:
@ -35,10 +36,11 @@ HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
logger.warning(
V2 = False
log_once(
logger.warning,
"Disabling exllama v2 and using v1 instead because there are issues when sharding"
)
V2 = False
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False

View File

@ -0,0 +1,6 @@
from functools import lru_cache
@lru_cache(10)
def log_once(log, msg:str):
log(msg)

View File

@ -6,6 +6,7 @@ import torch
from loguru import logger
from huggingface_hub import hf_hub_download
import json
from text_generation_server.utils.log import log_once
class Weights:
@ -161,7 +162,7 @@ class Weights:
else:
g_idx = None
bits, groupsize = self._get_gptq_params()
bits, groupsize, _ = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else:
slice_ = self._get_slice(f"{prefix}.weight")
@ -211,10 +212,10 @@ class Weights:
else:
g_idx = None
bits, groupsize = self._get_gptq_params()
bits, groupsize, desc_act = self._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq"
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
@ -240,11 +241,15 @@ class Weights:
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
use_exllama = True
bits, groupsize = self._get_gptq_params()
bits, groupsize, desc_act = self._get_gptq_params()
if bits != 4:
use_exllama = False
if desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
if self.process_group.size() > 1:
g_idx = self.get_tensor(f"{prefix}.g_idx")
if g_idx is not None:
@ -274,12 +279,18 @@ class Weights:
if use_exllama:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
logger.warning(
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True"
)
use_exllama = False
else:
logger.info(f"Using exllama kernels v{HAS_EXLLAMA}")
log_once(
logger.info,
f"Using exllama kernels v{HAS_EXLLAMA}"
)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if use_exllama and groupsize != -1:
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
@ -288,14 +299,12 @@ class Weights:
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if use_exllama:
g_idx = g_idx - g_idx[0]
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
elif quantize == "awq":
bits, groupsize = self._get_gptq_params()
bits, groupsize, _ = self._get_gptq_params()
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
@ -314,18 +323,20 @@ class Weights:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
def _get_gptq_params(self) -> Tuple[int, int]:
def _get_gptq_params(self) -> Tuple[int, int, int]:
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
desc_act = False
except (SafetensorError, RuntimeError) as e:
try:
bits = self.gptq_bits
groupsize = self.gptq_groupsize
desc_act = getattr(self, "gptq_desc_act", False)
except Exception:
raise e
return bits, groupsize
return bits, groupsize, desc_act
def _set_gptq_params(self, model_id, revision):
filename = "config.json"
@ -340,6 +351,7 @@ class Weights:
data = json.load(f)
self.gptq_bits = data["quantization_config"]["bits"]
self.gptq_groupsize = data["quantization_config"]["group_size"]
self.gptq_desc_act = data["quantization_config"]["desc_act"]
except Exception:
filename = "quantize_config.json"
try:
@ -353,6 +365,7 @@ class Weights:
data = json.load(f)
self.gptq_bits = data["bits"]
self.gptq_groupsize = data["group_size"]
self.gptq_desc_act = data["desc_act"]
except Exception:
filename = "quant_config.json"
try:
@ -366,5 +379,6 @@ class Weights:
data = json.load(f)
self.gptq_bits = data["w_bit"]
self.gptq_groupsize = data["q_group_size"]
self.gptq_desc_act = data["desc_act"]
except Exception:
pass