feat: update exllamav2 kernels (#1370)
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
987c959f73
commit
564199bab3
|
@ -2,6 +2,7 @@
|
||||||
#define _config_h
|
#define _config_h
|
||||||
|
|
||||||
#define MAX_Q_GEMM_ROWS 50
|
#define MAX_Q_GEMM_ROWS 50
|
||||||
|
#define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS
|
||||||
|
|
||||||
#define QMODE_2BIT 1
|
#define QMODE_2BIT 1
|
||||||
#define QMODE_3BIT 1
|
#define QMODE_3BIT 1
|
||||||
|
@ -10,4 +11,5 @@
|
||||||
#define QMODE_6BIT 0
|
#define QMODE_6BIT 0
|
||||||
#define QMODE_8BIT 0
|
#define QMODE_8BIT 0
|
||||||
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -10,16 +10,19 @@
|
||||||
#include "quant/qdq_6.cuh"
|
#include "quant/qdq_6.cuh"
|
||||||
#include "quant/qdq_8.cuh"
|
#include "quant/qdq_8.cuh"
|
||||||
|
|
||||||
#define BLOCK_KN_SIZE 128
|
#define GPTQ_BLOCK_KN_SIZE 128
|
||||||
#define BLOCK_M_SIZE_MAX 8
|
#define GPTQ_BLOCK_M_SIZE_MAX 8
|
||||||
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
|
#define GPTQ_MAX_GROUPS_IN_BLOCK (GPTQ_BLOCK_KN_SIZE / 32)
|
||||||
|
|
||||||
|
#define EXL2_BLOCK_KN_SIZE 64
|
||||||
|
#define EXL2_BLOCK_M_SIZE_MAX 8
|
||||||
|
#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32)
|
||||||
|
|
||||||
#define CLEAR_N_SIZE 256
|
#define CLEAR_N_SIZE 256
|
||||||
|
|
||||||
#include "q_gemm_kernel.cuh"
|
#include "q_gemm_kernel.cuh"
|
||||||
#include "q_gemm_kernel_gptq.cuh"
|
#include "q_gemm_kernel_gptq.cuh"
|
||||||
|
|
||||||
#include "compat_gemm.cuh"
|
|
||||||
|
|
||||||
void gemm_half_q_half_cuda_part
|
void gemm_half_q_half_cuda_part
|
||||||
(
|
(
|
||||||
const half* a,
|
const half* a,
|
||||||
|
@ -29,20 +32,23 @@ void gemm_half_q_half_cuda_part
|
||||||
int size_n,
|
int size_n,
|
||||||
int size_k,
|
int size_k,
|
||||||
int m_count,
|
int m_count,
|
||||||
bool clear
|
bool clear,
|
||||||
|
const half* r_weights,
|
||||||
|
int r_weights_stride,
|
||||||
|
bool mul_r_weights
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
if (!b->is_gptq)
|
if (!b->is_gptq)
|
||||||
{
|
{
|
||||||
dim3 blockDim, gridDim;
|
dim3 blockDim, gridDim;
|
||||||
blockDim.x = BLOCK_KN_SIZE;
|
blockDim.x = EXL2_BLOCK_KN_SIZE;
|
||||||
blockDim.y = 1;
|
blockDim.y = 1;
|
||||||
blockDim.z = 1;
|
blockDim.z = 1;
|
||||||
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
|
gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4);
|
||||||
gridDim.y = DIVIDE(size_m, m_count);
|
gridDim.y = DIVIDE(size_m, m_count);
|
||||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
gridDim.z = DIVIDE(size_k, EXL2_BLOCK_KN_SIZE);
|
||||||
|
|
||||||
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count);
|
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights);
|
||||||
|
|
||||||
kernel<<<gridDim, blockDim>>>
|
kernel<<<gridDim, blockDim>>>
|
||||||
(
|
(
|
||||||
|
@ -55,7 +61,7 @@ void gemm_half_q_half_cuda_part
|
||||||
size_n,
|
size_n,
|
||||||
size_k,
|
size_k,
|
||||||
b->groups,
|
b->groups,
|
||||||
b->groupsize,
|
b->cuda_q_group_map,
|
||||||
b->cuda_q_perm,
|
b->cuda_q_perm,
|
||||||
b->rows_8,
|
b->rows_8,
|
||||||
b->rows_6,
|
b->rows_6,
|
||||||
|
@ -63,24 +69,27 @@ void gemm_half_q_half_cuda_part
|
||||||
b->rows_4,
|
b->rows_4,
|
||||||
b->rows_3,
|
b->rows_3,
|
||||||
b->rows_2,
|
b->rows_2,
|
||||||
clear
|
clear,
|
||||||
|
r_weights,
|
||||||
|
r_weights_stride
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
dim3 blockDim, gridDim;
|
dim3 blockDim, gridDim;
|
||||||
blockDim.x = BLOCK_KN_SIZE;
|
blockDim.x = GPTQ_BLOCK_KN_SIZE;
|
||||||
blockDim.y = 1;
|
blockDim.y = 1;
|
||||||
blockDim.z = 1;
|
blockDim.z = 1;
|
||||||
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
|
gridDim.x = DIVIDE(size_n, GPTQ_BLOCK_KN_SIZE * 4);
|
||||||
gridDim.y = DIVIDE(size_m, m_count);
|
gridDim.y = DIVIDE(size_m, m_count);
|
||||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
gridDim.z = DIVIDE(size_k, GPTQ_BLOCK_KN_SIZE);
|
||||||
|
|
||||||
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
|
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(m_count, r_weights != NULL, mul_r_weights);
|
||||||
|
|
||||||
// DBGX((uint64_t) b->cuda_q_perm);
|
// DBGX((uint64_t) r_weights);
|
||||||
// DBGI(b->rows_4);
|
// if (r_weights)
|
||||||
// DBGI(b->height);
|
// print_global_mem(r_weights, 1, 1, 1);
|
||||||
|
// DBGI(r_weights_stride);
|
||||||
|
|
||||||
kernel<<<gridDim, blockDim>>>
|
kernel<<<gridDim, blockDim>>>
|
||||||
(
|
(
|
||||||
|
@ -93,10 +102,12 @@ void gemm_half_q_half_cuda_part
|
||||||
size_n,
|
size_n,
|
||||||
size_k,
|
size_k,
|
||||||
b->groups,
|
b->groups,
|
||||||
b->groupsize,
|
b->gptq_groupsize,
|
||||||
b->cuda_q_perm,
|
b->cuda_q_perm,
|
||||||
b->rows_4,
|
b->rows_4,
|
||||||
clear
|
clear,
|
||||||
|
r_weights,
|
||||||
|
r_weights_stride
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -112,13 +123,14 @@ void gemm_half_q_half_cuda
|
||||||
int size_k,
|
int size_k,
|
||||||
bool clear,
|
bool clear,
|
||||||
half* temp_dq,
|
half* temp_dq,
|
||||||
bool force_cuda
|
bool force_cuda,
|
||||||
|
const half* r_weights,
|
||||||
|
const int r_weights_stride,
|
||||||
|
bool mul_r_weights
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
|
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
|
||||||
{
|
{
|
||||||
//printf("cublas\n");
|
|
||||||
|
|
||||||
// Reconstruct FP16 matrix, then cuBLAS
|
// Reconstruct FP16 matrix, then cuBLAS
|
||||||
|
|
||||||
if (!temp_dq) temp_dq = b->temp_dq;
|
if (!temp_dq) temp_dq = b->temp_dq;
|
||||||
|
@ -139,12 +151,12 @@ void gemm_half_q_half_cuda
|
||||||
//const float alpha = 1.0f;
|
//const float alpha = 1.0f;
|
||||||
//const float beta = clear ? 0.0f : 1.0f;
|
//const float beta = clear ? 0.0f : 1.0f;
|
||||||
//cublasSgemmEx(cublas_handle,
|
//cublasSgemmEx(cublas_handle,
|
||||||
// CUBLAS_OP_N,
|
// CUBLAS_OP_N,
|
||||||
// CUBLAS_OP_N,
|
// CUBLAS_OP_N,
|
||||||
// size_n, size_m, size_k,
|
// size_n, size_m, size_k,
|
||||||
// &alpha, temp_dq, CUDA_R_16F, size_n,
|
// &alpha, temp_dq, CUDA_R_16F, size_n,
|
||||||
// a, CUDA_R_16F, size_k,
|
// a, CUDA_R_16F, size_k,
|
||||||
// &beta, c, CUDA_R_16F, size_n);
|
// &beta, c, CUDA_R_16F, size_n);
|
||||||
|
|
||||||
//const float alpha = 1.0f;
|
//const float alpha = 1.0f;
|
||||||
//const float beta = clear ? 0.0f : 1.0f;
|
//const float beta = clear ? 0.0f : 1.0f;
|
||||||
|
@ -158,24 +170,21 @@ void gemm_half_q_half_cuda
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
//printf("cuda\n");
|
|
||||||
|
|
||||||
// Quantized matmul
|
// Quantized matmul
|
||||||
|
|
||||||
//if (clear) clear_tensor_cuda(c, size_m, size_n);
|
int block_m_size_max = b->is_gptq ? GPTQ_BLOCK_M_SIZE_MAX : EXL2_BLOCK_M_SIZE_MAX;
|
||||||
|
int max_chunks = size_m / block_m_size_max;
|
||||||
int max_chunks = size_m / BLOCK_M_SIZE_MAX;
|
int last_chunk = max_chunks * block_m_size_max;
|
||||||
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
|
|
||||||
int last_chunk_size = size_m - last_chunk;
|
int last_chunk_size = size_m - last_chunk;
|
||||||
|
|
||||||
if (max_chunks)
|
if (max_chunks)
|
||||||
{
|
{
|
||||||
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear);
|
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, block_m_size_max, clear, r_weights, r_weights_stride, mul_r_weights);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (last_chunk_size)
|
if (last_chunk_size)
|
||||||
{
|
{
|
||||||
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
|
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear, r_weights, r_weights_stride, mul_r_weights);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -201,11 +210,10 @@ void clear_tensor_cuda
|
||||||
int size_n
|
int size_n
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
return;
|
// dim3 blockDim, gridDim;
|
||||||
dim3 blockDim, gridDim;
|
// blockDim.x = CLEAR_N_SIZE;
|
||||||
blockDim.x = CLEAR_N_SIZE;
|
// blockDim.y = 1;
|
||||||
blockDim.y = 1;
|
// gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
|
||||||
gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
|
// gridDim.y = size_m;
|
||||||
gridDim.y = size_m;
|
// clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
|
||||||
clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,10 @@ void gemm_half_q_half_cuda
|
||||||
int size_k,
|
int size_k,
|
||||||
bool clear = false,
|
bool clear = false,
|
||||||
half* reconstruct = NULL,
|
half* reconstruct = NULL,
|
||||||
bool force_cuda = false
|
bool force_cuda = false,
|
||||||
|
const half* r_weights = NULL,
|
||||||
|
const int r_weights_stride = 0,
|
||||||
|
bool mul_r_weights = false
|
||||||
);
|
);
|
||||||
|
|
||||||
void clear_tensor_cuda
|
void clear_tensor_cuda
|
||||||
|
|
|
@ -1,8 +1,5 @@
|
||||||
#include "compat.cuh"
|
#include "compat.cuh"
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
#include <cuda_fp16.h>
|
|
||||||
|
|
||||||
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
|
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||||
{
|
{
|
||||||
half2 result = {};
|
half2 result = {};
|
||||||
|
@ -60,6 +57,47 @@ __forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, c
|
||||||
return fma(result_f, qs_f, g_result);
|
return fma(result_f, qs_f, g_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h)
|
||||||
|
{
|
||||||
|
// Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127
|
||||||
|
|
||||||
|
float result = {};
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++)
|
||||||
|
{
|
||||||
|
half2 w01 = dq[i];
|
||||||
|
float w0 = __low2float(w01);
|
||||||
|
float w1 = __high2float(w01);
|
||||||
|
float x0 = __half2float(*a_ptr++);
|
||||||
|
float x1 = __half2float(*a_ptr++);
|
||||||
|
result = fma(w0, x0, result);
|
||||||
|
result = fma(w1, x1, result);
|
||||||
|
}
|
||||||
|
float qs = __half2float(qs_h);
|
||||||
|
result *= qs;
|
||||||
|
half result_h = __float2half_rn(result);
|
||||||
|
return __hadd(result_h, g_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h)
|
||||||
|
{
|
||||||
|
half2 result = {};
|
||||||
|
const half2* a2_ptr = (const half2*)a_ptr;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||||
|
half result_h = __hadd(__low2half(result), __high2half(result));
|
||||||
|
return __hfma(result_h, qs_h, g_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h)
|
||||||
|
{
|
||||||
|
half2 result = {};
|
||||||
|
const half2* a2_ptr = (const half2*)a_ptr;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||||
|
half result_h = __hadd(__low2half(result), __high2half(result));
|
||||||
|
return __hfma(result_h, qs_h, g_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
typedef void (*fp_gemm_half_q_half_kernel)
|
typedef void (*fp_gemm_half_q_half_kernel)
|
||||||
|
@ -73,7 +111,7 @@ typedef void (*fp_gemm_half_q_half_kernel)
|
||||||
const int,
|
const int,
|
||||||
const int,
|
const int,
|
||||||
const int,
|
const int,
|
||||||
const int,
|
const uint16_t*,
|
||||||
const uint16_t*,
|
const uint16_t*,
|
||||||
const int,
|
const int,
|
||||||
const int,
|
const int,
|
||||||
|
@ -81,10 +119,12 @@ typedef void (*fp_gemm_half_q_half_kernel)
|
||||||
const int,
|
const int,
|
||||||
const int,
|
const int,
|
||||||
const int,
|
const int,
|
||||||
const bool
|
const bool,
|
||||||
|
const half*,
|
||||||
|
const int
|
||||||
);
|
);
|
||||||
|
|
||||||
template <bool first_block, int m_count>
|
template <int m_count, bool use_r_weights, bool mul_r_weights>
|
||||||
__global__ void gemm_half_q_half_kernel
|
__global__ void gemm_half_q_half_kernel
|
||||||
(
|
(
|
||||||
const half* __restrict__ a,
|
const half* __restrict__ a,
|
||||||
|
@ -96,7 +136,7 @@ __global__ void gemm_half_q_half_kernel
|
||||||
const int size_n,
|
const int size_n,
|
||||||
const int size_k,
|
const int size_k,
|
||||||
const int groups,
|
const int groups,
|
||||||
const int groupsize,
|
const uint16_t* __restrict__ b_q_group_map,
|
||||||
const uint16_t* __restrict__ b_q_perm,
|
const uint16_t* __restrict__ b_q_perm,
|
||||||
const int rows_8,
|
const int rows_8,
|
||||||
const int rows_6,
|
const int rows_6,
|
||||||
|
@ -104,7 +144,9 @@ __global__ void gemm_half_q_half_kernel
|
||||||
const int rows_4,
|
const int rows_4,
|
||||||
const int rows_3,
|
const int rows_3,
|
||||||
const int rows_2,
|
const int rows_2,
|
||||||
const bool clear
|
const bool clear,
|
||||||
|
const half* r_weights,
|
||||||
|
const int r_weights_stride
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
MatrixView_half a_(a, size_m, size_k);
|
MatrixView_half a_(a, size_m, size_k);
|
||||||
|
@ -115,18 +157,34 @@ __global__ void gemm_half_q_half_kernel
|
||||||
|
|
||||||
// Block
|
// Block
|
||||||
|
|
||||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
int offset_n = blockIdx.x * EXL2_BLOCK_KN_SIZE * 4;
|
||||||
int offset_m = blockIdx.y * m_count;
|
int offset_m = blockIdx.y * m_count;
|
||||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
int offset_k = blockIdx.z * EXL2_BLOCK_KN_SIZE;
|
||||||
|
|
||||||
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
int end_n = min(offset_n + EXL2_BLOCK_KN_SIZE * 4, size_n);
|
||||||
int end_m = min(offset_m + m_count, size_m);
|
int end_m = min(offset_m + m_count, size_m);
|
||||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
int end_k = min(offset_k + EXL2_BLOCK_KN_SIZE, size_k);
|
||||||
int n = offset_n + t * 4;
|
int n = offset_n + t * 4;
|
||||||
|
|
||||||
|
// Read weights
|
||||||
|
|
||||||
|
half_uint16 weights[MAX_Q_GEMM_WEIGHTS];
|
||||||
|
if constexpr (use_r_weights)
|
||||||
|
{
|
||||||
|
uint16_t any_w = 0;
|
||||||
|
const half* w_ptr = r_weights;
|
||||||
|
for (int m = 0; m < m_count; ++m)
|
||||||
|
{
|
||||||
|
weights[m].as_half = *w_ptr;
|
||||||
|
w_ptr += r_weights_stride;
|
||||||
|
any_w |= weights[m].as_uint16;
|
||||||
|
}
|
||||||
|
if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!)
|
||||||
|
}
|
||||||
|
|
||||||
// Preload block_a
|
// Preload block_a
|
||||||
|
|
||||||
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
|
__shared__ half block_a[m_count][EXL2_BLOCK_KN_SIZE];
|
||||||
|
|
||||||
if (offset_k + t < end_k)
|
if (offset_k + t < end_k)
|
||||||
{
|
{
|
||||||
|
@ -135,6 +193,7 @@ __global__ void gemm_half_q_half_kernel
|
||||||
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
||||||
half* block_a_ptr = block_a[m];
|
half* block_a_ptr = block_a[m];
|
||||||
half a0 = a_ptr[b_q_perm[offset_k + t]];
|
half a0 = a_ptr[b_q_perm[offset_k + t]];
|
||||||
|
// half a0 = a_ptr[offset_k + t];
|
||||||
block_a_ptr[t] = a0;
|
block_a_ptr[t] = a0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -153,14 +212,19 @@ __global__ void gemm_half_q_half_kernel
|
||||||
|
|
||||||
// Find initial group
|
// Find initial group
|
||||||
|
|
||||||
int group = offset_k / groupsize;
|
//int group = offset_k / groupsize;
|
||||||
|
int group = b_q_group_map[offset_k * 2];
|
||||||
|
|
||||||
|
// if (offset_m == 0 && t == 0)
|
||||||
|
// DBGI2(offset_k, group);
|
||||||
|
|
||||||
// Preload scales
|
// Preload scales
|
||||||
|
|
||||||
float scales[MAX_GROUPS_IN_BLOCK][4];
|
half scales[EXL2_MAX_GROUPS_IN_BLOCK][4];
|
||||||
|
|
||||||
int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
|
//int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
|
||||||
for (int g = 0; g < groups_in_block; g++)
|
int temp_k = offset_k;
|
||||||
|
for (int g = 0; temp_k < end_k; g++)
|
||||||
{
|
{
|
||||||
int qscales[4];
|
int qscales[4];
|
||||||
b_q_scale_.item4(qscales, group + g, n);
|
b_q_scale_.item4(qscales, group + g, n);
|
||||||
|
@ -168,11 +232,12 @@ __global__ void gemm_half_q_half_kernel
|
||||||
qscales[1]++;
|
qscales[1]++;
|
||||||
qscales[2]++;
|
qscales[2]++;
|
||||||
qscales[3]++;
|
qscales[3]++;
|
||||||
float maxscale = __half2float(b_q_scale_max[group + g]);
|
half maxscale = b_q_scale_max[group + g];
|
||||||
scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale;
|
scales[g][0] = __hmul(__int2half_rn(qscales[0] * qscales[0]), maxscale);
|
||||||
scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale;
|
scales[g][1] = __hmul(__int2half_rn(qscales[1] * qscales[1]), maxscale);
|
||||||
scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale;
|
scales[g][2] = __hmul(__int2half_rn(qscales[2] * qscales[2]), maxscale);
|
||||||
scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale;
|
scales[g][3] = __hmul(__int2half_rn(qscales[3] * qscales[3]), maxscale);
|
||||||
|
temp_k += b_q_group_map[temp_k * 2 + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
// a, b offset
|
// a, b offset
|
||||||
|
@ -193,20 +258,20 @@ __global__ void gemm_half_q_half_kernel
|
||||||
|
|
||||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||||
const half* a_ptr = &block_a[0][0];
|
const half* a_ptr = &block_a[0][0];
|
||||||
int a_stride = BLOCK_KN_SIZE;
|
int a_stride = EXL2_BLOCK_KN_SIZE;
|
||||||
|
|
||||||
// Initial group
|
// Initial group
|
||||||
|
|
||||||
int scales_idx = 0;
|
int scales_idx = 0;
|
||||||
float qs_f0 = scales[scales_idx][0];
|
half qs_h0 = scales[scales_idx][0];
|
||||||
float qs_f1 = scales[scales_idx][1];
|
half qs_h1 = scales[scales_idx][1];
|
||||||
float qs_f2 = scales[scales_idx][2];
|
half qs_h2 = scales[scales_idx][2];
|
||||||
float qs_f3 = scales[scales_idx][3];
|
half qs_h3 = scales[scales_idx][3];
|
||||||
int nextgroup = offset_k + groupsize;
|
int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
|
||||||
|
|
||||||
// Column result
|
// Column result
|
||||||
|
|
||||||
float block_c[m_count][4] = {};
|
half block_c[m_count][4] = {};
|
||||||
|
|
||||||
// Dequantize groups
|
// Dequantize groups
|
||||||
|
|
||||||
|
@ -218,11 +283,11 @@ __global__ void gemm_half_q_half_kernel
|
||||||
{
|
{
|
||||||
group++;
|
group++;
|
||||||
scales_idx++;
|
scales_idx++;
|
||||||
qs_f0 = scales[scales_idx][0];
|
qs_h0 = scales[scales_idx][0];
|
||||||
qs_f1 = scales[scales_idx][1];
|
qs_h1 = scales[scales_idx][1];
|
||||||
qs_f2 = scales[scales_idx][2];
|
qs_h2 = scales[scales_idx][2];
|
||||||
qs_f3 = scales[scales_idx][3];
|
qs_h3 = scales[scales_idx][3];
|
||||||
nextgroup += groupsize;
|
nextgroup += b_q_group_map[k * 2 + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -240,10 +305,11 @@ __global__ void gemm_half_q_half_kernel
|
||||||
|
|
||||||
for (int m = 0; m < m_count; m++)
|
for (int m = 0; m < m_count; m++)
|
||||||
{
|
{
|
||||||
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||||
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||||
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||||
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||||
|
block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||||
}
|
}
|
||||||
a_ptr += 8;
|
a_ptr += 8;
|
||||||
}
|
}
|
||||||
|
@ -256,11 +322,11 @@ __global__ void gemm_half_q_half_kernel
|
||||||
{
|
{
|
||||||
group++;
|
group++;
|
||||||
scales_idx++;
|
scales_idx++;
|
||||||
qs_f0 = scales[scales_idx][0];
|
qs_h0 = scales[scales_idx][0];
|
||||||
qs_f1 = scales[scales_idx][1];
|
qs_h1 = scales[scales_idx][1];
|
||||||
qs_f2 = scales[scales_idx][2];
|
qs_h2 = scales[scales_idx][2];
|
||||||
qs_f3 = scales[scales_idx][3];
|
qs_h3 = scales[scales_idx][3];
|
||||||
nextgroup += groupsize;
|
nextgroup += b_q_group_map[k * 2 + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -279,10 +345,11 @@ __global__ void gemm_half_q_half_kernel
|
||||||
|
|
||||||
for (int m = 0; m < m_count; m++)
|
for (int m = 0; m < m_count; m++)
|
||||||
{
|
{
|
||||||
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||||
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||||
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||||
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||||
|
block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||||
}
|
}
|
||||||
a_ptr += 16;
|
a_ptr += 16;
|
||||||
}
|
}
|
||||||
|
@ -295,11 +362,11 @@ __global__ void gemm_half_q_half_kernel
|
||||||
{
|
{
|
||||||
group++;
|
group++;
|
||||||
scales_idx++;
|
scales_idx++;
|
||||||
qs_f0 = scales[scales_idx][0];
|
qs_h0 = scales[scales_idx][0];
|
||||||
qs_f1 = scales[scales_idx][1];
|
qs_h1 = scales[scales_idx][1];
|
||||||
qs_f2 = scales[scales_idx][2];
|
qs_h2 = scales[scales_idx][2];
|
||||||
qs_f3 = scales[scales_idx][3];
|
qs_h3 = scales[scales_idx][3];
|
||||||
nextgroup += groupsize;
|
nextgroup += b_q_group_map[k * 2 + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -320,10 +387,11 @@ __global__ void gemm_half_q_half_kernel
|
||||||
|
|
||||||
for (int m = 0; m < m_count; m++)
|
for (int m = 0; m < m_count; m++)
|
||||||
{
|
{
|
||||||
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||||
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||||
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||||
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||||
|
block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||||
}
|
}
|
||||||
a_ptr += 32;
|
a_ptr += 32;
|
||||||
}
|
}
|
||||||
|
@ -337,11 +405,11 @@ __global__ void gemm_half_q_half_kernel
|
||||||
{
|
{
|
||||||
group++;
|
group++;
|
||||||
scales_idx++;
|
scales_idx++;
|
||||||
qs_f0 = scales[scales_idx][0];
|
qs_h0 = scales[scales_idx][0];
|
||||||
qs_f1 = scales[scales_idx][1];
|
qs_h1 = scales[scales_idx][1];
|
||||||
qs_f2 = scales[scales_idx][2];
|
qs_h2 = scales[scales_idx][2];
|
||||||
qs_f3 = scales[scales_idx][3];
|
qs_h3 = scales[scales_idx][3];
|
||||||
nextgroup += groupsize;
|
nextgroup += b_q_group_map[k * 2 + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -358,10 +426,11 @@ __global__ void gemm_half_q_half_kernel
|
||||||
|
|
||||||
for (int m = 0; m < m_count; m++)
|
for (int m = 0; m < m_count; m++)
|
||||||
{
|
{
|
||||||
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||||
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||||
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||||
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||||
|
block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||||
}
|
}
|
||||||
a_ptr += 8;
|
a_ptr += 8;
|
||||||
}
|
}
|
||||||
|
@ -374,11 +443,11 @@ __global__ void gemm_half_q_half_kernel
|
||||||
{
|
{
|
||||||
group++;
|
group++;
|
||||||
scales_idx++;
|
scales_idx++;
|
||||||
qs_f0 = scales[scales_idx][0];
|
qs_h0 = scales[scales_idx][0];
|
||||||
qs_f1 = scales[scales_idx][1];
|
qs_h1 = scales[scales_idx][1];
|
||||||
qs_f2 = scales[scales_idx][2];
|
qs_h2 = scales[scales_idx][2];
|
||||||
qs_f3 = scales[scales_idx][3];
|
qs_h3 = scales[scales_idx][3];
|
||||||
nextgroup += groupsize;
|
nextgroup += b_q_group_map[k * 2 + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -397,10 +466,11 @@ __global__ void gemm_half_q_half_kernel
|
||||||
|
|
||||||
for (int m = 0; m < m_count; m++)
|
for (int m = 0; m < m_count; m++)
|
||||||
{
|
{
|
||||||
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||||
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||||
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||||
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||||
|
block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||||
}
|
}
|
||||||
a_ptr += 32;
|
a_ptr += 32;
|
||||||
}
|
}
|
||||||
|
@ -413,15 +483,15 @@ __global__ void gemm_half_q_half_kernel
|
||||||
{
|
{
|
||||||
group++;
|
group++;
|
||||||
scales_idx++;
|
scales_idx++;
|
||||||
qs_f0 = scales[scales_idx][0];
|
qs_h0 = scales[scales_idx][0];
|
||||||
qs_f1 = scales[scales_idx][1];
|
qs_h1 = scales[scales_idx][1];
|
||||||
qs_f2 = scales[scales_idx][2];
|
qs_h2 = scales[scales_idx][2];
|
||||||
qs_f3 = scales[scales_idx][3];
|
qs_h3 = scales[scales_idx][3];
|
||||||
nextgroup += groupsize;
|
nextgroup += b_q_group_map[k * 2 + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 2; j++)
|
for (int j = 0; j < 1; j++)
|
||||||
{
|
{
|
||||||
int4 load_int4[1];
|
int4 load_int4[1];
|
||||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||||
|
@ -434,15 +504,16 @@ __global__ void gemm_half_q_half_kernel
|
||||||
|
|
||||||
for (int m = 0; m < m_count; m++)
|
for (int m = 0; m < m_count; m++)
|
||||||
{
|
{
|
||||||
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||||
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||||
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||||
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||||
|
block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||||
}
|
}
|
||||||
|
|
||||||
a_ptr += 16;
|
a_ptr += 16;
|
||||||
}
|
}
|
||||||
k += 32;
|
k += 16;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accumulate column sums in c
|
// Accumulate column sums in c
|
||||||
|
@ -450,38 +521,60 @@ __global__ void gemm_half_q_half_kernel
|
||||||
for (int m = 0; m < m_count; m++)
|
for (int m = 0; m < m_count; m++)
|
||||||
{
|
{
|
||||||
half2* out = (half2*)c_.item_ptr(offset_m + m, n);
|
half2* out = (half2*)c_.item_ptr(offset_m + m, n);
|
||||||
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
|
half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
|
||||||
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
|
half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
|
||||||
|
|
||||||
|
if constexpr (mul_r_weights)
|
||||||
|
{
|
||||||
|
half2 w_mul2 = __half2half2(weights[m].as_half);
|
||||||
|
result01 = __hmul2(result01, w_mul2);
|
||||||
|
result23 = __hmul2(result23, w_mul2);
|
||||||
|
}
|
||||||
|
|
||||||
atomicAdd(out , result01);
|
atomicAdd(out , result01);
|
||||||
atomicAdd(out + 1, result23);
|
atomicAdd(out + 1, result23);
|
||||||
|
// *out = result01;
|
||||||
|
// *(out + 1) = result23;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count)
|
template <bool use_r_weights, bool mul_r_weights>
|
||||||
|
struct map_m_count_exl2 {
|
||||||
|
static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count)
|
||||||
|
{
|
||||||
|
#if EXL2_BLOCK_M_SIZE_MAX >= 1
|
||||||
|
if (m_count == 1) return gemm_half_q_half_kernel<1, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if EXL2_BLOCK_M_SIZE_MAX >= 2
|
||||||
|
if (m_count == 2) return gemm_half_q_half_kernel<2, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if EXL2_BLOCK_M_SIZE_MAX >= 3
|
||||||
|
if (m_count == 3) return gemm_half_q_half_kernel<3, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if EXL2_BLOCK_M_SIZE_MAX >= 4
|
||||||
|
if (m_count == 4) return gemm_half_q_half_kernel<4, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if EXL2_BLOCK_M_SIZE_MAX >= 5
|
||||||
|
if (m_count == 5) return gemm_half_q_half_kernel<5, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if EXL2_BLOCK_M_SIZE_MAX >= 6
|
||||||
|
if (m_count == 6) return gemm_half_q_half_kernel<6, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if EXL2_BLOCK_M_SIZE_MAX >= 7
|
||||||
|
if (m_count == 7) return gemm_half_q_half_kernel<7, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if EXL2_BLOCK_M_SIZE_MAX >= 8
|
||||||
|
if (m_count == 8) return gemm_half_q_half_kernel<8, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count, bool r_weights, bool mul_r_weights)
|
||||||
{
|
{
|
||||||
#if BLOCK_M_SIZE_MAX >= 1
|
if (!r_weights && !mul_r_weights) return map_m_count_exl2<false, false>::pick_gemm_half_q_half_kernel(m_count);
|
||||||
if (m_count == 1) return gemm_half_q_half_kernel<true, 1>;
|
if (!r_weights && mul_r_weights) return map_m_count_exl2<false, true>::pick_gemm_half_q_half_kernel(m_count);
|
||||||
#endif
|
if ( r_weights && !mul_r_weights) return map_m_count_exl2< true, false>::pick_gemm_half_q_half_kernel(m_count);
|
||||||
#if BLOCK_M_SIZE_MAX >= 2
|
if ( r_weights && mul_r_weights) return map_m_count_exl2< true, true>::pick_gemm_half_q_half_kernel(m_count);
|
||||||
if (m_count == 2) return gemm_half_q_half_kernel<true, 2>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 3
|
|
||||||
if (m_count == 3) return gemm_half_q_half_kernel<true, 3>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 4
|
|
||||||
if (m_count == 4) return gemm_half_q_half_kernel<true, 4>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 5
|
|
||||||
if (m_count == 5) return gemm_half_q_half_kernel<true, 5>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 6
|
|
||||||
if (m_count == 6) return gemm_half_q_half_kernel<true, 6>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 7
|
|
||||||
if (m_count == 7) return gemm_half_q_half_kernel<true, 7>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 8
|
|
||||||
if (m_count == 8) return gemm_half_q_half_kernel<true, 8>;
|
|
||||||
#endif
|
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,15 @@ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
|
||||||
return __half2float(__low2half(result)) + __half2float(__high2half(result));
|
return __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ half2 dot22_8_h2(half2(&dq)[4], const half* a_ptr)
|
||||||
|
{
|
||||||
|
half2 result = {};
|
||||||
|
const half2* a2_ptr = (const half2*)a_ptr;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
typedef void (*fp_gemm_half_q_half_gptq_kernel)
|
typedef void (*fp_gemm_half_q_half_gptq_kernel)
|
||||||
(
|
(
|
||||||
const half*,
|
const half*,
|
||||||
|
@ -32,10 +41,12 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)
|
||||||
const int,
|
const int,
|
||||||
const uint16_t*,
|
const uint16_t*,
|
||||||
const int,
|
const int,
|
||||||
const bool
|
const bool,
|
||||||
|
const half*,
|
||||||
|
const int
|
||||||
);
|
);
|
||||||
|
|
||||||
template <bool first_block, int m_count>
|
template <int m_count, bool use_r_weights, bool mul_r_weights>
|
||||||
__global__ void gemm_half_q_half_gptq_kernel
|
__global__ void gemm_half_q_half_gptq_kernel
|
||||||
(
|
(
|
||||||
const half* __restrict__ a,
|
const half* __restrict__ a,
|
||||||
|
@ -50,7 +61,9 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||||
const int groupsize,
|
const int groupsize,
|
||||||
const uint16_t* __restrict__ b_q_perm,
|
const uint16_t* __restrict__ b_q_perm,
|
||||||
const int rows_4,
|
const int rows_4,
|
||||||
const bool clear
|
const bool clear,
|
||||||
|
const half* r_weights,
|
||||||
|
const int r_weights_stride
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
MatrixView_half a_(a, size_m, size_k);
|
MatrixView_half a_(a, size_m, size_k);
|
||||||
|
@ -62,19 +75,35 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||||
|
|
||||||
// Block
|
// Block
|
||||||
|
|
||||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
int offset_n = blockIdx.x * GPTQ_BLOCK_KN_SIZE * 4;
|
||||||
int offset_m = blockIdx.y * m_count;
|
int offset_m = blockIdx.y * m_count;
|
||||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
int offset_k = blockIdx.z * GPTQ_BLOCK_KN_SIZE;
|
||||||
|
|
||||||
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
int end_n = min(offset_n + GPTQ_BLOCK_KN_SIZE * 4, size_n);
|
||||||
int end_m = min(offset_m + m_count, size_m);
|
int end_m = min(offset_m + m_count, size_m);
|
||||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
int end_k = min(offset_k + GPTQ_BLOCK_KN_SIZE, size_k);
|
||||||
|
|
||||||
int n = offset_n + t * 4;
|
int n = offset_n + t * 4;
|
||||||
|
|
||||||
|
// Read weights
|
||||||
|
|
||||||
|
half_uint16 weights[MAX_Q_GEMM_WEIGHTS];
|
||||||
|
if constexpr (use_r_weights)
|
||||||
|
{
|
||||||
|
uint16_t any_w = 0;
|
||||||
|
const half* w_ptr = r_weights;
|
||||||
|
for (int m = 0; m < m_count; ++m)
|
||||||
|
{
|
||||||
|
weights[m].as_half = *w_ptr;
|
||||||
|
w_ptr += r_weights_stride;
|
||||||
|
any_w |= weights[m].as_uint16;
|
||||||
|
}
|
||||||
|
if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!)
|
||||||
|
}
|
||||||
|
|
||||||
// Preload block_a
|
// Preload block_a
|
||||||
|
|
||||||
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
|
__shared__ half block_a[m_count][GPTQ_BLOCK_KN_SIZE];
|
||||||
|
|
||||||
if (offset_k + t < end_k)
|
if (offset_k + t < end_k)
|
||||||
{
|
{
|
||||||
|
@ -113,16 +142,16 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||||
|
|
||||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||||
const half* a_ptr = &block_a[0][0];
|
const half* a_ptr = &block_a[0][0];
|
||||||
int a_stride = BLOCK_KN_SIZE;
|
int a_stride = GPTQ_BLOCK_KN_SIZE;
|
||||||
|
|
||||||
// Initial group
|
// Initial group
|
||||||
|
|
||||||
int zeros[4];
|
int zeros[4];
|
||||||
float scales[4];
|
half2 scales[4];
|
||||||
half2 z1z16[4][2];
|
half2 z1z16[4][2];
|
||||||
half2 y1y16[4][2];
|
half2 y1y16[4][2];
|
||||||
b_gptq_qzeros_.item4(zeros, group, n);
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
b_gptq_scales_.item4_f(scales, group, n);
|
b_gptq_scales_.item4_h2(scales, group, n);
|
||||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||||
|
@ -132,7 +161,7 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||||
|
|
||||||
// Column result
|
// Column result
|
||||||
|
|
||||||
float block_c[m_count][4] = {};
|
half2 block_c[m_count][4] = {};
|
||||||
|
|
||||||
// Dequantize and multiply
|
// Dequantize and multiply
|
||||||
|
|
||||||
|
@ -144,7 +173,7 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||||
group++;
|
group++;
|
||||||
nextgroup += groupsize;
|
nextgroup += groupsize;
|
||||||
b_gptq_qzeros_.item4(zeros, group, n);
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
b_gptq_scales_.item4_f(scales, group, n);
|
b_gptq_scales_.item4_h2(scales, group, n);
|
||||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||||
|
@ -166,10 +195,11 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int m = 0; m < m_count; m++)
|
for (int m = 0; m < m_count; m++)
|
||||||
{
|
{
|
||||||
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
|
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||||
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
|
block_c[m][0] = __hfma2(dot22_8_h2(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
|
||||||
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
|
block_c[m][1] = __hfma2(dot22_8_h2(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
|
||||||
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
|
block_c[m][2] = __hfma2(dot22_8_h2(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
|
||||||
|
block_c[m][3] = __hfma2(dot22_8_h2(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
|
||||||
}
|
}
|
||||||
|
|
||||||
b_ptr += size_n;
|
b_ptr += size_n;
|
||||||
|
@ -182,38 +212,62 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||||
for (int m = 0; m < m_count; m++)
|
for (int m = 0; m < m_count; m++)
|
||||||
{
|
{
|
||||||
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
|
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
|
||||||
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
|
half result0 = __hadd(__low2half(block_c[m][0]), __high2half(block_c[m][0]));
|
||||||
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
|
half result1 = __hadd(__low2half(block_c[m][1]), __high2half(block_c[m][1]));
|
||||||
|
half result2 = __hadd(__low2half(block_c[m][2]), __high2half(block_c[m][2]));
|
||||||
|
half result3 = __hadd(__low2half(block_c[m][3]), __high2half(block_c[m][3]));
|
||||||
|
half2 result01 = __halves2half2(result0, result1);
|
||||||
|
half2 result23 = __halves2half2(result2, result3);
|
||||||
|
|
||||||
|
if constexpr (mul_r_weights)
|
||||||
|
{
|
||||||
|
half2 w_mul2 = __half2half2(weights[m].as_half);
|
||||||
|
result01 = __hmul2(result01, w_mul2);
|
||||||
|
result23 = __hmul2(result23, w_mul2);
|
||||||
|
}
|
||||||
|
|
||||||
atomicAdd(out , result01);
|
atomicAdd(out , result01);
|
||||||
atomicAdd(out + 1, result23);
|
atomicAdd(out + 1, result23);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
|
template <bool use_r_weights, bool mul_r_weights>
|
||||||
|
struct map_m_count_gptq {
|
||||||
|
static constexpr fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(int m_count)
|
||||||
|
{
|
||||||
|
#if GPTQ_BLOCK_M_SIZE_MAX >= 1
|
||||||
|
if (m_count == 1) return gemm_half_q_half_gptq_kernel<1, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if GPTQ_BLOCK_M_SIZE_MAX >= 2
|
||||||
|
if (m_count == 2) return gemm_half_q_half_gptq_kernel<2, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if GPTQ_BLOCK_M_SIZE_MAX >= 3
|
||||||
|
if (m_count == 3) return gemm_half_q_half_gptq_kernel<3, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if GPTQ_BLOCK_M_SIZE_MAX >= 4
|
||||||
|
if (m_count == 4) return gemm_half_q_half_gptq_kernel<4, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if GPTQ_BLOCK_M_SIZE_MAX >= 5
|
||||||
|
if (m_count == 5) return gemm_half_q_half_gptq_kernel<5, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if GPTQ_BLOCK_M_SIZE_MAX >= 6
|
||||||
|
if (m_count == 6) return gemm_half_q_half_gptq_kernel<6, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if GPTQ_BLOCK_M_SIZE_MAX >= 7
|
||||||
|
if (m_count == 7) return gemm_half_q_half_gptq_kernel<7, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if GPTQ_BLOCK_M_SIZE_MAX >= 8
|
||||||
|
if (m_count == 8) return gemm_half_q_half_gptq_kernel<8, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(const int m_count, bool r_weights, bool mul_r_weights)
|
||||||
{
|
{
|
||||||
#if BLOCK_M_SIZE_MAX >= 1
|
if (!r_weights && !mul_r_weights) return map_m_count_gptq<false, false>::pick_gemm_half_q_half_gptq_kernel(m_count);
|
||||||
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
|
if (!r_weights && mul_r_weights) return map_m_count_gptq<false, true>::pick_gemm_half_q_half_gptq_kernel(m_count);
|
||||||
#endif
|
if ( r_weights && !mul_r_weights) return map_m_count_gptq< true, false>::pick_gemm_half_q_half_gptq_kernel(m_count);
|
||||||
#if BLOCK_M_SIZE_MAX >= 2
|
if ( r_weights && mul_r_weights) return map_m_count_gptq< true, true>::pick_gemm_half_q_half_gptq_kernel(m_count);
|
||||||
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 3
|
|
||||||
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 4
|
|
||||||
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 5
|
|
||||||
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 6
|
|
||||||
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 7
|
|
||||||
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 8
|
|
||||||
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
|
|
||||||
#endif
|
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,6 +57,7 @@ QMatrix::QMatrix
|
||||||
uint32_t* _q_scale,
|
uint32_t* _q_scale,
|
||||||
half* _q_scale_max,
|
half* _q_scale_max,
|
||||||
uint16_t* _q_groups,
|
uint16_t* _q_groups,
|
||||||
|
uint16_t* _q_group_map,
|
||||||
|
|
||||||
uint32_t* _gptq_qzeros,
|
uint32_t* _gptq_qzeros,
|
||||||
half* _gptq_scales,
|
half* _gptq_scales,
|
||||||
|
@ -80,13 +81,17 @@ QMatrix::QMatrix
|
||||||
cuda_q_scale = _q_scale;
|
cuda_q_scale = _q_scale;
|
||||||
cuda_q_scale_max = _q_scale_max;
|
cuda_q_scale_max = _q_scale_max;
|
||||||
cuda_q_groups = _q_groups;
|
cuda_q_groups = _q_groups;
|
||||||
|
cuda_q_group_map = _q_group_map;
|
||||||
cuda_gptq_qzeros = _gptq_qzeros;
|
cuda_gptq_qzeros = _gptq_qzeros;
|
||||||
cuda_gptq_scales = _gptq_scales;
|
cuda_gptq_scales = _gptq_scales;
|
||||||
|
|
||||||
is_gptq = (_gptq_qzeros != NULL);
|
is_gptq = (_gptq_qzeros != NULL);
|
||||||
|
|
||||||
groupsize = 1;
|
if (is_gptq)
|
||||||
while (groupsize * groups < height) groupsize *= 2;
|
{
|
||||||
|
gptq_groupsize = 1;
|
||||||
|
while (gptq_groupsize * groups < height) gptq_groupsize *= 2;
|
||||||
|
}
|
||||||
|
|
||||||
// Create group map
|
// Create group map
|
||||||
|
|
||||||
|
@ -102,15 +107,26 @@ QMatrix::QMatrix
|
||||||
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
|
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
|
||||||
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
|
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
|
||||||
|
|
||||||
|
int row = 0;
|
||||||
for (int i = 0; i < groups; i++)
|
for (int i = 0; i < groups; i++)
|
||||||
{
|
{
|
||||||
int bits = cpu_q_groups[i * 2];
|
int bits = cpu_q_groups[i * 2];
|
||||||
if (bits == 8) rows_8 += groupsize;
|
|
||||||
if (bits == 6) rows_6 += groupsize;
|
int rows;
|
||||||
if (bits == 5) rows_5 += groupsize;
|
if (i < groups - 1)
|
||||||
if (bits == 4) rows_4 += groupsize;
|
{
|
||||||
if (bits == 3) rows_3 += groupsize;
|
int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1];
|
||||||
if (bits == 2) rows_2 += groupsize;
|
rows = qrows * 32 / bits;
|
||||||
|
}
|
||||||
|
else rows = height - row;
|
||||||
|
|
||||||
|
if (bits == 8) rows_8 += rows;
|
||||||
|
if (bits == 6) rows_6 += rows;
|
||||||
|
if (bits == 5) rows_5 += rows;
|
||||||
|
if (bits == 4) rows_4 += rows;
|
||||||
|
if (bits == 3) rows_3 += rows;
|
||||||
|
if (bits == 2) rows_2 += rows;
|
||||||
|
row += rows;
|
||||||
}
|
}
|
||||||
|
|
||||||
free(cpu_q_groups);
|
free(cpu_q_groups);
|
||||||
|
@ -138,6 +154,13 @@ QMatrix::QMatrix
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DBGI(rows_8);
|
||||||
|
// DBGI(rows_6);
|
||||||
|
// DBGI(rows_5);
|
||||||
|
// DBGI(rows_4);
|
||||||
|
// DBGI(rows_3);
|
||||||
|
// DBGI(rows_2);
|
||||||
|
|
||||||
// Shuffle quantized data
|
// Shuffle quantized data
|
||||||
|
|
||||||
dim3 blockDim, gridDim;
|
dim3 blockDim, gridDim;
|
||||||
|
@ -283,10 +306,10 @@ __global__ void reconstruct_kernel
|
||||||
const uint16_t* __restrict__ b_q_perm,
|
const uint16_t* __restrict__ b_q_perm,
|
||||||
const uint32_t* __restrict__ b_q_scale,
|
const uint32_t* __restrict__ b_q_scale,
|
||||||
const half* __restrict__ b_q_scale_max,
|
const half* __restrict__ b_q_scale_max,
|
||||||
//const uint16_t* __restrict__ b_q_groups,
|
const uint16_t* __restrict__ b_q_group_map,
|
||||||
const int size_k,
|
const int size_k,
|
||||||
const int size_n,
|
const int size_n,
|
||||||
const int groupsize,
|
//const int groupsize,
|
||||||
const int groups,
|
const int groups,
|
||||||
half* __restrict__ b,
|
half* __restrict__ b,
|
||||||
const int rows_8,
|
const int rows_8,
|
||||||
|
@ -317,7 +340,8 @@ __global__ void reconstruct_kernel
|
||||||
|
|
||||||
// Find initial group
|
// Find initial group
|
||||||
|
|
||||||
int group = offset_k / groupsize;
|
// int group = offset_k / groupsize;
|
||||||
|
int group = b_q_group_map[offset_k * 2];
|
||||||
|
|
||||||
int pre_rows_8 = min(rows_8, offset_k);
|
int pre_rows_8 = min(rows_8, offset_k);
|
||||||
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
|
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
|
||||||
|
@ -337,7 +361,7 @@ __global__ void reconstruct_kernel
|
||||||
|
|
||||||
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
|
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
|
||||||
half2 qs_h2 = __halves2half2(qs_h, qs_h);
|
half2 qs_h2 = __halves2half2(qs_h, qs_h);
|
||||||
int nextgroup = offset_k + groupsize;
|
int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
|
||||||
|
|
||||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||||
int k = offset_k;
|
int k = offset_k;
|
||||||
|
@ -347,7 +371,7 @@ __global__ void reconstruct_kernel
|
||||||
|
|
||||||
while (k < rows_8 && k < end_k)
|
while (k < rows_8 && k < end_k)
|
||||||
{
|
{
|
||||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||||
for (int p = 0; p < 4; p++)
|
for (int p = 0; p < 4; p++)
|
||||||
{
|
{
|
||||||
half2 dq[4];
|
half2 dq[4];
|
||||||
|
@ -363,7 +387,7 @@ __global__ void reconstruct_kernel
|
||||||
|
|
||||||
while (k < rows_6 && k < end_k)
|
while (k < rows_6 && k < end_k)
|
||||||
{
|
{
|
||||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||||
for (int p = 0; p < 2; p++)
|
for (int p = 0; p < 2; p++)
|
||||||
{
|
{
|
||||||
half2 dq[8];
|
half2 dq[8];
|
||||||
|
@ -380,7 +404,7 @@ __global__ void reconstruct_kernel
|
||||||
|
|
||||||
while (k < rows_5 && k < end_k)
|
while (k < rows_5 && k < end_k)
|
||||||
{
|
{
|
||||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||||
for (int p = 0; p < 1; p++)
|
for (int p = 0; p < 1; p++)
|
||||||
{
|
{
|
||||||
half2 dq[16];
|
half2 dq[16];
|
||||||
|
@ -399,7 +423,7 @@ __global__ void reconstruct_kernel
|
||||||
|
|
||||||
while (k < rows_4 && k < end_k)
|
while (k < rows_4 && k < end_k)
|
||||||
{
|
{
|
||||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||||
for (int p = 0; p < 4; p++)
|
for (int p = 0; p < 4; p++)
|
||||||
{
|
{
|
||||||
half2 dq[4];
|
half2 dq[4];
|
||||||
|
@ -414,7 +438,7 @@ __global__ void reconstruct_kernel
|
||||||
|
|
||||||
while (k < rows_3 && k < end_k)
|
while (k < rows_3 && k < end_k)
|
||||||
{
|
{
|
||||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||||
for (int p = 0; p < 1; p++)
|
for (int p = 0; p < 1; p++)
|
||||||
{
|
{
|
||||||
half2 dq[16];
|
half2 dq[16];
|
||||||
|
@ -431,8 +455,8 @@ __global__ void reconstruct_kernel
|
||||||
|
|
||||||
while (k < rows_2 && k < end_k)
|
while (k < rows_2 && k < end_k)
|
||||||
{
|
{
|
||||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||||
for (int p = 0; p < 2; p++)
|
for (int p = 0; p < 1; p++)
|
||||||
{
|
{
|
||||||
half2 dq[8];
|
half2 dq[8];
|
||||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||||
|
@ -441,7 +465,7 @@ __global__ void reconstruct_kernel
|
||||||
half* dqh = (half*) dq;
|
half* dqh = (half*) dq;
|
||||||
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
|
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||||
}
|
}
|
||||||
k += 32;
|
k += 16;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -461,10 +485,10 @@ void QMatrix::reconstruct(half* out)
|
||||||
cuda_q_perm,
|
cuda_q_perm,
|
||||||
cuda_q_scale,
|
cuda_q_scale,
|
||||||
cuda_q_scale_max,
|
cuda_q_scale_max,
|
||||||
//cuda_q_groups,
|
cuda_q_group_map,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
groupsize,
|
//groupsize,
|
||||||
groups,
|
groups,
|
||||||
out,
|
out,
|
||||||
rows_8,
|
rows_8,
|
||||||
|
@ -487,7 +511,7 @@ void QMatrix::reconstruct(half* out)
|
||||||
//const uint16_t* __restrict__ b_q_groups,
|
//const uint16_t* __restrict__ b_q_groups,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
groupsize,
|
gptq_groupsize,
|
||||||
groups,
|
groups,
|
||||||
out,
|
out,
|
||||||
rows_4
|
rows_4
|
||||||
|
|
|
@ -18,7 +18,7 @@ public:
|
||||||
int height;
|
int height;
|
||||||
int width;
|
int width;
|
||||||
int groups;
|
int groups;
|
||||||
int groupsize;
|
int gptq_groupsize;
|
||||||
|
|
||||||
int rows_8;
|
int rows_8;
|
||||||
int rows_6;
|
int rows_6;
|
||||||
|
@ -33,6 +33,7 @@ public:
|
||||||
uint32_t* cuda_q_scale = NULL;
|
uint32_t* cuda_q_scale = NULL;
|
||||||
half* cuda_q_scale_max = NULL;
|
half* cuda_q_scale_max = NULL;
|
||||||
uint16_t* cuda_q_groups = NULL;
|
uint16_t* cuda_q_groups = NULL;
|
||||||
|
uint16_t* cuda_q_group_map = NULL;
|
||||||
uint32_t* cuda_gptq_qzeros = NULL;
|
uint32_t* cuda_gptq_qzeros = NULL;
|
||||||
half* cuda_gptq_scales = NULL;
|
half* cuda_gptq_scales = NULL;
|
||||||
|
|
||||||
|
@ -53,6 +54,7 @@ public:
|
||||||
uint32_t* _q_scale,
|
uint32_t* _q_scale,
|
||||||
half* _q_scale_max,
|
half* _q_scale_max,
|
||||||
uint16_t* _q_groups,
|
uint16_t* _q_groups,
|
||||||
|
uint16_t* _q_group_map,
|
||||||
|
|
||||||
uint32_t* _gptq_qzeros,
|
uint32_t* _gptq_qzeros,
|
||||||
half* _gptq_scales,
|
half* _gptq_scales,
|
||||||
|
|
|
@ -7,6 +7,7 @@ union half2_uint32
|
||||||
half2 as_half2;
|
half2 as_half2;
|
||||||
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||||
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||||
|
__device__ half2_uint32() : as_uint32(0) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
union half_uint16
|
union half_uint16
|
||||||
|
@ -15,6 +16,7 @@ union half_uint16
|
||||||
half as_half;
|
half as_half;
|
||||||
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||||
__device__ half_uint16(half val) : as_half(val) {}
|
__device__ half_uint16(half val) : as_half(val) {}
|
||||||
|
__device__ half_uint16() : as_uint16(0) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Max_scale premultiplied by 1/256
|
// Max_scale premultiplied by 1/256
|
||||||
|
|
|
@ -1,3 +1,11 @@
|
||||||
|
#ifndef _util_cuh
|
||||||
|
#define _util_cuh
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
||||||
|
|
||||||
|
@ -40,3 +48,7 @@ inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=
|
||||||
if (abort) exit(code);
|
if (abort) exit(code);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void print_global_mem(const half* ptr, int rows, int columns, int stride);
|
||||||
|
|
||||||
|
#endif
|
|
@ -31,6 +31,7 @@ uintptr_t make_q_matrix
|
||||||
torch::Tensor q_scale,
|
torch::Tensor q_scale,
|
||||||
torch::Tensor q_scale_max,
|
torch::Tensor q_scale_max,
|
||||||
torch::Tensor q_groups,
|
torch::Tensor q_groups,
|
||||||
|
torch::Tensor q_group_map,
|
||||||
torch::Tensor gptq_qzeros,
|
torch::Tensor gptq_qzeros,
|
||||||
torch::Tensor gptq_scales,
|
torch::Tensor gptq_scales,
|
||||||
torch::Tensor gptq_g_idx,
|
torch::Tensor gptq_g_idx,
|
||||||
|
@ -43,6 +44,7 @@ uintptr_t make_q_matrix
|
||||||
TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
|
TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
|
||||||
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
|
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
|
||||||
TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
|
TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(q_group_map, kShort);
|
||||||
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
|
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
|
||||||
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
|
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
|
||||||
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
|
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
|
||||||
|
@ -83,12 +85,15 @@ uintptr_t make_q_matrix
|
||||||
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
|
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
|
||||||
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
|
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
|
||||||
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
|
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
|
||||||
|
q_group_map.device().is_meta() ? NULL : (uint16_t*) q_group_map.data_ptr(),
|
||||||
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
|
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
|
||||||
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
|
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
|
||||||
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
|
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
|
||||||
(half*) temp_dq.data_ptr()
|
(half*) temp_dq.data_ptr()
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if (m->failed) throw std::runtime_error("CUDA out of memory");
|
||||||
|
|
||||||
return reinterpret_cast<uintptr_t> (m);
|
return reinterpret_cast<uintptr_t> (m);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,10 +32,10 @@ def fresh_cache():
|
||||||
current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
|
current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
|
||||||
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
|
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
|
||||||
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
|
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
|
||||||
os.environ['HUGGINGFACE_HUB_CACHE'] = d
|
os.environ["HUGGINGFACE_HUB_CACHE"] = d
|
||||||
yield
|
yield
|
||||||
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
|
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
|
||||||
os.environ['HUGGINGFACE_HUB_CACHE'] = current_value
|
os.environ["HUGGINGFACE_HUB_CACHE"] = current_value
|
||||||
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value
|
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ def prefetched():
|
||||||
revision="main",
|
revision="main",
|
||||||
local_files_only=False,
|
local_files_only=False,
|
||||||
repo_type="model",
|
repo_type="model",
|
||||||
allow_patterns=["*.safetensors"]
|
allow_patterns=["*.safetensors"],
|
||||||
)
|
)
|
||||||
yield model_id
|
yield model_id
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ def test_weight_hub_files_offline_error(offline, fresh_cache):
|
||||||
def test_weight_hub_files_offline_ok(prefetched, offline):
|
def test_weight_hub_files_offline_ok(prefetched, offline):
|
||||||
# If the model is prefetched then we should be able to get the weight files from local cache
|
# If the model is prefetched then we should be able to get the weight files from local cache
|
||||||
filenames = weight_hub_files(prefetched)
|
filenames = weight_hub_files(prefetched)
|
||||||
assert filenames == ['model.safetensors']
|
assert filenames == ["model.safetensors"]
|
||||||
|
|
||||||
|
|
||||||
def test_weight_hub_files():
|
def test_weight_hub_files():
|
||||||
|
|
|
@ -71,7 +71,7 @@ def _load_multi_mqa_gptq(
|
||||||
|
|
||||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||||
g_idx = g_idx.to(device=weights.device)
|
g_idx = g_idx.to(device=weights.device)
|
||||||
bits, groupsize = weights._get_gptq_params()
|
bits, groupsize, _ = weights._get_gptq_params()
|
||||||
|
|
||||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,32 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
||||||
return output.view(output_shape)
|
return output.view(output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
# Group map needed for irregular group sizes
|
||||||
|
|
||||||
|
|
||||||
|
def make_group_map(q_groups, num_qrows):
|
||||||
|
|
||||||
|
gr = q_groups.tolist()
|
||||||
|
group_map = []
|
||||||
|
num_groups = len(gr) // 2
|
||||||
|
|
||||||
|
for i in range(num_groups):
|
||||||
|
bits = gr[i * 2]
|
||||||
|
if i < num_groups - 1:
|
||||||
|
qrows = gr[i * 2 + 3] - gr[i * 2 + 1]
|
||||||
|
else:
|
||||||
|
qrows = num_qrows - gr[i * 2 + 1]
|
||||||
|
rows = qrows * 32 // bits
|
||||||
|
for j in range(rows):
|
||||||
|
group_map += [i]
|
||||||
|
group_map += [rows - j]
|
||||||
|
|
||||||
|
return torch.tensor(group_map, dtype=torch.short, device=q_groups.device)
|
||||||
|
|
||||||
|
|
||||||
|
# Create Q matrix
|
||||||
|
|
||||||
|
|
||||||
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||||
"""
|
"""
|
||||||
Create Q matrix
|
Create Q matrix
|
||||||
|
@ -37,6 +63,10 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||||
w["q_scale_max"] /= 256
|
w["q_scale_max"] /= 256
|
||||||
w["q_perm"] = w["q_perm"].short()
|
w["q_perm"] = w["q_perm"].short()
|
||||||
w["q_invperm"] = w["q_invperm"].short()
|
w["q_invperm"] = w["q_invperm"].short()
|
||||||
|
|
||||||
|
if "q_group_map" not in w:
|
||||||
|
w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0])
|
||||||
|
|
||||||
return make_q_matrix(
|
return make_q_matrix(
|
||||||
w["q_weight"],
|
w["q_weight"],
|
||||||
w["q_perm"],
|
w["q_perm"],
|
||||||
|
@ -44,6 +74,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||||
w["q_scale"],
|
w["q_scale"],
|
||||||
w["q_scale_max"],
|
w["q_scale_max"],
|
||||||
w["q_groups"],
|
w["q_groups"],
|
||||||
|
w["q_group_map"],
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
|
@ -70,6 +101,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
|
none_tensor,
|
||||||
w["qzeros"],
|
w["qzeros"],
|
||||||
w["scales"],
|
w["scales"],
|
||||||
w["g_idx"].cpu(),
|
w["g_idx"].cpu(),
|
||||||
|
@ -84,6 +116,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
|
none_tensor,
|
||||||
w["qzeros"],
|
w["qzeros"],
|
||||||
w["scales"],
|
w["scales"],
|
||||||
none_tensor,
|
none_tensor,
|
||||||
|
|
|
@ -18,7 +18,9 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
||||||
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
|
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
|
||||||
|
|
||||||
|
|
||||||
def _cached_weight_files(model_id: str, revision: Optional[str], extension: str) -> List[str]:
|
def _cached_weight_files(
|
||||||
|
model_id: str, revision: Optional[str], extension: str
|
||||||
|
) -> List[str]:
|
||||||
"""Guess weight files from the cached revision snapshot directory"""
|
"""Guess weight files from the cached revision snapshot directory"""
|
||||||
d = _get_cached_revision_directory(model_id, revision)
|
d = _get_cached_revision_directory(model_id, revision)
|
||||||
if not d:
|
if not d:
|
||||||
|
@ -27,7 +29,9 @@ def _cached_weight_files(model_id: str, revision: Optional[str], extension: str)
|
||||||
return filenames
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
def _weight_hub_files_from_model_info(info: hf_api.ModelInfo, extension: str) -> List[str]:
|
def _weight_hub_files_from_model_info(
|
||||||
|
info: hf_api.ModelInfo, extension: str
|
||||||
|
) -> List[str]:
|
||||||
return [
|
return [
|
||||||
s.rfilename
|
s.rfilename
|
||||||
for s in info.siblings
|
for s in info.siblings
|
||||||
|
@ -44,21 +48,27 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
||||||
# see _weight_hub_files_from_model_info, that's also what is
|
# see _weight_hub_files_from_model_info, that's also what is
|
||||||
# done there with the len(s.rfilename.split("/")) == 1 condition
|
# done there with the len(s.rfilename.split("/")) == 1 condition
|
||||||
root, _, files = next(os.walk(str(d)))
|
root, _, files = next(os.walk(str(d)))
|
||||||
filenames = [f for f in files
|
filenames = [
|
||||||
if f.endswith(extension)
|
f
|
||||||
and "arguments" not in f
|
for f in files
|
||||||
and "args" not in f
|
if f.endswith(extension)
|
||||||
and "adapter" not in f
|
and "arguments" not in f
|
||||||
and "training" not in f]
|
and "args" not in f
|
||||||
|
and "adapter" not in f
|
||||||
|
and "training" not in f
|
||||||
|
]
|
||||||
return filenames
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Optional[Path]:
|
def _get_cached_revision_directory(
|
||||||
|
model_id: str, revision: Optional[str]
|
||||||
|
) -> Optional[Path]:
|
||||||
if revision is None:
|
if revision is None:
|
||||||
revision = "main"
|
revision = "main"
|
||||||
|
|
||||||
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path(
|
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path(
|
||||||
file_download.repo_folder_name(repo_id=model_id, repo_type="model"))
|
file_download.repo_folder_name(repo_id=model_id, repo_type="model")
|
||||||
|
)
|
||||||
|
|
||||||
if not repo_cache.is_dir():
|
if not repo_cache.is_dir():
|
||||||
# No cache for this model
|
# No cache for this model
|
||||||
|
@ -86,7 +96,7 @@ def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Op
|
||||||
|
|
||||||
|
|
||||||
def weight_hub_files(
|
def weight_hub_files(
|
||||||
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Get the weights filenames on the hub"""
|
"""Get the weights filenames on the hub"""
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
|
|
|
@ -19,6 +19,7 @@ from accelerate import init_empty_weights
|
||||||
|
|
||||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||||
|
from text_generation_server.utils.log import log_once
|
||||||
|
|
||||||
HAS_AWQ = True
|
HAS_AWQ = True
|
||||||
try:
|
try:
|
||||||
|
@ -35,10 +36,11 @@ HAS_EXLLAMA = False
|
||||||
CAN_EXLLAMA = major >= 8
|
CAN_EXLLAMA = major >= 8
|
||||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||||
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
||||||
logger.warning(
|
V2 = False
|
||||||
|
log_once(
|
||||||
|
logger.warning,
|
||||||
"Disabling exllama v2 and using v1 instead because there are issues when sharding"
|
"Disabling exllama v2 and using v1 instead because there are issues when sharding"
|
||||||
)
|
)
|
||||||
V2 = False
|
|
||||||
|
|
||||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||||
HAS_EXLLAMA = False
|
HAS_EXLLAMA = False
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(10)
|
||||||
|
def log_once(log, msg:str):
|
||||||
|
log(msg)
|
|
@ -6,6 +6,7 @@ import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
import json
|
import json
|
||||||
|
from text_generation_server.utils.log import log_once
|
||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
|
@ -161,7 +162,7 @@ class Weights:
|
||||||
else:
|
else:
|
||||||
g_idx = None
|
g_idx = None
|
||||||
|
|
||||||
bits, groupsize = self._get_gptq_params()
|
bits, groupsize, _ = self._get_gptq_params()
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||||
else:
|
else:
|
||||||
slice_ = self._get_slice(f"{prefix}.weight")
|
slice_ = self._get_slice(f"{prefix}.weight")
|
||||||
|
@ -211,10 +212,10 @@ class Weights:
|
||||||
else:
|
else:
|
||||||
g_idx = None
|
g_idx = None
|
||||||
|
|
||||||
bits, groupsize = self._get_gptq_params()
|
bits, groupsize, desc_act = self._get_gptq_params()
|
||||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||||
|
|
||||||
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq"
|
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||||
else:
|
else:
|
||||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
|
@ -240,11 +241,15 @@ class Weights:
|
||||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||||
if quantize == "gptq":
|
if quantize == "gptq":
|
||||||
use_exllama = True
|
use_exllama = True
|
||||||
bits, groupsize = self._get_gptq_params()
|
bits, groupsize, desc_act = self._get_gptq_params()
|
||||||
|
|
||||||
if bits != 4:
|
if bits != 4:
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
|
if desc_act:
|
||||||
|
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||||
|
use_exllama = False
|
||||||
|
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||||
if g_idx is not None:
|
if g_idx is not None:
|
||||||
|
@ -274,12 +279,18 @@ class Weights:
|
||||||
if use_exllama:
|
if use_exllama:
|
||||||
if not HAS_EXLLAMA:
|
if not HAS_EXLLAMA:
|
||||||
if CAN_EXLLAMA:
|
if CAN_EXLLAMA:
|
||||||
logger.warning(
|
log_once(
|
||||||
|
logger.warning,
|
||||||
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True"
|
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True"
|
||||||
)
|
)
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
else:
|
else:
|
||||||
logger.info(f"Using exllama kernels v{HAS_EXLLAMA}")
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
f"Using exllama kernels v{HAS_EXLLAMA}"
|
||||||
|
)
|
||||||
|
|
||||||
|
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
|
|
||||||
if use_exllama and groupsize != -1:
|
if use_exllama and groupsize != -1:
|
||||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||||
|
@ -288,14 +299,12 @@ class Weights:
|
||||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||||
scales = self.get_tensor(f"{prefix}.scales")
|
scales = self.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
|
||||||
|
|
||||||
if use_exllama:
|
if use_exllama:
|
||||||
g_idx = g_idx - g_idx[0]
|
g_idx = g_idx - g_idx[0]
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||||
elif quantize == "awq":
|
elif quantize == "awq":
|
||||||
bits, groupsize = self._get_gptq_params()
|
bits, groupsize, _ = self._get_gptq_params()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
|
@ -314,18 +323,20 @@ class Weights:
|
||||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def _get_gptq_params(self) -> Tuple[int, int]:
|
def _get_gptq_params(self) -> Tuple[int, int, int]:
|
||||||
try:
|
try:
|
||||||
bits = self.get_tensor("gptq_bits").item()
|
bits = self.get_tensor("gptq_bits").item()
|
||||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||||
|
desc_act = False
|
||||||
except (SafetensorError, RuntimeError) as e:
|
except (SafetensorError, RuntimeError) as e:
|
||||||
try:
|
try:
|
||||||
bits = self.gptq_bits
|
bits = self.gptq_bits
|
||||||
groupsize = self.gptq_groupsize
|
groupsize = self.gptq_groupsize
|
||||||
|
desc_act = getattr(self, "gptq_desc_act", False)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return bits, groupsize
|
return bits, groupsize, desc_act
|
||||||
|
|
||||||
def _set_gptq_params(self, model_id, revision):
|
def _set_gptq_params(self, model_id, revision):
|
||||||
filename = "config.json"
|
filename = "config.json"
|
||||||
|
@ -340,6 +351,7 @@ class Weights:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
self.gptq_bits = data["quantization_config"]["bits"]
|
self.gptq_bits = data["quantization_config"]["bits"]
|
||||||
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
||||||
|
self.gptq_desc_act = data["quantization_config"]["desc_act"]
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = "quantize_config.json"
|
filename = "quantize_config.json"
|
||||||
try:
|
try:
|
||||||
|
@ -353,6 +365,7 @@ class Weights:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
self.gptq_bits = data["bits"]
|
self.gptq_bits = data["bits"]
|
||||||
self.gptq_groupsize = data["group_size"]
|
self.gptq_groupsize = data["group_size"]
|
||||||
|
self.gptq_desc_act = data["desc_act"]
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = "quant_config.json"
|
filename = "quant_config.json"
|
||||||
try:
|
try:
|
||||||
|
@ -366,5 +379,6 @@ class Weights:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
self.gptq_bits = data["w_bit"]
|
self.gptq_bits = data["w_bit"]
|
||||||
self.gptq_groupsize = data["q_group_size"]
|
self.gptq_groupsize = data["q_group_size"]
|
||||||
|
self.gptq_desc_act = data["desc_act"]
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue