Add support for Marlin 2:4 sparsity (#2102)

This change adds support for 2:4 sparsity when using Marlin
quantization. The 2:4 kernel is used when:

* The quantizer is `marlin`;
* the quantizer checkpoint format is `marlin_24`.

Fixes #2098.
This commit is contained in:
Daniël de Kok 2024-06-25 21:09:42 +02:00 committed by GitHub
parent 14980df2df
commit f1f98e369f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1731 additions and 16 deletions

View File

@ -19,6 +19,23 @@ def gptq_marlin_gemm(
""" """
... ...
def gptq_marlin_24_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_meta: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
"""
Matrix multiplication using Marlin kernels. This is an extension of
`marlin_gemm` that supports 2:4 sparsity.
"""
...
def gptq_marlin_repack( def gptq_marlin_repack(
b_q_weight: torch.Tensor, b_q_weight: torch.Tensor,
perm: torch.Tensor, perm: torch.Tensor,

View File

@ -5,6 +5,7 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, m.def("gptq_marlin_gemm", &gptq_marlin_gemm,
"Marlin gemm with GPTQ compatibility"); "Marlin gemm with GPTQ compatibility");
m.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin sparse 2:4 gemm");
m.def("gptq_marlin_repack", &gptq_marlin_repack, m.def("gptq_marlin_repack", &gptq_marlin_repack,
"Repack GPTQ parameters for Marlin"); "Repack GPTQ parameters for Marlin");
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm"); m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");

View File

@ -12,6 +12,13 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
int64_t num_bits, int64_t size_m, int64_t size_n, int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full); int64_t size_k, bool is_k_full);
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_meta,
torch::Tensor &b_scales,
torch::Tensor &workspace, int64_t num_bits,
int64_t size_m, int64_t size_n,
int64_t size_k);
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
int64_t size_k, int64_t size_n, int64_t size_k, int64_t size_n,
int64_t num_bits); int64_t num_bits);

View File

@ -0,0 +1,51 @@
/*
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
* Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
namespace marlin_24 {
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template <typename T, int n>
struct Vec {
T elems[n];
__device__ T& operator[](int i) { return elems[i]; }
};
template <int M_, int N_, int K_>
struct ShapeBase {
static constexpr int M = M_, N = N_, K = K_;
};
using I4 = Vec<int, 4>;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>;
using FragM = Vec<uint, 1>;
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; // quantization scales
} // namespace marlin_24

View File

@ -0,0 +1,136 @@
/*
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
* Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "base.h"
namespace marlin_24 {
// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__ inline void cp_async4_pred_zfill(void* smem_ptr,
const void* glob_ptr,
bool pred = true,
const bool zfill = false) {
const int BYTES = 16;
int src_in_bytes = (zfill ? 0 : BYTES);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
}
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
// Asynchronous global->shared copy
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
}
// Async copy fence.
__device__ inline void cp_async_fence() {
asm volatile("cp.async.commit_group;\n" ::);
}
// Wait until at most `n` async copy stages are still pending.
template <int n>
__device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
: "r"(smem));
}
__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
: "=r"(a[0]), "=r"(a[1])
: "r"(smem));
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
: "r"(smem));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) {
int state = -1;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
: "=r"(state)
: "l"(lock));
while (state != count);
}
__syncthreads();
}
// Release barrier and increment visitation count.
__device__ inline void barrier_release(int* lock, bool reset = false) {
__syncthreads();
if (threadIdx.x == 0) {
if (reset) {
lock[0] = 0;
return;
}
int val = 1;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
:
: "l"(lock), "r"(val));
}
}
} // namespace marlin_24

View File

@ -0,0 +1,191 @@
/*
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
* Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "base.h"
#include <cudaTypedefs.h>
namespace marlin_24 {
// On CUDA earlier than 12.5, the ordered_metadata version of this instruction
// is not supported. On later versions of CUDA the version without ordered
// metadata results in the following warning:
// | Advisory: Modifier .sp::ordered_metadata should be used on instruction
// | mma instead of modifier .sp as it is expected to have substantially
// | reduced performance on some future architectures
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
#define MMA_SP_INST \
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
#else
#define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
#endif
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
const FragA& frag_b, FragC& frag_c, FragM& frag_m,
const int psel) {
const uint32_t* a0 = reinterpret_cast<const uint32_t*>(&a_frag0);
const uint32_t* a1 = reinterpret_cast<const uint32_t*>(&a_frag1);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
float* c = reinterpret_cast<float*>(&frag_c);
if (psel == 0) {
asm volatile(MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x0;\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
"f"(c[2]), "f"(c[3]), "r"(e[0]));
asm volatile(MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x0;\n"
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
"f"(c[6]), "f"(c[7]), "r"(e[0]));
} else {
asm volatile(MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x1;\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
"f"(c[2]), "f"(c[3]), "r"(e[0]));
asm volatile(MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x1;\n"
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
"f"(c[6]), "f"(c[7]), "r"(e[0]));
}
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
__device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2,
float c3) {
uint2 r;
asm("{\n\t"
".reg .f16 a, b, c, d; \n\t"
"cvt.rn.f16.f32 a, %2; \n\t"
"cvt.rn.f16.f32 b, %3; \n\t"
"cvt.rn.f16.f32 c, %4; \n\t"
"cvt.rn.f16.f32 d, %5; \n\t"
"mov.b32 %0, {a, b}; \n\t"
"mov.b32 %1, {c, d}; \n\t"
"}"
: "=r"(r.x), "=r"(r.y)
: "f"(c0), "f"(c1), "f"(c2), "f"(c3));
return r;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
: "=r"(res)
: "r"(a), "n"(start_byte), "n"(mask));
return res;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__ inline FragB dequant_4bit(int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__ inline FragB dequant_8bit(int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
frag_b[0] = __hmul2(frag_b[0], s);
frag_b[1] = __hmul2(frag_b[1], s);
}
__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
FragS& s0, float* c4, float* c5, float* c6,
float* c7, FragS& s1) {
*c0 = __fmul_rn(*c0, __half2float(s0[0].x));
*c1 = __fmul_rn(*c1, __half2float(s0[0].y));
*c2 = __fmul_rn(*c2, __half2float(s0[1].x));
*c3 = __fmul_rn(*c3, __half2float(s0[1].y));
*c4 = __fmul_rn(*c4, __half2float(s1[0].x));
*c5 = __fmul_rn(*c5, __half2float(s1[0].y));
*c6 = __fmul_rn(*c6, __half2float(s1[1].x));
*c7 = __fmul_rn(*c7, __half2float(s1[1].y));
}
} // namespace marlin_24

File diff suppressed because it is too large Load Diff

View File

@ -12,6 +12,7 @@ setup(
"marlin_kernels/gptq_marlin.cu", "marlin_kernels/gptq_marlin.cu",
"marlin_kernels/gptq_marlin_repack.cu", "marlin_kernels/gptq_marlin_repack.cu",
"marlin_kernels/marlin_cuda_kernel.cu", "marlin_kernels/marlin_cuda_kernel.cu",
"marlin_kernels/sparse/marlin_24_cuda_kernel.cu",
"marlin_kernels/ext.cpp", "marlin_kernels/ext.cpp",
], ],
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,

View File

@ -1,7 +1,6 @@
from typing import Optional from typing import Optional
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from text_generation_server.layers.marlin import GPTQMarlinLinear
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "rocm": if SYSTEM == "rocm":
@ -225,6 +224,9 @@ def get_linear(weight, bias, quantize):
) )
elif quantize == "marlin": elif quantize == "marlin":
from text_generation_server.layers.marlin import ( from text_generation_server.layers.marlin import (
GPTQMarlin24Linear,
GPTQMarlin24Weight,
GPTQMarlinLinear,
GPTQMarlinWeight, GPTQMarlinWeight,
MarlinLinear, MarlinLinear,
MarlinWeight, MarlinWeight,
@ -235,6 +237,11 @@ def get_linear(weight, bias, quantize):
weight=weight, weight=weight,
bias=bias, bias=bias,
) )
elif isinstance(weight, GPTQMarlin24Weight):
linear = GPTQMarlin24Linear(
weight=weight,
bias=bias,
)
elif isinstance(weight, MarlinWeight): elif isinstance(weight, MarlinWeight):
linear = MarlinLinear(weight=weight, bias=bias) linear = MarlinLinear(weight=weight, bias=bias)
else: else:

View File

@ -1,9 +1,8 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, List from typing import List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
try: try:
@ -177,12 +176,12 @@ class GPTQMarlinLinear(nn.Module):
self.bits = weight.bits self.bits = weight.bits
self.is_full_k = weight.is_full_k self.is_full_k = weight.is_full_k
self.register_buffer("qweight", weight.qweight) self.qweight = weight.qweight
self.register_buffer("scales", weight.scales) self.scales = weight.scales
self.register_buffer("g_idx", weight.g_idx) self.g_idx = weight.g_idx
self.register_buffer("perm", weight.perm) self.perm = weight.perm
if bias is not None: if bias is not None:
self.register_buffer("bias", bias) self.bias = bias
else: else:
self.bias = None self.bias = None
@ -215,6 +214,116 @@ class GPTQMarlinLinear(nn.Module):
return C return C
GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 64
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
@dataclass
class GPTQMarlin24Weight:
"""
GPTQ-Marlin 2:4 weights.
Attributes:
B (torch.Tensor): int4-quantized weights packed into int32.
B_meta (torch.Tensor): metadata for 2:4 sparsity.
s (torch.Tensor): float16 scales.
bits: quantized weight size.
"""
B: torch.Tensor
B_meta: torch.Tensor
s: torch.Tensor
bits: int
def __post_init__(self):
assert self.B.dtype == torch.int32
assert self.B_meta.dtype == torch.int16
assert self.s.dtype == torch.float16
class GPTQMarlin24Linear(nn.Module):
def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
if weight.bits not in GPTQ_MARLIN_BITS:
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
raise RuntimeError(
f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}"
)
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2
out_features = weight.s.shape[1]
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
supported_sizes = ", ".join(
str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
)
raise RuntimeError(
f"Group size {groupsize} is not supported, must be one of: {supported_sizes}"
)
self.bits = weight.bits
weights_per_int32 = 32 // self.bits
assert (
out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads"
assert (
out_features % weights_per_int32 == 0
), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})"
assert (
in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads"
if groupsize != -1 and in_features % groupsize != 0:
raise ValueError(
f"Number of input features ({in_features}) not divisable by group size ({groupsize})"
)
self.B = weight.B
self.B_meta = weight.B_meta
self.s = weight.s
if bias is not None:
self.bias = bias
else:
self.bias = None
self.workspace = torch.zeros(
(out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL,
dtype=torch.int,
device=weight.B.device,
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
C = marlin_kernels.gptq_marlin_24_gemm(
A.view(-1, A.shape[-1]),
self.B,
self.B_meta,
self.s,
self.workspace,
self.bits,
A.shape[0],
self.s.shape[1],
A.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
if self.bias is not None:
C += self.bias
return C
@dataclass @dataclass
class MarlinWeight: class MarlinWeight:
""" """
@ -255,10 +364,10 @@ class MarlinLinear(nn.Module):
128, 128,
}, f"Group size must be -1 or 128, was {groupsize}" }, f"Group size must be -1 or 128, was {groupsize}"
self.register_buffer("B", weight.B) self.B = weight.B
self.register_buffer("s", weight.s) self.s = weight.s
if bias is not None: if bias is not None:
self.register_buffer("bias", bias) self.bias = bias
else: else:
self.bias = None self.bias = None

View File

@ -13,6 +13,7 @@ from text_generation_server.utils.log import log_once
@dataclass @dataclass
class _GPTQParams: class _GPTQParams:
bits: int bits: int
checkpoint_format: Optional[str]
groupsize: int groupsize: int
desc_act: bool desc_act: bool
quant_method: str quant_method: str
@ -263,12 +264,29 @@ class Weights:
) )
elif quantize == "marlin": elif quantize == "marlin":
from text_generation_server.layers.marlin import ( from text_generation_server.layers.marlin import (
GPTQMarlin24Weight,
MarlinWeight, MarlinWeight,
repack_gptq_for_marlin, repack_gptq_for_marlin,
) )
quant_method = getattr(self, "quant_method", "marlin") quant_method = getattr(self, "quant_method", "marlin")
if quant_method == "gptq": is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
B = self.get_packed_sharded(
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
)
B_meta = self.get_packed_sharded(
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
)
s = self.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
gptq_params = self._get_gptq_params()
weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
)
elif quant_method == "gptq":
gptq_params = self._get_gptq_params() gptq_params = self._get_gptq_params()
try: try:
qweight = self.get_packed_sharded( qweight = self.get_packed_sharded(
@ -293,7 +311,6 @@ class Weights:
sym=gptq_params.sym, sym=gptq_params.sym,
sharded_infeatures=False, sharded_infeatures=False,
) )
else: else:
B = self.get_packed_sharded( B = self.get_packed_sharded(
f"{prefix}.B", dim=1, block_sizes=block_sizes f"{prefix}.B", dim=1, block_sizes=block_sizes
@ -406,12 +423,36 @@ class Weights:
elif quantize == "marlin": elif quantize == "marlin":
from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import ( from text_generation_server.layers.marlin import (
GPTQMarlin24Weight,
MarlinWeight, MarlinWeight,
repack_gptq_for_marlin, repack_gptq_for_marlin,
) )
quant_method = getattr(self, "quant_method", "marlin") quant_method = getattr(self, "quant_method", "marlin")
if quant_method == "gptq": is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = torch.cat(
[self.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)
B_meta = torch.cat(
[self.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
)
s = torch.cat(
[self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
gptq_params = self._get_gptq_params()
weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
)
elif quant_method == "gptq":
gptq_params = self._get_gptq_params() gptq_params = self._get_gptq_params()
try: try:
qweight = torch.cat( qweight = torch.cat(
@ -629,12 +670,35 @@ class Weights:
elif quantize == "marlin": elif quantize == "marlin":
from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import ( from text_generation_server.layers.marlin import (
GPTQMarlin24Weight,
MarlinWeight, MarlinWeight,
repack_gptq_for_marlin, repack_gptq_for_marlin,
) )
quant_method = getattr(self, "quant_method", "marlin") quant_method = getattr(self, "quant_method", "marlin")
if quant_method == "gptq": is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = self.get_sharded(f"{prefix}.B_24", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = self.get_sharded(f"{prefix}.B_meta", dim=0)
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = self.get_tensor(f"{prefix}.s")
else:
s = self.get_sharded(f"{prefix}.s", dim=0)
gptq_params = self._get_gptq_params()
weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
)
elif quant_method == "gptq":
log_once(logger.info, "Converting GPTQ model to Marlin packing format.") log_once(logger.info, "Converting GPTQ model to Marlin packing format.")
gptq_params = self._get_gptq_params() gptq_params = self._get_gptq_params()
@ -668,7 +732,7 @@ class Weights:
B = self.get_sharded(f"{prefix}.B", dim=0) B = self.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError: except RuntimeError:
raise RuntimeError( raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" "Cannot load `marlin` weight, make sure the model is already quantized."
) )
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
@ -688,6 +752,7 @@ class Weights:
try: try:
bits = self.get_tensor("gptq_bits").item() bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item() groupsize = self.get_tensor("gptq_groupsize").item()
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
desc_act = False desc_act = False
sym = True sym = True
quant_method = "gptq" quant_method = "gptq"
@ -695,6 +760,7 @@ class Weights:
try: try:
bits = self.gptq_bits bits = self.gptq_bits
groupsize = self.gptq_groupsize groupsize = self.gptq_groupsize
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
desc_act = getattr(self, "gptq_desc_act", False) desc_act = getattr(self, "gptq_desc_act", False)
quant_method = getattr(self, "quant_method", "gptq") quant_method = getattr(self, "quant_method", "gptq")
sym = getattr(self, "sym", True) sym = getattr(self, "sym", True)
@ -703,6 +769,7 @@ class Weights:
return _GPTQParams( return _GPTQParams(
bits=bits, bits=bits,
checkpoint_format=checkpoint_format,
desc_act=desc_act, desc_act=desc_act,
groupsize=groupsize, groupsize=groupsize,
quant_method=quant_method, quant_method=quant_method,
@ -724,6 +791,9 @@ class Weights:
self.gptq_groupsize = data["quantization_config"]["group_size"] self.gptq_groupsize = data["quantization_config"]["group_size"]
# Order is important here, desc_act is missing on some real models # Order is important here, desc_act is missing on some real models
self.quant_method = data["quantization_config"]["quant_method"] self.quant_method = data["quantization_config"]["quant_method"]
self.gptq_checkpoint_format = data["quantization_config"].get(
"checkpoint_format"
)
self.gptq_sym = data["quantization_config"]["sym"] self.gptq_sym = data["quantization_config"]["sym"]
self.gptq_desc_act = data["quantization_config"]["desc_act"] self.gptq_desc_act = data["quantization_config"]["desc_act"]
except Exception: except Exception: