Add support for Marlin 2:4 sparsity (#2102)
This change adds support for 2:4 sparsity when using Marlin quantization. The 2:4 kernel is used when: * The quantizer is `marlin`; * the quantizer checkpoint format is `marlin_24`. Fixes #2098.
This commit is contained in:
parent
14980df2df
commit
f1f98e369f
|
@ -19,6 +19,23 @@ def gptq_marlin_gemm(
|
|||
"""
|
||||
...
|
||||
|
||||
def gptq_marlin_24_gemm(
|
||||
a: torch.Tensor,
|
||||
b_q_weight: torch.Tensor,
|
||||
b_meta: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_m: int,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Matrix multiplication using Marlin kernels. This is an extension of
|
||||
`marlin_gemm` that supports 2:4 sparsity.
|
||||
"""
|
||||
...
|
||||
|
||||
def gptq_marlin_repack(
|
||||
b_q_weight: torch.Tensor,
|
||||
perm: torch.Tensor,
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("gptq_marlin_gemm", &gptq_marlin_gemm,
|
||||
"Marlin gemm with GPTQ compatibility");
|
||||
m.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin sparse 2:4 gemm");
|
||||
m.def("gptq_marlin_repack", &gptq_marlin_repack,
|
||||
"Repack GPTQ parameters for Marlin");
|
||||
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");
|
||||
|
|
|
@ -12,6 +12,13 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full);
|
||||
|
||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||
torch::Tensor &b_meta,
|
||||
torch::Tensor &b_scales,
|
||||
torch::Tensor &workspace, int64_t num_bits,
|
||||
int64_t size_m, int64_t size_n,
|
||||
int64_t size_k);
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits);
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
|
||||
* Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace marlin_24 {
|
||||
|
||||
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
|
||||
// for instance as inputs to tensor core operations. Consequently, all
|
||||
// corresponding index accesses must be compile-time constants, which is why we
|
||||
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
||||
// this.
|
||||
template <typename T, int n>
|
||||
struct Vec {
|
||||
T elems[n];
|
||||
__device__ T& operator[](int i) { return elems[i]; }
|
||||
};
|
||||
|
||||
template <int M_, int N_, int K_>
|
||||
struct ShapeBase {
|
||||
static constexpr int M = M_, N = N_, K = K_;
|
||||
};
|
||||
|
||||
using I4 = Vec<int, 4>;
|
||||
|
||||
// Matrix fragments for tensor core instructions; their precise layout is
|
||||
// documented here:
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
||||
using FragA = Vec<half2, 4>;
|
||||
using FragB = Vec<half2, 2>;
|
||||
using FragM = Vec<uint, 1>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<half2, 1>; // quantization scales
|
||||
|
||||
} // namespace marlin_24
|
|
@ -0,0 +1,136 @@
|
|||
/*
|
||||
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
|
||||
* Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "base.h"
|
||||
|
||||
namespace marlin_24 {
|
||||
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
||||
// predication to handle batchsizes that are not multiples of 16.
|
||||
__device__ inline void cp_async4_pred_zfill(void* smem_ptr,
|
||||
const void* glob_ptr,
|
||||
bool pred = true,
|
||||
const bool zfill = false) {
|
||||
const int BYTES = 16;
|
||||
int src_in_bytes = (zfill ? 0 : BYTES);
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
// Asynchronous global->shared copy
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
// Async copy fence.
|
||||
__device__ inline void cp_async_fence() {
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
}
|
||||
|
||||
// Wait until at most `n` async copy stages are still pending.
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
// memory, directly in tensor core layout.
|
||||
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
|
||||
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
||||
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
||||
: "r"(smem));
|
||||
}
|
||||
|
||||
__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
|
||||
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
|
||||
: "=r"(a[0]), "=r"(a[1])
|
||||
: "r"(smem));
|
||||
}
|
||||
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
// memory, directly in tensor core layout.
|
||||
__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
|
||||
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
||||
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
||||
: "r"(smem));
|
||||
}
|
||||
|
||||
// Wait until barrier reaches `count`, then lock for current threadblock.
|
||||
__device__ inline void barrier_acquire(int* lock, int count) {
|
||||
if (threadIdx.x == 0) {
|
||||
int state = -1;
|
||||
do
|
||||
// Guarantee that subsequent writes by this threadblock will be visible
|
||||
// globally.
|
||||
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
|
||||
: "=r"(state)
|
||||
: "l"(lock));
|
||||
while (state != count);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Release barrier and increment visitation count.
|
||||
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
if (reset) {
|
||||
lock[0] = 0;
|
||||
return;
|
||||
}
|
||||
int val = 1;
|
||||
// Make sure that all writes since acquiring this barrier are visible
|
||||
// globally, while releasing the barrier.
|
||||
asm volatile("fence.acq_rel.gpu;\n");
|
||||
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
|
||||
:
|
||||
: "l"(lock), "r"(val));
|
||||
}
|
||||
}
|
||||
} // namespace marlin_24
|
|
@ -0,0 +1,191 @@
|
|||
/*
|
||||
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
|
||||
* Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "base.h"
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
namespace marlin_24 {
|
||||
|
||||
// On CUDA earlier than 12.5, the ordered_metadata version of this instruction
|
||||
// is not supported. On later versions of CUDA the version without ordered
|
||||
// metadata results in the following warning:
|
||||
// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction
|
||||
// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially
|
||||
// | reduced performance on some future architectures
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
|
||||
#define MMA_SP_INST \
|
||||
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||
#else
|
||||
#define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||
#endif
|
||||
|
||||
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
|
||||
const FragA& frag_b, FragC& frag_c, FragM& frag_m,
|
||||
const int psel) {
|
||||
const uint32_t* a0 = reinterpret_cast<const uint32_t*>(&a_frag0);
|
||||
const uint32_t* a1 = reinterpret_cast<const uint32_t*>(&a_frag1);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
|
||||
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
if (psel == 0) {
|
||||
asm volatile(MMA_SP_INST
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
||||
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
||||
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
||||
asm volatile(MMA_SP_INST
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
||||
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
||||
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
||||
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
||||
} else {
|
||||
asm volatile(MMA_SP_INST
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
||||
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
||||
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
||||
asm volatile(MMA_SP_INST
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
||||
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
||||
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
||||
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup-table based 3-input logical operation; explicitly used for
|
||||
// dequantization as the compiler does not seem to automatically recognize it in
|
||||
// all cases.
|
||||
template <int lut>
|
||||
__device__ inline int lop3(int a, int b, int c) {
|
||||
int res;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||
return res;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2,
|
||||
float c3) {
|
||||
uint2 r;
|
||||
asm("{\n\t"
|
||||
".reg .f16 a, b, c, d; \n\t"
|
||||
"cvt.rn.f16.f32 a, %2; \n\t"
|
||||
"cvt.rn.f16.f32 b, %3; \n\t"
|
||||
"cvt.rn.f16.f32 c, %4; \n\t"
|
||||
"cvt.rn.f16.f32 d, %5; \n\t"
|
||||
"mov.b32 %0, {a, b}; \n\t"
|
||||
"mov.b32 %1, {c, d}; \n\t"
|
||||
"}"
|
||||
: "=r"(r.x), "=r"(r.y)
|
||||
: "f"(c0), "f"(c1), "f"(c2), "f"(c3));
|
||||
return r;
|
||||
}
|
||||
|
||||
// Constructs destination register by taking bytes from 2 sources (based on
|
||||
// mask)
|
||||
template <int start_byte, int mask>
|
||||
__device__ inline uint32_t prmt(uint32_t a) {
|
||||
uint32_t res;
|
||||
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "n"(start_byte), "n"(mask));
|
||||
return res;
|
||||
}
|
||||
|
||||
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
||||
// values. We mostly follow the strategy in the link below, with some small
|
||||
// changes:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
__device__ inline FragB dequant_4bit(int q) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd480d480;
|
||||
|
||||
FragB frag_b;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
||||
// values. We mostly follow the strategy in the link below, with some small
|
||||
// changes:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
__device__ inline FragB dequant_8bit(int q) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||
|
||||
FragB frag_b;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
// Multiply dequantized values by the corresponding quantization scale; used
|
||||
// only for grouped quantization.
|
||||
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
|
||||
half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
|
||||
frag_b[0] = __hmul2(frag_b[0], s);
|
||||
frag_b[1] = __hmul2(frag_b[1], s);
|
||||
}
|
||||
|
||||
__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
|
||||
FragS& s0, float* c4, float* c5, float* c6,
|
||||
float* c7, FragS& s1) {
|
||||
*c0 = __fmul_rn(*c0, __half2float(s0[0].x));
|
||||
*c1 = __fmul_rn(*c1, __half2float(s0[0].y));
|
||||
*c2 = __fmul_rn(*c2, __half2float(s0[1].x));
|
||||
*c3 = __fmul_rn(*c3, __half2float(s0[1].y));
|
||||
|
||||
*c4 = __fmul_rn(*c4, __half2float(s1[0].x));
|
||||
*c5 = __fmul_rn(*c5, __half2float(s1[0].y));
|
||||
*c6 = __fmul_rn(*c6, __half2float(s1[1].x));
|
||||
*c7 = __fmul_rn(*c7, __half2float(s1[1].y));
|
||||
}
|
||||
|
||||
} // namespace marlin_24
|
File diff suppressed because it is too large
Load Diff
|
@ -12,6 +12,7 @@ setup(
|
|||
"marlin_kernels/gptq_marlin.cu",
|
||||
"marlin_kernels/gptq_marlin_repack.cu",
|
||||
"marlin_kernels/marlin_cuda_kernel.cu",
|
||||
"marlin_kernels/sparse/marlin_24_cuda_kernel.cu",
|
||||
"marlin_kernels/ext.cpp",
|
||||
],
|
||||
extra_compile_args=extra_compile_args,
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from typing import Optional
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from text_generation_server.layers.marlin import GPTQMarlinLinear
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
|
@ -225,6 +224,9 @@ def get_linear(weight, bias, quantize):
|
|||
)
|
||||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlin24Linear,
|
||||
GPTQMarlin24Weight,
|
||||
GPTQMarlinLinear,
|
||||
GPTQMarlinWeight,
|
||||
MarlinLinear,
|
||||
MarlinWeight,
|
||||
|
@ -235,6 +237,11 @@ def get_linear(weight, bias, quantize):
|
|||
weight=weight,
|
||||
bias=bias,
|
||||
)
|
||||
elif isinstance(weight, GPTQMarlin24Weight):
|
||||
linear = GPTQMarlin24Linear(
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
)
|
||||
elif isinstance(weight, MarlinWeight):
|
||||
linear = MarlinLinear(weight=weight, bias=bias)
|
||||
else:
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, List
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
try:
|
||||
|
@ -177,12 +176,12 @@ class GPTQMarlinLinear(nn.Module):
|
|||
self.bits = weight.bits
|
||||
self.is_full_k = weight.is_full_k
|
||||
|
||||
self.register_buffer("qweight", weight.qweight)
|
||||
self.register_buffer("scales", weight.scales)
|
||||
self.register_buffer("g_idx", weight.g_idx)
|
||||
self.register_buffer("perm", weight.perm)
|
||||
self.qweight = weight.qweight
|
||||
self.scales = weight.scales
|
||||
self.g_idx = weight.g_idx
|
||||
self.perm = weight.perm
|
||||
if bias is not None:
|
||||
self.register_buffer("bias", bias)
|
||||
self.bias = bias
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
@ -215,6 +214,116 @@ class GPTQMarlinLinear(nn.Module):
|
|||
return C
|
||||
|
||||
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
||||
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL = 64
|
||||
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQMarlin24Weight:
|
||||
"""
|
||||
GPTQ-Marlin 2:4 weights.
|
||||
|
||||
Attributes:
|
||||
B (torch.Tensor): int4-quantized weights packed into int32.
|
||||
B_meta (torch.Tensor): metadata for 2:4 sparsity.
|
||||
s (torch.Tensor): float16 scales.
|
||||
bits: quantized weight size.
|
||||
"""
|
||||
|
||||
B: torch.Tensor
|
||||
B_meta: torch.Tensor
|
||||
s: torch.Tensor
|
||||
bits: int
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.B.dtype == torch.int32
|
||||
assert self.B_meta.dtype == torch.int16
|
||||
assert self.s.dtype == torch.float16
|
||||
|
||||
|
||||
class GPTQMarlin24Linear(nn.Module):
|
||||
def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]):
|
||||
super().__init__()
|
||||
|
||||
_check_marlin_kernels()
|
||||
assert marlin_kernels is not None
|
||||
|
||||
if weight.bits not in GPTQ_MARLIN_BITS:
|
||||
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
|
||||
raise RuntimeError(
|
||||
f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}"
|
||||
)
|
||||
|
||||
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2
|
||||
out_features = weight.s.shape[1]
|
||||
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
|
||||
|
||||
if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
|
||||
supported_sizes = ", ".join(
|
||||
str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Group size {groupsize} is not supported, must be one of: {supported_sizes}"
|
||||
)
|
||||
|
||||
self.bits = weight.bits
|
||||
weights_per_int32 = 32 // self.bits
|
||||
|
||||
assert (
|
||||
out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0
|
||||
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads"
|
||||
assert (
|
||||
out_features % weights_per_int32 == 0
|
||||
), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})"
|
||||
|
||||
assert (
|
||||
in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0
|
||||
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads"
|
||||
if groupsize != -1 and in_features % groupsize != 0:
|
||||
raise ValueError(
|
||||
f"Number of input features ({in_features}) not divisable by group size ({groupsize})"
|
||||
)
|
||||
|
||||
self.B = weight.B
|
||||
self.B_meta = weight.B_meta
|
||||
self.s = weight.s
|
||||
if bias is not None:
|
||||
self.bias = bias
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.workspace = torch.zeros(
|
||||
(out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL,
|
||||
dtype=torch.int,
|
||||
device=weight.B.device,
|
||||
)
|
||||
|
||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||
assert marlin_kernels is not None
|
||||
|
||||
C = marlin_kernels.gptq_marlin_24_gemm(
|
||||
A.view(-1, A.shape[-1]),
|
||||
self.B,
|
||||
self.B_meta,
|
||||
self.s,
|
||||
self.workspace,
|
||||
self.bits,
|
||||
A.shape[0],
|
||||
self.s.shape[1],
|
||||
A.shape[1],
|
||||
)
|
||||
|
||||
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
|
||||
|
||||
if self.bias is not None:
|
||||
C += self.bias
|
||||
|
||||
return C
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarlinWeight:
|
||||
"""
|
||||
|
@ -255,10 +364,10 @@ class MarlinLinear(nn.Module):
|
|||
128,
|
||||
}, f"Group size must be -1 or 128, was {groupsize}"
|
||||
|
||||
self.register_buffer("B", weight.B)
|
||||
self.register_buffer("s", weight.s)
|
||||
self.B = weight.B
|
||||
self.s = weight.s
|
||||
if bias is not None:
|
||||
self.register_buffer("bias", bias)
|
||||
self.bias = bias
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from text_generation_server.utils.log import log_once
|
|||
@dataclass
|
||||
class _GPTQParams:
|
||||
bits: int
|
||||
checkpoint_format: Optional[str]
|
||||
groupsize: int
|
||||
desc_act: bool
|
||||
quant_method: str
|
||||
|
@ -263,12 +264,29 @@ class Weights:
|
|||
)
|
||||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlin24Weight,
|
||||
MarlinWeight,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
quant_method = getattr(self, "quant_method", "marlin")
|
||||
if quant_method == "gptq":
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
B = self.get_packed_sharded(
|
||||
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
B_meta = self.get_packed_sharded(
|
||||
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
s = self.get_packed_sharded(
|
||||
f"{prefix}.s", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
weight = GPTQMarlin24Weight(
|
||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||
)
|
||||
elif quant_method == "gptq":
|
||||
gptq_params = self._get_gptq_params()
|
||||
try:
|
||||
qweight = self.get_packed_sharded(
|
||||
|
@ -293,7 +311,6 @@ class Weights:
|
|||
sym=gptq_params.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
else:
|
||||
B = self.get_packed_sharded(
|
||||
f"{prefix}.B", dim=1, block_sizes=block_sizes
|
||||
|
@ -406,12 +423,36 @@ class Weights:
|
|||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlin24Weight,
|
||||
MarlinWeight,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
quant_method = getattr(self, "quant_method", "marlin")
|
||||
if quant_method == "gptq":
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
try:
|
||||
B = torch.cat(
|
||||
[self.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
B_meta = torch.cat(
|
||||
[self.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
s = torch.cat(
|
||||
[self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
weight = GPTQMarlin24Weight(
|
||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||
)
|
||||
elif quant_method == "gptq":
|
||||
gptq_params = self._get_gptq_params()
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
|
@ -629,12 +670,35 @@ class Weights:
|
|||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlin24Weight,
|
||||
MarlinWeight,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
quant_method = getattr(self, "quant_method", "marlin")
|
||||
if quant_method == "gptq":
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
try:
|
||||
B = self.get_sharded(f"{prefix}.B_24", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
B_meta = self.get_sharded(f"{prefix}.B_meta", dim=0)
|
||||
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
|
||||
if num_groups == 1:
|
||||
# The number of groups is 1 when groupsize == -1. share
|
||||
# scales between all shards in this case.
|
||||
s = self.get_tensor(f"{prefix}.s")
|
||||
else:
|
||||
s = self.get_sharded(f"{prefix}.s", dim=0)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
weight = GPTQMarlin24Weight(
|
||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||
)
|
||||
elif quant_method == "gptq":
|
||||
log_once(logger.info, "Converting GPTQ model to Marlin packing format.")
|
||||
gptq_params = self._get_gptq_params()
|
||||
|
||||
|
@ -668,7 +732,7 @@ class Weights:
|
|||
B = self.get_sharded(f"{prefix}.B", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
"Cannot load `marlin` weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
|
||||
|
@ -688,6 +752,7 @@ class Weights:
|
|||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
|
||||
desc_act = False
|
||||
sym = True
|
||||
quant_method = "gptq"
|
||||
|
@ -695,6 +760,7 @@ class Weights:
|
|||
try:
|
||||
bits = self.gptq_bits
|
||||
groupsize = self.gptq_groupsize
|
||||
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
|
||||
desc_act = getattr(self, "gptq_desc_act", False)
|
||||
quant_method = getattr(self, "quant_method", "gptq")
|
||||
sym = getattr(self, "sym", True)
|
||||
|
@ -703,6 +769,7 @@ class Weights:
|
|||
|
||||
return _GPTQParams(
|
||||
bits=bits,
|
||||
checkpoint_format=checkpoint_format,
|
||||
desc_act=desc_act,
|
||||
groupsize=groupsize,
|
||||
quant_method=quant_method,
|
||||
|
@ -724,6 +791,9 @@ class Weights:
|
|||
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
||||
# Order is important here, desc_act is missing on some real models
|
||||
self.quant_method = data["quantization_config"]["quant_method"]
|
||||
self.gptq_checkpoint_format = data["quantization_config"].get(
|
||||
"checkpoint_format"
|
||||
)
|
||||
self.gptq_sym = data["quantization_config"]["sym"]
|
||||
self.gptq_desc_act = data["quantization_config"]["desc_act"]
|
||||
except Exception:
|
||||
|
|
Loading…
Reference in New Issue