feat: update exllamav2 kernels (#1370)

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
OlivierDehaene 2023-12-21 17:25:22 +01:00 committed by GitHub
parent 987c959f73
commit 564199bab3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 525 additions and 255 deletions

View File

@ -2,6 +2,7 @@
#define _config_h #define _config_h
#define MAX_Q_GEMM_ROWS 50 #define MAX_Q_GEMM_ROWS 50
#define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS
#define QMODE_2BIT 1 #define QMODE_2BIT 1
#define QMODE_3BIT 1 #define QMODE_3BIT 1
@ -10,4 +11,5 @@
#define QMODE_6BIT 0 #define QMODE_6BIT 0
#define QMODE_8BIT 0 #define QMODE_8BIT 0
#endif #endif

View File

@ -10,16 +10,19 @@
#include "quant/qdq_6.cuh" #include "quant/qdq_6.cuh"
#include "quant/qdq_8.cuh" #include "quant/qdq_8.cuh"
#define BLOCK_KN_SIZE 128 #define GPTQ_BLOCK_KN_SIZE 128
#define BLOCK_M_SIZE_MAX 8 #define GPTQ_BLOCK_M_SIZE_MAX 8
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32) #define GPTQ_MAX_GROUPS_IN_BLOCK (GPTQ_BLOCK_KN_SIZE / 32)
#define EXL2_BLOCK_KN_SIZE 64
#define EXL2_BLOCK_M_SIZE_MAX 8
#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32)
#define CLEAR_N_SIZE 256 #define CLEAR_N_SIZE 256
#include "q_gemm_kernel.cuh" #include "q_gemm_kernel.cuh"
#include "q_gemm_kernel_gptq.cuh" #include "q_gemm_kernel_gptq.cuh"
#include "compat_gemm.cuh"
void gemm_half_q_half_cuda_part void gemm_half_q_half_cuda_part
( (
const half* a, const half* a,
@ -29,20 +32,23 @@ void gemm_half_q_half_cuda_part
int size_n, int size_n,
int size_k, int size_k,
int m_count, int m_count,
bool clear bool clear,
const half* r_weights,
int r_weights_stride,
bool mul_r_weights
) )
{ {
if (!b->is_gptq) if (!b->is_gptq)
{ {
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE; blockDim.x = EXL2_BLOCK_KN_SIZE;
blockDim.y = 1; blockDim.y = 1;
blockDim.z = 1; blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count); gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); gridDim.z = DIVIDE(size_k, EXL2_BLOCK_KN_SIZE);
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count); fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights);
kernel<<<gridDim, blockDim>>> kernel<<<gridDim, blockDim>>>
( (
@ -55,7 +61,7 @@ void gemm_half_q_half_cuda_part
size_n, size_n,
size_k, size_k,
b->groups, b->groups,
b->groupsize, b->cuda_q_group_map,
b->cuda_q_perm, b->cuda_q_perm,
b->rows_8, b->rows_8,
b->rows_6, b->rows_6,
@ -63,24 +69,27 @@ void gemm_half_q_half_cuda_part
b->rows_4, b->rows_4,
b->rows_3, b->rows_3,
b->rows_2, b->rows_2,
clear clear,
r_weights,
r_weights_stride
); );
} }
else else
{ {
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE; blockDim.x = GPTQ_BLOCK_KN_SIZE;
blockDim.y = 1; blockDim.y = 1;
blockDim.z = 1; blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); gridDim.x = DIVIDE(size_n, GPTQ_BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count); gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); gridDim.z = DIVIDE(size_k, GPTQ_BLOCK_KN_SIZE);
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count); fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(m_count, r_weights != NULL, mul_r_weights);
// DBGX((uint64_t) b->cuda_q_perm); // DBGX((uint64_t) r_weights);
// DBGI(b->rows_4); // if (r_weights)
// DBGI(b->height); // print_global_mem(r_weights, 1, 1, 1);
// DBGI(r_weights_stride);
kernel<<<gridDim, blockDim>>> kernel<<<gridDim, blockDim>>>
( (
@ -93,10 +102,12 @@ void gemm_half_q_half_cuda_part
size_n, size_n,
size_k, size_k,
b->groups, b->groups,
b->groupsize, b->gptq_groupsize,
b->cuda_q_perm, b->cuda_q_perm,
b->rows_4, b->rows_4,
clear clear,
r_weights,
r_weights_stride
); );
} }
} }
@ -112,13 +123,14 @@ void gemm_half_q_half_cuda
int size_k, int size_k,
bool clear, bool clear,
half* temp_dq, half* temp_dq,
bool force_cuda bool force_cuda,
const half* r_weights,
const int r_weights_stride,
bool mul_r_weights
) )
{ {
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda) if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
{ {
//printf("cublas\n");
// Reconstruct FP16 matrix, then cuBLAS // Reconstruct FP16 matrix, then cuBLAS
if (!temp_dq) temp_dq = b->temp_dq; if (!temp_dq) temp_dq = b->temp_dq;
@ -139,12 +151,12 @@ void gemm_half_q_half_cuda
//const float alpha = 1.0f; //const float alpha = 1.0f;
//const float beta = clear ? 0.0f : 1.0f; //const float beta = clear ? 0.0f : 1.0f;
//cublasSgemmEx(cublas_handle, //cublasSgemmEx(cublas_handle,
// CUBLAS_OP_N, // CUBLAS_OP_N,
// CUBLAS_OP_N, // CUBLAS_OP_N,
// size_n, size_m, size_k, // size_n, size_m, size_k,
// &alpha, temp_dq, CUDA_R_16F, size_n, // &alpha, temp_dq, CUDA_R_16F, size_n,
// a, CUDA_R_16F, size_k, // a, CUDA_R_16F, size_k,
// &beta, c, CUDA_R_16F, size_n); // &beta, c, CUDA_R_16F, size_n);
//const float alpha = 1.0f; //const float alpha = 1.0f;
//const float beta = clear ? 0.0f : 1.0f; //const float beta = clear ? 0.0f : 1.0f;
@ -158,24 +170,21 @@ void gemm_half_q_half_cuda
} }
else else
{ {
//printf("cuda\n");
// Quantized matmul // Quantized matmul
//if (clear) clear_tensor_cuda(c, size_m, size_n); int block_m_size_max = b->is_gptq ? GPTQ_BLOCK_M_SIZE_MAX : EXL2_BLOCK_M_SIZE_MAX;
int max_chunks = size_m / block_m_size_max;
int max_chunks = size_m / BLOCK_M_SIZE_MAX; int last_chunk = max_chunks * block_m_size_max;
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
int last_chunk_size = size_m - last_chunk; int last_chunk_size = size_m - last_chunk;
if (max_chunks) if (max_chunks)
{ {
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear); gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, block_m_size_max, clear, r_weights, r_weights_stride, mul_r_weights);
} }
if (last_chunk_size) if (last_chunk_size)
{ {
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear); gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear, r_weights, r_weights_stride, mul_r_weights);
} }
} }
} }
@ -201,11 +210,10 @@ void clear_tensor_cuda
int size_n int size_n
) )
{ {
return; // dim3 blockDim, gridDim;
dim3 blockDim, gridDim; // blockDim.x = CLEAR_N_SIZE;
blockDim.x = CLEAR_N_SIZE; // blockDim.y = 1;
blockDim.y = 1; // gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE); // gridDim.y = size_m;
gridDim.y = size_m; // clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
} }

View File

@ -20,7 +20,10 @@ void gemm_half_q_half_cuda
int size_k, int size_k,
bool clear = false, bool clear = false,
half* reconstruct = NULL, half* reconstruct = NULL,
bool force_cuda = false bool force_cuda = false,
const half* r_weights = NULL,
const int r_weights_stride = 0,
bool mul_r_weights = false
); );
void clear_tensor_cuda void clear_tensor_cuda

View File

@ -1,8 +1,5 @@
#include "compat.cuh" #include "compat.cuh"
#include <cuda_runtime.h>
#include <cuda_fp16.h>
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h) __forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
{ {
half2 result = {}; half2 result = {};
@ -60,6 +57,47 @@ __forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, c
return fma(result_f, qs_f, g_result); return fma(result_f, qs_f, g_result);
} }
__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h)
{
// Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127
float result = {};
#pragma unroll
for (int i = 0; i < 4; i++)
{
half2 w01 = dq[i];
float w0 = __low2float(w01);
float w1 = __high2float(w01);
float x0 = __half2float(*a_ptr++);
float x1 = __half2float(*a_ptr++);
result = fma(w0, x0, result);
result = fma(w1, x1, result);
}
float qs = __half2float(qs_h);
result *= qs;
half result_h = __float2half_rn(result);
return __hadd(result_h, g_result);
}
__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
half result_h = __hadd(__low2half(result), __high2half(result));
return __hfma(result_h, qs_h, g_result);
}
__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
half result_h = __hadd(__low2half(result), __high2half(result));
return __hfma(result_h, qs_h, g_result);
}
typedef void (*fp_gemm_half_q_half_kernel) typedef void (*fp_gemm_half_q_half_kernel)
@ -73,7 +111,7 @@ typedef void (*fp_gemm_half_q_half_kernel)
const int, const int,
const int, const int,
const int, const int,
const int, const uint16_t*,
const uint16_t*, const uint16_t*,
const int, const int,
const int, const int,
@ -81,10 +119,12 @@ typedef void (*fp_gemm_half_q_half_kernel)
const int, const int,
const int, const int,
const int, const int,
const bool const bool,
const half*,
const int
); );
template <bool first_block, int m_count> template <int m_count, bool use_r_weights, bool mul_r_weights>
__global__ void gemm_half_q_half_kernel __global__ void gemm_half_q_half_kernel
( (
const half* __restrict__ a, const half* __restrict__ a,
@ -96,7 +136,7 @@ __global__ void gemm_half_q_half_kernel
const int size_n, const int size_n,
const int size_k, const int size_k,
const int groups, const int groups,
const int groupsize, const uint16_t* __restrict__ b_q_group_map,
const uint16_t* __restrict__ b_q_perm, const uint16_t* __restrict__ b_q_perm,
const int rows_8, const int rows_8,
const int rows_6, const int rows_6,
@ -104,7 +144,9 @@ __global__ void gemm_half_q_half_kernel
const int rows_4, const int rows_4,
const int rows_3, const int rows_3,
const int rows_2, const int rows_2,
const bool clear const bool clear,
const half* r_weights,
const int r_weights_stride
) )
{ {
MatrixView_half a_(a, size_m, size_k); MatrixView_half a_(a, size_m, size_k);
@ -115,18 +157,34 @@ __global__ void gemm_half_q_half_kernel
// Block // Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; int offset_n = blockIdx.x * EXL2_BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count; int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE; int offset_k = blockIdx.z * EXL2_BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); int end_n = min(offset_n + EXL2_BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m); int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); int end_k = min(offset_k + EXL2_BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4; int n = offset_n + t * 4;
// Read weights
half_uint16 weights[MAX_Q_GEMM_WEIGHTS];
if constexpr (use_r_weights)
{
uint16_t any_w = 0;
const half* w_ptr = r_weights;
for (int m = 0; m < m_count; ++m)
{
weights[m].as_half = *w_ptr;
w_ptr += r_weights_stride;
any_w |= weights[m].as_uint16;
}
if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!)
}
// Preload block_a // Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE]; __shared__ half block_a[m_count][EXL2_BLOCK_KN_SIZE];
if (offset_k + t < end_k) if (offset_k + t < end_k)
{ {
@ -135,6 +193,7 @@ __global__ void gemm_half_q_half_kernel
const half* a_ptr = a_.item_ptr(offset_m + m, 0); const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m]; half* block_a_ptr = block_a[m];
half a0 = a_ptr[b_q_perm[offset_k + t]]; half a0 = a_ptr[b_q_perm[offset_k + t]];
// half a0 = a_ptr[offset_k + t];
block_a_ptr[t] = a0; block_a_ptr[t] = a0;
} }
} }
@ -153,14 +212,19 @@ __global__ void gemm_half_q_half_kernel
// Find initial group // Find initial group
int group = offset_k / groupsize; //int group = offset_k / groupsize;
int group = b_q_group_map[offset_k * 2];
// if (offset_m == 0 && t == 0)
// DBGI2(offset_k, group);
// Preload scales // Preload scales
float scales[MAX_GROUPS_IN_BLOCK][4]; half scales[EXL2_MAX_GROUPS_IN_BLOCK][4];
int groups_in_block = DIVIDE((end_k - offset_k), groupsize); //int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
for (int g = 0; g < groups_in_block; g++) int temp_k = offset_k;
for (int g = 0; temp_k < end_k; g++)
{ {
int qscales[4]; int qscales[4];
b_q_scale_.item4(qscales, group + g, n); b_q_scale_.item4(qscales, group + g, n);
@ -168,11 +232,12 @@ __global__ void gemm_half_q_half_kernel
qscales[1]++; qscales[1]++;
qscales[2]++; qscales[2]++;
qscales[3]++; qscales[3]++;
float maxscale = __half2float(b_q_scale_max[group + g]); half maxscale = b_q_scale_max[group + g];
scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale; scales[g][0] = __hmul(__int2half_rn(qscales[0] * qscales[0]), maxscale);
scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale; scales[g][1] = __hmul(__int2half_rn(qscales[1] * qscales[1]), maxscale);
scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale; scales[g][2] = __hmul(__int2half_rn(qscales[2] * qscales[2]), maxscale);
scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale; scales[g][3] = __hmul(__int2half_rn(qscales[3] * qscales[3]), maxscale);
temp_k += b_q_group_map[temp_k * 2 + 1];
} }
// a, b offset // a, b offset
@ -193,20 +258,20 @@ __global__ void gemm_half_q_half_kernel
const uint32_t* b_ptr = b_q_weight + qk * size_n + n; const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0]; const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE; int a_stride = EXL2_BLOCK_KN_SIZE;
// Initial group // Initial group
int scales_idx = 0; int scales_idx = 0;
float qs_f0 = scales[scales_idx][0]; half qs_h0 = scales[scales_idx][0];
float qs_f1 = scales[scales_idx][1]; half qs_h1 = scales[scales_idx][1];
float qs_f2 = scales[scales_idx][2]; half qs_h2 = scales[scales_idx][2];
float qs_f3 = scales[scales_idx][3]; half qs_h3 = scales[scales_idx][3];
int nextgroup = offset_k + groupsize; int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
// Column result // Column result
float block_c[m_count][4] = {}; half block_c[m_count][4] = {};
// Dequantize groups // Dequantize groups
@ -218,11 +283,11 @@ __global__ void gemm_half_q_half_kernel
{ {
group++; group++;
scales_idx++; scales_idx++;
qs_f0 = scales[scales_idx][0]; qs_h0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1]; qs_h1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2]; qs_h2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3]; qs_h3 = scales[scales_idx][3];
nextgroup += groupsize; nextgroup += b_q_group_map[k * 2 + 1];
} }
#pragma unroll #pragma unroll
@ -240,10 +305,11 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++) for (int m = 0; m < m_count; m++)
{ {
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
} }
a_ptr += 8; a_ptr += 8;
} }
@ -256,11 +322,11 @@ __global__ void gemm_half_q_half_kernel
{ {
group++; group++;
scales_idx++; scales_idx++;
qs_f0 = scales[scales_idx][0]; qs_h0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1]; qs_h1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2]; qs_h2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3]; qs_h3 = scales[scales_idx][3];
nextgroup += groupsize; nextgroup += b_q_group_map[k * 2 + 1];
} }
#pragma unroll #pragma unroll
@ -279,10 +345,11 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++) for (int m = 0; m < m_count; m++)
{ {
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
} }
a_ptr += 16; a_ptr += 16;
} }
@ -295,11 +362,11 @@ __global__ void gemm_half_q_half_kernel
{ {
group++; group++;
scales_idx++; scales_idx++;
qs_f0 = scales[scales_idx][0]; qs_h0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1]; qs_h1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2]; qs_h2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3]; qs_h3 = scales[scales_idx][3];
nextgroup += groupsize; nextgroup += b_q_group_map[k * 2 + 1];
} }
#pragma unroll #pragma unroll
@ -320,10 +387,11 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++) for (int m = 0; m < m_count; m++)
{ {
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
} }
a_ptr += 32; a_ptr += 32;
} }
@ -337,11 +405,11 @@ __global__ void gemm_half_q_half_kernel
{ {
group++; group++;
scales_idx++; scales_idx++;
qs_f0 = scales[scales_idx][0]; qs_h0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1]; qs_h1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2]; qs_h2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3]; qs_h3 = scales[scales_idx][3];
nextgroup += groupsize; nextgroup += b_q_group_map[k * 2 + 1];
} }
#pragma unroll #pragma unroll
@ -358,10 +426,11 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++) for (int m = 0; m < m_count; m++)
{ {
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
} }
a_ptr += 8; a_ptr += 8;
} }
@ -374,11 +443,11 @@ __global__ void gemm_half_q_half_kernel
{ {
group++; group++;
scales_idx++; scales_idx++;
qs_f0 = scales[scales_idx][0]; qs_h0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1]; qs_h1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2]; qs_h2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3]; qs_h3 = scales[scales_idx][3];
nextgroup += groupsize; nextgroup += b_q_group_map[k * 2 + 1];
} }
#pragma unroll #pragma unroll
@ -397,10 +466,11 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++) for (int m = 0; m < m_count; m++)
{ {
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
} }
a_ptr += 32; a_ptr += 32;
} }
@ -413,15 +483,15 @@ __global__ void gemm_half_q_half_kernel
{ {
group++; group++;
scales_idx++; scales_idx++;
qs_f0 = scales[scales_idx][0]; qs_h0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1]; qs_h1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2]; qs_h2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3]; qs_h3 = scales[scales_idx][3];
nextgroup += groupsize; nextgroup += b_q_group_map[k * 2 + 1];
} }
#pragma unroll #pragma unroll
for (int j = 0; j < 2; j++) for (int j = 0; j < 1; j++)
{ {
int4 load_int4[1]; int4 load_int4[1];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
@ -434,15 +504,16 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++) for (int m = 0; m < m_count; m++)
{ {
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
} }
a_ptr += 16; a_ptr += 16;
} }
k += 32; k += 16;
} }
// Accumulate column sums in c // Accumulate column sums in c
@ -450,38 +521,60 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++) for (int m = 0; m < m_count; m++)
{ {
half2* out = (half2*)c_.item_ptr(offset_m + m, n); half2* out = (half2*)c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
if constexpr (mul_r_weights)
{
half2 w_mul2 = __half2half2(weights[m].as_half);
result01 = __hmul2(result01, w_mul2);
result23 = __hmul2(result23, w_mul2);
}
atomicAdd(out , result01); atomicAdd(out , result01);
atomicAdd(out + 1, result23); atomicAdd(out + 1, result23);
// *out = result01;
// *(out + 1) = result23;
} }
} }
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count) template <bool use_r_weights, bool mul_r_weights>
struct map_m_count_exl2 {
static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count)
{
#if EXL2_BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_kernel<1, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_kernel<2, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_kernel<3, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_kernel<4, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_kernel<5, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_kernel<6, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_kernel<7, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_kernel<8, use_r_weights, mul_r_weights>;
#endif
return NULL;
}
};
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count, bool r_weights, bool mul_r_weights)
{ {
#if BLOCK_M_SIZE_MAX >= 1 if (!r_weights && !mul_r_weights) return map_m_count_exl2<false, false>::pick_gemm_half_q_half_kernel(m_count);
if (m_count == 1) return gemm_half_q_half_kernel<true, 1>; if (!r_weights && mul_r_weights) return map_m_count_exl2<false, true>::pick_gemm_half_q_half_kernel(m_count);
#endif if ( r_weights && !mul_r_weights) return map_m_count_exl2< true, false>::pick_gemm_half_q_half_kernel(m_count);
#if BLOCK_M_SIZE_MAX >= 2 if ( r_weights && mul_r_weights) return map_m_count_exl2< true, true>::pick_gemm_half_q_half_kernel(m_count);
if (m_count == 2) return gemm_half_q_half_kernel<true, 2>;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_kernel<true, 3>;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_kernel<true, 4>;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_kernel<true, 5>;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_kernel<true, 6>;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_kernel<true, 7>;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_kernel<true, 8>;
#endif
return NULL; return NULL;
} }

View File

@ -18,6 +18,15 @@ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
return __half2float(__low2half(result)) + __half2float(__high2half(result)); return __half2float(__low2half(result)) + __half2float(__high2half(result));
} }
__forceinline__ __device__ half2 dot22_8_h2(half2(&dq)[4], const half* a_ptr)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return result;
}
typedef void (*fp_gemm_half_q_half_gptq_kernel) typedef void (*fp_gemm_half_q_half_gptq_kernel)
( (
const half*, const half*,
@ -32,10 +41,12 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)
const int, const int,
const uint16_t*, const uint16_t*,
const int, const int,
const bool const bool,
const half*,
const int
); );
template <bool first_block, int m_count> template <int m_count, bool use_r_weights, bool mul_r_weights>
__global__ void gemm_half_q_half_gptq_kernel __global__ void gemm_half_q_half_gptq_kernel
( (
const half* __restrict__ a, const half* __restrict__ a,
@ -50,7 +61,9 @@ __global__ void gemm_half_q_half_gptq_kernel
const int groupsize, const int groupsize,
const uint16_t* __restrict__ b_q_perm, const uint16_t* __restrict__ b_q_perm,
const int rows_4, const int rows_4,
const bool clear const bool clear,
const half* r_weights,
const int r_weights_stride
) )
{ {
MatrixView_half a_(a, size_m, size_k); MatrixView_half a_(a, size_m, size_k);
@ -62,19 +75,35 @@ __global__ void gemm_half_q_half_gptq_kernel
// Block // Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; int offset_n = blockIdx.x * GPTQ_BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count; int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE; int offset_k = blockIdx.z * GPTQ_BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); int end_n = min(offset_n + GPTQ_BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m); int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); int end_k = min(offset_k + GPTQ_BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4; int n = offset_n + t * 4;
// Read weights
half_uint16 weights[MAX_Q_GEMM_WEIGHTS];
if constexpr (use_r_weights)
{
uint16_t any_w = 0;
const half* w_ptr = r_weights;
for (int m = 0; m < m_count; ++m)
{
weights[m].as_half = *w_ptr;
w_ptr += r_weights_stride;
any_w |= weights[m].as_uint16;
}
if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!)
}
// Preload block_a // Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE]; __shared__ half block_a[m_count][GPTQ_BLOCK_KN_SIZE];
if (offset_k + t < end_k) if (offset_k + t < end_k)
{ {
@ -113,16 +142,16 @@ __global__ void gemm_half_q_half_gptq_kernel
const uint32_t* b_ptr = b_q_weight + qk * size_n + n; const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0]; const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE; int a_stride = GPTQ_BLOCK_KN_SIZE;
// Initial group // Initial group
int zeros[4]; int zeros[4];
float scales[4]; half2 scales[4];
half2 z1z16[4][2]; half2 z1z16[4][2];
half2 y1y16[4][2]; half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n); b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n); b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
@ -132,7 +161,7 @@ __global__ void gemm_half_q_half_gptq_kernel
// Column result // Column result
float block_c[m_count][4] = {}; half2 block_c[m_count][4] = {};
// Dequantize and multiply // Dequantize and multiply
@ -144,7 +173,7 @@ __global__ void gemm_half_q_half_gptq_kernel
group++; group++;
nextgroup += groupsize; nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n); b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n); b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
@ -166,10 +195,11 @@ __global__ void gemm_half_q_half_gptq_kernel
#pragma unroll #pragma unroll
for (int m = 0; m < m_count; m++) for (int m = 0; m < m_count; m++)
{ {
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); block_c[m][0] = __hfma2(dot22_8_h2(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); block_c[m][1] = __hfma2(dot22_8_h2(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); block_c[m][2] = __hfma2(dot22_8_h2(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
block_c[m][3] = __hfma2(dot22_8_h2(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
} }
b_ptr += size_n; b_ptr += size_n;
@ -182,38 +212,62 @@ __global__ void gemm_half_q_half_gptq_kernel
for (int m = 0; m < m_count; m++) for (int m = 0; m < m_count; m++)
{ {
half2 *out = (half2*) c_.item_ptr(offset_m + m, n); half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); half result0 = __hadd(__low2half(block_c[m][0]), __high2half(block_c[m][0]));
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); half result1 = __hadd(__low2half(block_c[m][1]), __high2half(block_c[m][1]));
half result2 = __hadd(__low2half(block_c[m][2]), __high2half(block_c[m][2]));
half result3 = __hadd(__low2half(block_c[m][3]), __high2half(block_c[m][3]));
half2 result01 = __halves2half2(result0, result1);
half2 result23 = __halves2half2(result2, result3);
if constexpr (mul_r_weights)
{
half2 w_mul2 = __half2half2(weights[m].as_half);
result01 = __hmul2(result01, w_mul2);
result23 = __hmul2(result23, w_mul2);
}
atomicAdd(out , result01); atomicAdd(out , result01);
atomicAdd(out + 1, result23); atomicAdd(out + 1, result23);
} }
} }
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count) template <bool use_r_weights, bool mul_r_weights>
struct map_m_count_gptq {
static constexpr fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(int m_count)
{
#if GPTQ_BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_gptq_kernel<1, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_gptq_kernel<2, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_gptq_kernel<3, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_gptq_kernel<4, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_gptq_kernel<5, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_gptq_kernel<6, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_gptq_kernel<7, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_gptq_kernel<8, use_r_weights, mul_r_weights>;
#endif
return NULL;
}
};
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(const int m_count, bool r_weights, bool mul_r_weights)
{ {
#if BLOCK_M_SIZE_MAX >= 1 if (!r_weights && !mul_r_weights) return map_m_count_gptq<false, false>::pick_gemm_half_q_half_gptq_kernel(m_count);
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>; if (!r_weights && mul_r_weights) return map_m_count_gptq<false, true>::pick_gemm_half_q_half_gptq_kernel(m_count);
#endif if ( r_weights && !mul_r_weights) return map_m_count_gptq< true, false>::pick_gemm_half_q_half_gptq_kernel(m_count);
#if BLOCK_M_SIZE_MAX >= 2 if ( r_weights && mul_r_weights) return map_m_count_gptq< true, true>::pick_gemm_half_q_half_gptq_kernel(m_count);
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
#endif
return NULL; return NULL;
} }

View File

@ -57,6 +57,7 @@ QMatrix::QMatrix
uint32_t* _q_scale, uint32_t* _q_scale,
half* _q_scale_max, half* _q_scale_max,
uint16_t* _q_groups, uint16_t* _q_groups,
uint16_t* _q_group_map,
uint32_t* _gptq_qzeros, uint32_t* _gptq_qzeros,
half* _gptq_scales, half* _gptq_scales,
@ -80,13 +81,17 @@ QMatrix::QMatrix
cuda_q_scale = _q_scale; cuda_q_scale = _q_scale;
cuda_q_scale_max = _q_scale_max; cuda_q_scale_max = _q_scale_max;
cuda_q_groups = _q_groups; cuda_q_groups = _q_groups;
cuda_q_group_map = _q_group_map;
cuda_gptq_qzeros = _gptq_qzeros; cuda_gptq_qzeros = _gptq_qzeros;
cuda_gptq_scales = _gptq_scales; cuda_gptq_scales = _gptq_scales;
is_gptq = (_gptq_qzeros != NULL); is_gptq = (_gptq_qzeros != NULL);
groupsize = 1; if (is_gptq)
while (groupsize * groups < height) groupsize *= 2; {
gptq_groupsize = 1;
while (gptq_groupsize * groups < height) gptq_groupsize *= 2;
}
// Create group map // Create group map
@ -102,15 +107,26 @@ QMatrix::QMatrix
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t)); uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost); cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
int row = 0;
for (int i = 0; i < groups; i++) for (int i = 0; i < groups; i++)
{ {
int bits = cpu_q_groups[i * 2]; int bits = cpu_q_groups[i * 2];
if (bits == 8) rows_8 += groupsize;
if (bits == 6) rows_6 += groupsize; int rows;
if (bits == 5) rows_5 += groupsize; if (i < groups - 1)
if (bits == 4) rows_4 += groupsize; {
if (bits == 3) rows_3 += groupsize; int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1];
if (bits == 2) rows_2 += groupsize; rows = qrows * 32 / bits;
}
else rows = height - row;
if (bits == 8) rows_8 += rows;
if (bits == 6) rows_6 += rows;
if (bits == 5) rows_5 += rows;
if (bits == 4) rows_4 += rows;
if (bits == 3) rows_3 += rows;
if (bits == 2) rows_2 += rows;
row += rows;
} }
free(cpu_q_groups); free(cpu_q_groups);
@ -138,6 +154,13 @@ QMatrix::QMatrix
} }
} }
// DBGI(rows_8);
// DBGI(rows_6);
// DBGI(rows_5);
// DBGI(rows_4);
// DBGI(rows_3);
// DBGI(rows_2);
// Shuffle quantized data // Shuffle quantized data
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
@ -283,10 +306,10 @@ __global__ void reconstruct_kernel
const uint16_t* __restrict__ b_q_perm, const uint16_t* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_q_scale, const uint32_t* __restrict__ b_q_scale,
const half* __restrict__ b_q_scale_max, const half* __restrict__ b_q_scale_max,
//const uint16_t* __restrict__ b_q_groups, const uint16_t* __restrict__ b_q_group_map,
const int size_k, const int size_k,
const int size_n, const int size_n,
const int groupsize, //const int groupsize,
const int groups, const int groups,
half* __restrict__ b, half* __restrict__ b,
const int rows_8, const int rows_8,
@ -317,7 +340,8 @@ __global__ void reconstruct_kernel
// Find initial group // Find initial group
int group = offset_k / groupsize; // int group = offset_k / groupsize;
int group = b_q_group_map[offset_k * 2];
int pre_rows_8 = min(rows_8, offset_k); int pre_rows_8 = min(rows_8, offset_k);
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0; int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
@ -337,7 +361,7 @@ __global__ void reconstruct_kernel
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
half2 qs_h2 = __halves2half2(qs_h, qs_h); half2 qs_h2 = __halves2half2(qs_h, qs_h);
int nextgroup = offset_k + groupsize; int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int k = offset_k; int k = offset_k;
@ -347,7 +371,7 @@ __global__ void reconstruct_kernel
while (k < rows_8 && k < end_k) while (k < rows_8 && k < end_k)
{ {
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 4; p++) for (int p = 0; p < 4; p++)
{ {
half2 dq[4]; half2 dq[4];
@ -363,7 +387,7 @@ __global__ void reconstruct_kernel
while (k < rows_6 && k < end_k) while (k < rows_6 && k < end_k)
{ {
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 2; p++) for (int p = 0; p < 2; p++)
{ {
half2 dq[8]; half2 dq[8];
@ -380,7 +404,7 @@ __global__ void reconstruct_kernel
while (k < rows_5 && k < end_k) while (k < rows_5 && k < end_k)
{ {
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++) for (int p = 0; p < 1; p++)
{ {
half2 dq[16]; half2 dq[16];
@ -399,7 +423,7 @@ __global__ void reconstruct_kernel
while (k < rows_4 && k < end_k) while (k < rows_4 && k < end_k)
{ {
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 4; p++) for (int p = 0; p < 4; p++)
{ {
half2 dq[4]; half2 dq[4];
@ -414,7 +438,7 @@ __global__ void reconstruct_kernel
while (k < rows_3 && k < end_k) while (k < rows_3 && k < end_k)
{ {
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++) for (int p = 0; p < 1; p++)
{ {
half2 dq[16]; half2 dq[16];
@ -431,8 +455,8 @@ __global__ void reconstruct_kernel
while (k < rows_2 && k < end_k) while (k < rows_2 && k < end_k)
{ {
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 2; p++) for (int p = 0; p < 1; p++)
{ {
half2 dq[8]; half2 dq[8];
uint32_t q_0 = *b_ptr; b_ptr += size_n; uint32_t q_0 = *b_ptr; b_ptr += size_n;
@ -441,7 +465,7 @@ __global__ void reconstruct_kernel
half* dqh = (half*) dq; half* dqh = (half*) dq;
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]); for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
} }
k += 32; k += 16;
} }
} }
@ -461,10 +485,10 @@ void QMatrix::reconstruct(half* out)
cuda_q_perm, cuda_q_perm,
cuda_q_scale, cuda_q_scale,
cuda_q_scale_max, cuda_q_scale_max,
//cuda_q_groups, cuda_q_group_map,
height, height,
width, width,
groupsize, //groupsize,
groups, groups,
out, out,
rows_8, rows_8,
@ -487,7 +511,7 @@ void QMatrix::reconstruct(half* out)
//const uint16_t* __restrict__ b_q_groups, //const uint16_t* __restrict__ b_q_groups,
height, height,
width, width,
groupsize, gptq_groupsize,
groups, groups,
out, out,
rows_4 rows_4

View File

@ -18,7 +18,7 @@ public:
int height; int height;
int width; int width;
int groups; int groups;
int groupsize; int gptq_groupsize;
int rows_8; int rows_8;
int rows_6; int rows_6;
@ -33,6 +33,7 @@ public:
uint32_t* cuda_q_scale = NULL; uint32_t* cuda_q_scale = NULL;
half* cuda_q_scale_max = NULL; half* cuda_q_scale_max = NULL;
uint16_t* cuda_q_groups = NULL; uint16_t* cuda_q_groups = NULL;
uint16_t* cuda_q_group_map = NULL;
uint32_t* cuda_gptq_qzeros = NULL; uint32_t* cuda_gptq_qzeros = NULL;
half* cuda_gptq_scales = NULL; half* cuda_gptq_scales = NULL;
@ -53,6 +54,7 @@ public:
uint32_t* _q_scale, uint32_t* _q_scale,
half* _q_scale_max, half* _q_scale_max,
uint16_t* _q_groups, uint16_t* _q_groups,
uint16_t* _q_group_map,
uint32_t* _gptq_qzeros, uint32_t* _gptq_qzeros,
half* _gptq_scales, half* _gptq_scales,

View File

@ -7,6 +7,7 @@ union half2_uint32
half2 as_half2; half2 as_half2;
__device__ half2_uint32(uint32_t val) : as_uint32(val) {} __device__ half2_uint32(uint32_t val) : as_uint32(val) {}
__device__ half2_uint32(half2 val) : as_half2(val) {} __device__ half2_uint32(half2 val) : as_half2(val) {}
__device__ half2_uint32() : as_uint32(0) {}
}; };
union half_uint16 union half_uint16
@ -15,6 +16,7 @@ union half_uint16
half as_half; half as_half;
__device__ half_uint16(uint16_t val) : as_uint16(val) {} __device__ half_uint16(uint16_t val) : as_uint16(val) {}
__device__ half_uint16(half val) : as_half(val) {} __device__ half_uint16(half val) : as_half(val) {}
__device__ half_uint16() : as_uint16(0) {}
}; };
// Max_scale premultiplied by 1/256 // Max_scale premultiplied by 1/256

View File

@ -1,3 +1,11 @@
#ifndef _util_cuh
#define _util_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/cuda/CUDAContext.h>
#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) #define DIVIDE(x, size) (((x) + (size) - 1) / (size))
@ -40,3 +48,7 @@ inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=
if (abort) exit(code); if (abort) exit(code);
} }
} }
void print_global_mem(const half* ptr, int rows, int columns, int stride);
#endif

View File

@ -31,6 +31,7 @@ uintptr_t make_q_matrix
torch::Tensor q_scale, torch::Tensor q_scale,
torch::Tensor q_scale_max, torch::Tensor q_scale_max,
torch::Tensor q_groups, torch::Tensor q_groups,
torch::Tensor q_group_map,
torch::Tensor gptq_qzeros, torch::Tensor gptq_qzeros,
torch::Tensor gptq_scales, torch::Tensor gptq_scales,
torch::Tensor gptq_g_idx, torch::Tensor gptq_g_idx,
@ -43,6 +44,7 @@ uintptr_t make_q_matrix
TORCH_CHECK_DTYPE_OPT(q_scale, kInt); TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf); TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
TORCH_CHECK_DTYPE_OPT(q_groups, kShort); TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
TORCH_CHECK_DTYPE_OPT(q_group_map, kShort);
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt); TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf); TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt); TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
@ -83,12 +85,15 @@ uintptr_t make_q_matrix
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(), q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(), q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(), q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
q_group_map.device().is_meta() ? NULL : (uint16_t*) q_group_map.data_ptr(),
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(), gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(), gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(), gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
(half*) temp_dq.data_ptr() (half*) temp_dq.data_ptr()
); );
if (m->failed) throw std::runtime_error("CUDA out of memory");
return reinterpret_cast<uintptr_t> (m); return reinterpret_cast<uintptr_t> (m);
} }

View File

@ -32,10 +32,10 @@ def fresh_cache():
current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
os.environ['HUGGINGFACE_HUB_CACHE'] = d os.environ["HUGGINGFACE_HUB_CACHE"] = d
yield yield
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
os.environ['HUGGINGFACE_HUB_CACHE'] = current_value os.environ["HUGGINGFACE_HUB_CACHE"] = current_value
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value
@ -47,7 +47,7 @@ def prefetched():
revision="main", revision="main",
local_files_only=False, local_files_only=False,
repo_type="model", repo_type="model",
allow_patterns=["*.safetensors"] allow_patterns=["*.safetensors"],
) )
yield model_id yield model_id
@ -61,7 +61,7 @@ def test_weight_hub_files_offline_error(offline, fresh_cache):
def test_weight_hub_files_offline_ok(prefetched, offline): def test_weight_hub_files_offline_ok(prefetched, offline):
# If the model is prefetched then we should be able to get the weight files from local cache # If the model is prefetched then we should be able to get the weight files from local cache
filenames = weight_hub_files(prefetched) filenames = weight_hub_files(prefetched)
assert filenames == ['model.safetensors'] assert filenames == ["model.safetensors"]
def test_weight_hub_files(): def test_weight_hub_files():

View File

@ -71,7 +71,7 @@ def _load_multi_mqa_gptq(
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device) g_idx = g_idx.to(device=weights.device)
bits, groupsize = weights._get_gptq_params() bits, groupsize, _ = weights._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA from text_generation_server.utils.layers import HAS_EXLLAMA

View File

@ -27,6 +27,32 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
return output.view(output_shape) return output.view(output_shape)
# Group map needed for irregular group sizes
def make_group_map(q_groups, num_qrows):
gr = q_groups.tolist()
group_map = []
num_groups = len(gr) // 2
for i in range(num_groups):
bits = gr[i * 2]
if i < num_groups - 1:
qrows = gr[i * 2 + 3] - gr[i * 2 + 1]
else:
qrows = num_qrows - gr[i * 2 + 1]
rows = qrows * 32 // bits
for j in range(rows):
group_map += [i]
group_map += [rows - j]
return torch.tensor(group_map, dtype=torch.short, device=q_groups.device)
# Create Q matrix
def ext_make_q_matrix(w: dict, temp_dq, key: str = None): def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
""" """
Create Q matrix Create Q matrix
@ -37,6 +63,10 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
w["q_scale_max"] /= 256 w["q_scale_max"] /= 256
w["q_perm"] = w["q_perm"].short() w["q_perm"] = w["q_perm"].short()
w["q_invperm"] = w["q_invperm"].short() w["q_invperm"] = w["q_invperm"].short()
if "q_group_map" not in w:
w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0])
return make_q_matrix( return make_q_matrix(
w["q_weight"], w["q_weight"],
w["q_perm"], w["q_perm"],
@ -44,6 +74,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
w["q_scale"], w["q_scale"],
w["q_scale_max"], w["q_scale_max"],
w["q_groups"], w["q_groups"],
w["q_group_map"],
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor, none_tensor,
@ -70,6 +101,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor,
w["qzeros"], w["qzeros"],
w["scales"], w["scales"],
w["g_idx"].cpu(), w["g_idx"].cpu(),
@ -84,6 +116,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor,
w["qzeros"], w["qzeros"],
w["scales"], w["scales"],
none_tensor, none_tensor,

View File

@ -18,7 +18,9 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"] HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
def _cached_weight_files(model_id: str, revision: Optional[str], extension: str) -> List[str]: def _cached_weight_files(
model_id: str, revision: Optional[str], extension: str
) -> List[str]:
"""Guess weight files from the cached revision snapshot directory""" """Guess weight files from the cached revision snapshot directory"""
d = _get_cached_revision_directory(model_id, revision) d = _get_cached_revision_directory(model_id, revision)
if not d: if not d:
@ -27,7 +29,9 @@ def _cached_weight_files(model_id: str, revision: Optional[str], extension: str)
return filenames return filenames
def _weight_hub_files_from_model_info(info: hf_api.ModelInfo, extension: str) -> List[str]: def _weight_hub_files_from_model_info(
info: hf_api.ModelInfo, extension: str
) -> List[str]:
return [ return [
s.rfilename s.rfilename
for s in info.siblings for s in info.siblings
@ -44,21 +48,27 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
# see _weight_hub_files_from_model_info, that's also what is # see _weight_hub_files_from_model_info, that's also what is
# done there with the len(s.rfilename.split("/")) == 1 condition # done there with the len(s.rfilename.split("/")) == 1 condition
root, _, files = next(os.walk(str(d))) root, _, files = next(os.walk(str(d)))
filenames = [f for f in files filenames = [
if f.endswith(extension) f
and "arguments" not in f for f in files
and "args" not in f if f.endswith(extension)
and "adapter" not in f and "arguments" not in f
and "training" not in f] and "args" not in f
and "adapter" not in f
and "training" not in f
]
return filenames return filenames
def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Optional[Path]: def _get_cached_revision_directory(
model_id: str, revision: Optional[str]
) -> Optional[Path]:
if revision is None: if revision is None:
revision = "main" revision = "main"
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path( repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path(
file_download.repo_folder_name(repo_id=model_id, repo_type="model")) file_download.repo_folder_name(repo_id=model_id, repo_type="model")
)
if not repo_cache.is_dir(): if not repo_cache.is_dir():
# No cache for this model # No cache for this model
@ -86,7 +96,7 @@ def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Op
def weight_hub_files( def weight_hub_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[str]: ) -> List[str]:
"""Get the weights filenames on the hub""" """Get the weights filenames on the hub"""
api = HfApi() api = HfApi()

View File

@ -19,6 +19,7 @@ from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear from text_generation_server.utils.gptq.quant_linear import QuantLinear
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
from text_generation_server.utils.log import log_once
HAS_AWQ = True HAS_AWQ = True
try: try:
@ -35,10 +36,11 @@ HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 CAN_EXLLAMA = major >= 8
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1: if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
logger.warning( V2 = False
log_once(
logger.warning,
"Disabling exllama v2 and using v1 instead because there are issues when sharding" "Disabling exllama v2 and using v1 instead because there are issues when sharding"
) )
V2 = False
if os.getenv("DISABLE_EXLLAMA") == "True": if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False HAS_EXLLAMA = False

View File

@ -0,0 +1,6 @@
from functools import lru_cache
@lru_cache(10)
def log_once(log, msg:str):
log(msg)

View File

@ -6,6 +6,7 @@ import torch
from loguru import logger from loguru import logger
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import json import json
from text_generation_server.utils.log import log_once
class Weights: class Weights:
@ -161,7 +162,7 @@ class Weights:
else: else:
g_idx = None g_idx = None
bits, groupsize = self._get_gptq_params() bits, groupsize, _ = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else: else:
slice_ = self._get_slice(f"{prefix}.weight") slice_ = self._get_slice(f"{prefix}.weight")
@ -211,10 +212,10 @@ class Weights:
else: else:
g_idx = None g_idx = None
bits, groupsize = self._get_gptq_params() bits, groupsize, desc_act = self._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq" use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else: else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
@ -240,11 +241,15 @@ class Weights:
def get_multi_weights_row(self, prefix: str, quantize: str): def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq": if quantize == "gptq":
use_exllama = True use_exllama = True
bits, groupsize = self._get_gptq_params() bits, groupsize, desc_act = self._get_gptq_params()
if bits != 4: if bits != 4:
use_exllama = False use_exllama = False
if desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
if self.process_group.size() > 1: if self.process_group.size() > 1:
g_idx = self.get_tensor(f"{prefix}.g_idx") g_idx = self.get_tensor(f"{prefix}.g_idx")
if g_idx is not None: if g_idx is not None:
@ -274,12 +279,18 @@ class Weights:
if use_exllama: if use_exllama:
if not HAS_EXLLAMA: if not HAS_EXLLAMA:
if CAN_EXLLAMA: if CAN_EXLLAMA:
logger.warning( log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True" "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True"
) )
use_exllama = False use_exllama = False
else: else:
logger.info(f"Using exllama kernels v{HAS_EXLLAMA}") log_once(
logger.info,
f"Using exllama kernels v{HAS_EXLLAMA}"
)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if use_exllama and groupsize != -1: if use_exllama and groupsize != -1:
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
@ -288,14 +299,12 @@ class Weights:
qzeros = self.get_tensor(f"{prefix}.qzeros") qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales") scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if use_exllama: if use_exllama:
g_idx = g_idx - g_idx[0] g_idx = g_idx - g_idx[0]
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
elif quantize == "awq": elif quantize == "awq":
bits, groupsize = self._get_gptq_params() bits, groupsize, _ = self._get_gptq_params()
try: try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0) qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
@ -314,18 +323,20 @@ class Weights:
weight = self.get_sharded(f"{prefix}.weight", dim=1) weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight return weight
def _get_gptq_params(self) -> Tuple[int, int]: def _get_gptq_params(self) -> Tuple[int, int, int]:
try: try:
bits = self.get_tensor("gptq_bits").item() bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item() groupsize = self.get_tensor("gptq_groupsize").item()
desc_act = False
except (SafetensorError, RuntimeError) as e: except (SafetensorError, RuntimeError) as e:
try: try:
bits = self.gptq_bits bits = self.gptq_bits
groupsize = self.gptq_groupsize groupsize = self.gptq_groupsize
desc_act = getattr(self, "gptq_desc_act", False)
except Exception: except Exception:
raise e raise e
return bits, groupsize return bits, groupsize, desc_act
def _set_gptq_params(self, model_id, revision): def _set_gptq_params(self, model_id, revision):
filename = "config.json" filename = "config.json"
@ -340,6 +351,7 @@ class Weights:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["quantization_config"]["bits"] self.gptq_bits = data["quantization_config"]["bits"]
self.gptq_groupsize = data["quantization_config"]["group_size"] self.gptq_groupsize = data["quantization_config"]["group_size"]
self.gptq_desc_act = data["quantization_config"]["desc_act"]
except Exception: except Exception:
filename = "quantize_config.json" filename = "quantize_config.json"
try: try:
@ -353,6 +365,7 @@ class Weights:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["bits"] self.gptq_bits = data["bits"]
self.gptq_groupsize = data["group_size"] self.gptq_groupsize = data["group_size"]
self.gptq_desc_act = data["desc_act"]
except Exception: except Exception:
filename = "quant_config.json" filename = "quant_config.json"
try: try:
@ -366,5 +379,6 @@ class Weights:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["w_bit"] self.gptq_bits = data["w_bit"]
self.gptq_groupsize = data["q_group_size"] self.gptq_groupsize = data["q_group_size"]
self.gptq_desc_act = data["desc_act"]
except Exception: except Exception:
pass pass