Add support for GPTQ Marlin (#2052)

Add support for GPTQ Marlin kernels

GPTQ Marlin extends the Marlin kernels to support common GPTQ
configurations:

- bits: 4 or 8
- groupsize: -1, 32, 64, or 128
- desc_act: true/false

Using the GPTQ Marlin kernels requires repacking the parameters in the
Marlin quantizer format.

The kernels were contributed by Neural Magic to VLLM. We vendor them
here for convenience.
This commit is contained in:
Daniël de Kok 2024-06-14 09:45:42 +02:00 committed by GitHub
parent f433f1f770
commit 093a27c528
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 4654 additions and 140 deletions

View File

@ -140,9 +140,9 @@ RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
# Build marlin kernels
FROM kernel-builder as marlin-kernels-builder
WORKDIR /usr/src
COPY server/Makefile-marlin Makefile
COPY server/marlin/ .
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-marlin
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
# Build Transformers CUDA kernels
FROM kernel-builder as custom-kernels-builder
@ -213,7 +213,7 @@ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86
# Copy build artifacts from eetq kernels builder
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from marlin kernels builder
COPY --from=marlin-kernels-builder /usr/src/marlin/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

View File

@ -0,0 +1,84 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6230469,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.046875,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1425781,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.9238281,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.076660156,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10821533,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
}

View File

@ -0,0 +1,84 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": -2.2539062,
"special": false,
"text": "."
},
{
"id": 578,
"logprob": -0.15563965,
"special": false,
"text": " The"
},
{
"id": 3622,
"logprob": -0.8203125,
"special": false,
"text": " server"
},
{
"id": 706,
"logprob": 0.0,
"special": false,
"text": " has"
},
{
"id": 539,
"logprob": 0.0,
"special": false,
"text": " not"
},
{
"id": 3686,
"logprob": 0.0,
"special": false,
"text": " yet"
},
{
"id": 3288,
"logprob": 0.0,
"special": false,
"text": " sent"
},
{
"id": 904,
"logprob": 0.0,
"special": false,
"text": " any"
},
{
"id": 828,
"logprob": 0.0,
"special": false,
"text": " data"
},
{
"id": 382,
"logprob": -1.5517578,
"special": false,
"text": ".\n\n"
}
],
"top_tokens": null
},
"generated_text": "Test request. The server has not yet sent any data.\n\n"
}

View File

@ -0,0 +1,338 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
}
]

View File

@ -0,0 +1,65 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_gptq_marlin_handle(launcher):
with launcher(
"astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="marlin"
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle):
await flash_llama_gptq_marlin_handle.health(300)
return flash_llama_gptq_marlin_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
response = await flash_llama_gptq_marlin.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin_all_params(
flash_llama_gptq_marlin, response_snapshot
):
response = await flash_llama_gptq_marlin.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin_load(
flash_llama_gptq_marlin, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_gptq_marlin, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -3,7 +3,6 @@ include Makefile-flash-att-v2
include Makefile-vllm
include Makefile-awq
include Makefile-eetq
include Makefile-marlin
include Makefile-selective-scan
unit-tests:

View File

@ -1,11 +0,0 @@
marlin_commit := 2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c
build-marlin:
if [ ! -d 'marlin' ]; then \
pip install -U ninja packaging --no-cache-dir && \
git clone https://github.com/IST-DASLab/marlin.git marlin; \
fi
cd marlin && git fetch && git checkout $(marlin_commit) && python setup.py build
install-marlin: build-marlin
cd marlin && git fetch && git checkout $(marlin_commit) && pip install -e .

20
server/marlin/COPYRIGHT Normal file
View File

@ -0,0 +1,20 @@
These kernels were vendored from VLLM. The Marlin kernels were developed
by Elias Frantar and extended by Neural Magic.
---
Copyright (C) Marlin.2024 Elias Frantar
Modified by Neural Magic
Copyright 2024 The vLLM team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -0,0 +1,44 @@
import torch
def gptq_marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
size_m: int,
size_n: int,
size_k: int,
is_k_full: bool,
) -> torch.Tensor:
"""
Matrix multiplication using Marlin kernels. This is an extension of
`marlin_gemm` that supports converted GPTQ kernels.
"""
...
def gptq_marlin_repack(
b_q_weight: torch.Tensor,
perm: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
"""Repack GPTQ parameters for Marlin kernels."""
...
def marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
"""
Matrix multiplication using Marlin kernels.
"""
...

View File

@ -0,0 +1,11 @@
#include <torch/extension.h>
#include "ext.hh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gptq_marlin_gemm", &gptq_marlin_gemm,
"Marlin gemm with GPTQ compatibility");
m.def("gptq_marlin_repack", &gptq_marlin_repack,
"Repack GPTQ parameters for Marlin");
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");
}

View File

@ -0,0 +1,23 @@
#pragma once
#include <torch/library.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async
#else
torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &g_idx,
torch::Tensor &perm, torch::Tensor &workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full);
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);
torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &workspace,
int64_t size_m, int64_t size_n, int64_t size_k);
#endif

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,76 @@
#pragma once
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
namespace gptq_marlin {
// 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles.
static constexpr int default_threads = 256;
static constexpr int pipe_stages =
4; // 4 pipeline stages fit into shared memory
static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64;
static constexpr int tile_size = 16;
static constexpr int max_par = 16;
template <typename T, int n>
struct Vec {
T elems[n];
__device__ T& operator[](int i) { return elems[i]; }
};
using I4 = Vec<int, 4>;
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async
#else
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async_fence() {
asm volatile("cp.async.commit_group;\n" ::);
}
template <int n>
__device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}
#endif
} // namespace gptq_marlin

View File

@ -0,0 +1,77 @@
#ifndef _data_types_cuh
#define _data_types_cuh
#include "gptq_marlin.cuh"
#include <cuda_fp16.h>
#include <cuda_bf16.h>
namespace gptq_marlin {
template <typename scalar_t>
class ScalarType {};
template <>
class ScalarType<half> {
public:
using scalar_t = half;
using scalar_t2 = half2;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>;
static __device__ float inline num2float(const half x) {
return __half2float(x);
}
static __device__ half2 inline num2num2(const half x) {
return __half2half2(x);
}
static __device__ half2 inline nums2num2(const half x1, const half x2) {
return __halves2half2(x1, x2);
}
static __host__ __device__ half inline float2num(const float x) {
return __float2half(x);
}
};
template <>
class ScalarType<nv_bfloat16> {
public:
using scalar_t = nv_bfloat16;
using scalar_t2 = nv_bfloat162;
using FragA = Vec<nv_bfloat162, 4>;
using FragB = Vec<nv_bfloat162, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<nv_bfloat162, 1>;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) {
return __bfloat162float(x);
}
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
return __bfloat162bfloat162(x);
}
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
const nv_bfloat16 x2) {
return __halves2bfloat162(x1, x2);
}
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
return __float2bfloat16(x);
}
#endif
};
} // namespace gptq_marlin
#endif

View File

@ -0,0 +1,350 @@
#include "gptq_marlin.cuh"
namespace gptq_marlin {
static constexpr int repack_stages = 8;
static constexpr int repack_threads = 256;
static constexpr int tile_k_size = tile_size;
static constexpr int tile_n_size = tile_k_size * 4;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {}
} // namespace gptq_marlin
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
int start_k_tile = blockIdx.x * block_k_tiles;
if (start_k_tile >= k_tiles) {
return;
}
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait<repack_stages - 2>();
__syncthreads();
};
extern __shared__ int4 sh[];
constexpr int perm_size = tile_k_size / 4;
int4* sh_perm_ptr = sh;
int4* sh_pipe_ptr = sh_perm_ptr;
if constexpr (has_perm) {
sh_pipe_ptr += perm_size;
}
constexpr int tile_ints = tile_k_size / pack_factor;
constexpr int stage_n_threads = tile_n_size / 4;
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
constexpr int stage_size = stage_k_threads * stage_n_threads;
auto load_perm_to_shared = [&](int k_tile_id) {
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
if (threadIdx.x < perm_size) {
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
}
__syncthreads();
};
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
cp_async_fence();
return;
}
int first_n = n_tile_id * tile_n_size;
int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
if constexpr (has_perm) {
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
uint32_t const* sh_perm_int_ptr =
reinterpret_cast<uint32_t const*>(sh_perm_ptr);
int src_k = sh_perm_int_ptr[k_id];
int src_k_packed = src_k / pack_factor;
cp_async4(
&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const*>(&(
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
}
} else {
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size;
int first_k_packed = first_k / pack_factor;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const*>(
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
first_n + (n_id * 4)])));
}
}
cp_async_fence();
};
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
return;
}
int warp_id = threadIdx.x / 32;
int th_id = threadIdx.x % 32;
if (warp_id >= 4) {
return;
}
int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2;
constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col;
constexpr int sh_stride = 64;
constexpr uint32_t mask = (1 << num_bits) - 1;
int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
uint32_t* sh_perm_int_ptr = reinterpret_cast<uint32_t*>(sh_perm_ptr);
uint32_t vals[8];
if constexpr (has_perm) {
for (int i = 0; i < 4; i++) {
int k_idx = tc_row + tc_offsets[i];
uint32_t src_k = sh_perm_int_ptr[k_idx];
uint32_t src_k_pos = src_k % pack_factor;
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
vals[i] = b1_cur_val;
vals[4 + i] = b2_cur_val;
}
} else {
uint32_t b1_vals[tile_ints];
uint32_t b2_vals[tile_ints];
#pragma unroll
for (int i = 0; i < tile_ints; i++) {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
}
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
int cur_int = cur_elem / pack_factor;
int cur_pos = cur_elem % pack_factor;
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
}
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else {
constexpr int pack_idx[4] = {0, 2, 1, 3};
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
}
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
}
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
}
wait_for_stage();
};
#pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0;
if constexpr (has_perm) {
load_perm_to_shared(k_tile_id);
}
start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1);
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
wait_for_stage();
}
n_tile_id += repack_stages;
}
}
}
} // namespace gptq_marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
NUM_BITS, HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits;
// Verify B
TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k, ", pack_factor = ", pack_factor);
TORCH_CHECK(b_q_weight.size(1) == size_n,
"b_q_weight.size(1) = ", b_q_weight.size(1),
" is not size_n = ", size_n);
// Verify device and strides
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt");
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
torch::Tensor out =
torch::empty({size_k / gptq_marlin::tile_size,
size_n * gptq_marlin::tile_size / pack_factor},
options);
// Detect if there is act_order
bool has_perm = perm.size(0) != 0;
// Get ptrs
uint32_t const* b_q_weight_ptr =
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
// Get dev info
int dev = b_q_weight.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
int blocks;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
if (false) {
}
CALL_IF(4, false)
CALL_IF(4, true)
CALL_IF(8, false)
CALL_IF(8, true)
else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
", has_perm = ", has_perm);
}
return out;
}
#endif

File diff suppressed because it is too large Load Diff

View File

21
server/marlin/setup.py Normal file
View File

@ -0,0 +1,21 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
extra_compile_args = []
setup(
name="marlin_kernels",
ext_modules=[
CUDAExtension(
name="marlin_kernels",
sources=[
"marlin_kernels/gptq_marlin.cu",
"marlin_kernels/gptq_marlin_repack.cu",
"marlin_kernels/marlin_cuda_kernel.cu",
"marlin_kernels/ext.cpp",
],
extra_compile_args=extra_compile_args,
),
],
cmdclass={"build_ext": BuildExtension},
)

View File

@ -1,6 +1,7 @@
from typing import Optional
import torch
from torch.nn import functional as F
from text_generation_server.layers.marlin import GPTQMarlinLinear
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "rocm":
@ -223,13 +224,23 @@ def get_linear(weight, bias, quantize):
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
)
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinLinear, MarlinWeight
from text_generation_server.layers.marlin import (
GPTQMarlinWeight,
MarlinLinear,
MarlinWeight,
)
if not isinstance(weight, MarlinWeight):
if isinstance(weight, GPTQMarlinWeight):
linear = GPTQMarlinLinear(
weight=weight,
bias=bias,
)
elif isinstance(weight, MarlinWeight):
linear = MarlinLinear(weight=weight, bias=bias)
else:
raise NotImplementedError(
f"The passed weight is not `marlin` compatible, loader needs to be updated."
)
linear = MarlinLinear(B=weight.B, s=weight.s, bias=bias)
else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear

View File

@ -1,13 +1,15 @@
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
try:
import marlin
import marlin_kernels
except ImportError:
marlin = None
marlin_kernels = None
try:
major, _minor = torch.cuda.get_device_capability()
@ -15,9 +17,204 @@ try:
except Exception:
has_sm_8_0 = False
GPTQ_MARLIN_BITS = [4, 8]
GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
MARLIN_TILE_SIZE = 16
def _check_marlin_kernels():
if not (SYSTEM == "cuda" and has_sm_8_0):
raise NotImplementedError(
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
)
if marlin_kernels is None:
raise NotImplementedError(
"marlin is not installed, install it with: pip install server/marlin"
)
def _check_valid_shape(in_features: int, out_features: int):
if (in_features % 128 != 0 or out_features % 64 != 0) and (
in_features % 64 != 0 or out_features % 128 != 0
):
raise ValueError(
f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})."
" The shape elements must be divisible by (128, 64) or (64, 128)."
)
# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54
def _get_perms() -> Tuple[List[int], List[int]]:
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
_scale_perm, _scale_perm_single = _get_perms()
def permute_scales(scales: torch.Tensor):
out_features = scales.shape[1]
if scales.shape[0] == 1:
scales = scales.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
else:
scales = scales.reshape((-1, len(_scale_perm)))[:, _scale_perm]
return scales.reshape((-1, out_features)).contiguous()
@dataclass
class GPTQMarlinWeight:
"""
Repacked GPTQ Marlin weights.
"""
qweight: torch.Tensor
scales: torch.Tensor
g_idx: torch.Tensor
perm: torch.Tensor
bits: int
is_full_k: bool
def __post_init__(self):
assert self.qweight.dtype == torch.int32
assert self.scales.dtype == torch.float16
assert self.g_idx.dtype == torch.int32
assert self.perm.dtype == torch.int32
def repack_gptq_for_marlin(
*,
qweight: torch.Tensor,
scales: torch.Tensor,
g_idx: torch.Tensor,
bits: int,
desc_act: bool,
groupsize: int,
sym: bool,
sharded_infeatures: bool,
) -> GPTQMarlinWeight:
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
_check_marlin_kernels()
assert marlin_kernels is not None
if bits not in GPTQ_MARLIN_BITS:
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
raise RuntimeError(
f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}"
)
if groupsize not in GPTQ_MARLIN_GROUP_SIZES:
supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES)
raise RuntimeError(
f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
)
if not sym:
raise RuntimeError(
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
)
weights_per_int = 32 // bits
in_features = qweight.shape[0] * weights_per_int
out_features = qweight.shape[1]
if in_features % groupsize != 0:
raise ValueError(
f"Number of input features ({in_features}) not divisible by group size ({groupsize})"
)
if desc_act and groupsize != -1:
perm = torch.argsort(g_idx).to(torch.int)
g_idx = g_idx[perm]
else:
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
repacked = marlin_kernels.gptq_marlin_repack(
qweight, perm, in_features, out_features, bits
)
scales = permute_scales(scales)
is_full_k = not (desc_act and sharded_infeatures)
return GPTQMarlinWeight(
qweight=repacked,
scales=scales,
g_idx=g_idx,
perm=perm,
bits=bits,
is_full_k=is_full_k,
)
class GPTQMarlinLinear(nn.Module):
"""
Linear layer for GPTQ weights that were converted for the GPTQ-Marlin
kernels.
"""
def __init__(
self,
*,
weight: GPTQMarlinWeight,
bias: Optional[torch.Tensor],
):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
out_features = weight.scales.shape[1]
_check_valid_shape(in_features=in_features, out_features=out_features)
self.bits = weight.bits
self.is_full_k = weight.is_full_k
self.register_buffer("qweight", weight.qweight)
self.register_buffer("scales", weight.scales)
self.register_buffer("g_idx", weight.g_idx)
self.register_buffer("perm", weight.perm)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.gptq_marlin_gemm(
A_flat,
self.qweight,
self.scales,
self.g_idx,
self.perm,
self.workspace,
self.bits,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],
self.is_full_k,
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
if self.bias is not None:
C += self.bias
return C
@dataclass
class MarlinWeight:
"""
@ -31,28 +228,20 @@ class MarlinWeight:
B: torch.Tensor
s: torch.Tensor
def __post_init__(self):
assert self.B.dtype == torch.int32
assert self.s.dtype == torch.float16
class MarlinLinear(nn.Module):
def __init__(
self, *, B: torch.Tensor, s: torch.Tensor, bias: Optional[torch.Tensor]
):
def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]):
super().__init__()
if not has_sm_8_0:
raise NotImplementedError(
"Using quantized marlin models requires CUDA capability 8.0 or later"
)
_check_marlin_kernels()
assert marlin_kernels is not None
if marlin is None:
raise NotImplementedError(
"You do not seem to have marlin installed, either install it (cd server && make install-marlin)"
)
assert B.dtype == torch.int32
assert s.dtype == torch.float16
in_features = B.shape[0] * MARLIN_TILE_SIZE
out_features = s.shape[1]
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE
out_features = weight.s.shape[1]
assert (
in_features % 128 == 0
), f"Number of input features ({in_features}) not divisable by 128"
@ -60,35 +249,36 @@ class MarlinLinear(nn.Module):
out_features % 256 == 0
), f"Number of output features ({out_features}) not divisable by 256"
group_size = -1 if s.shape[0] == 1 else in_features // s.shape[0]
assert group_size in {
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
assert groupsize in {
-1,
128,
}, f"Group size must be -1 or 128, was {group_size}"
}, f"Group size must be -1 or 128, was {groupsize}"
self.register_buffer("B", B)
self.register_buffer("s", s)
self.register_buffer("B", weight.B)
self.register_buffer("s", weight.s)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
self.workspace = torch.zeros(
out_features // 128 * 16, dtype=torch.int, device=B.device
out_features // 64 * 16, dtype=torch.int, device=weight.B.device
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin is not None
C = torch.empty(
A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device
)
marlin.mul(
A.view((-1, A.shape[-1])),
assert marlin_kernels is not None
C = marlin_kernels.marlin_gemm(
A.view(-1, A.shape[-1]),
self.B,
C.view((-1, C.shape[-1])),
self.s,
self.workspace,
A.shape[0],
self.s.shape[1],
A.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
if self.bias is not None:
C += self.bias

View File

@ -83,7 +83,7 @@ class BLOOMSharded(CausalLM):
process_group=self.process_group,
prefix="transformer",
)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = BloomForCausalLM(config, weights)

View File

@ -166,7 +166,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads

View File

@ -81,16 +81,11 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device)
(
bits,
groupsize,
_,
quant_method,
) = weights._get_gptq_params()
if quant_method == "gptq":
gptq_params = weights._get_gptq_params()
if gptq_params.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device)
elif quant_method == "awq":
elif gptq_params.quant_method == "awq":
g_idx = None
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
@ -105,8 +100,8 @@ def _load_multi_mqa_gptq(
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=HAS_EXLLAMA,
)

View File

@ -130,7 +130,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads

View File

@ -55,7 +55,7 @@ class FlashCohere(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashCohereForCausalLM(config, weights)

View File

@ -80,7 +80,7 @@ class FlashDbrx(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashDbrxForCausalLM(config, weights)

View File

@ -53,7 +53,7 @@ class FlashGemma(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
# TODO hardcoded

View File

@ -67,7 +67,7 @@ class FlashLlama(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "exl2"]:
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""

View File

@ -68,7 +68,7 @@ class BaseFlashMistral(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""

View File

@ -58,7 +58,7 @@ class FlashNeoXSharded(FlashCausalLM):
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashGPTNeoXForCausalLM(config, weights)

View File

@ -53,7 +53,7 @@ class FlashPhi(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashPhiForCausalLM(config, weights)

View File

@ -62,7 +62,7 @@ class FlashQwen2(BaseFlashMistral):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = Qwen2ForCausalLM(config, weights)

View File

@ -67,7 +67,7 @@ class FlashRWSharded(FlashCausalLM):
config.quantize = quantize
config.speculator = speculator
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashRWForCausalLM(config, weights)

View File

@ -69,7 +69,7 @@ class FlashSantacoderSharded(FlashCausalLM):
process_group=self.process_group,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashSantacoderForCausalLM(config, weights)

View File

@ -61,7 +61,7 @@ class FlashStarcoder2(BaseFlashMistral):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashStarcoder2ForCausalLM(config, weights)

View File

@ -205,7 +205,7 @@ class GalacticaSharded(CausalLM):
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = OPTForCausalLM(config, weights)

View File

@ -58,7 +58,7 @@ class GPTNeoxSharded(CausalLM):
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = GPTNeoxForCausalLM(config, weights)

View File

@ -82,7 +82,7 @@ class MPTSharded(CausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
config.quantize = quantize

View File

@ -56,7 +56,7 @@ class OPTSharded(CausalLM):
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = OPTForCausalLM(config, weights)

View File

@ -1,4 +1,5 @@
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from safetensors import safe_open, SafetensorError
@ -9,6 +10,15 @@ import json
from text_generation_server.utils.log import log_once
@dataclass
class _GPTQParams:
bits: int
groupsize: int
desc_act: bool
quant_method: str
sym: bool
class Weights:
def __init__(
self,
@ -181,15 +191,15 @@ class Weights:
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
)
bits, groupsize, _, quant_method = self._get_gptq_params()
gptq_params = self._get_gptq_params()
qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes)
scales = self._get_qweight(f"{prefix}.scales", block_sizes)
scales = scales.to(dtype=self.dtype)
if quantize == "gptq" and quant_method == "gptq":
if quantize == "gptq" and gptq_params.quant_method == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx")
elif quantize == "gptq" and quant_method == "awq":
elif quantize == "gptq" and gptq_params.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
@ -199,8 +209,11 @@ class Weights:
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = (
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
// groupsize
torch.arange(
qweight.shape[0] * (32 // gptq_params.bits),
device=qweight.device,
)
// gptq_params.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
@ -210,16 +223,43 @@ class Weights:
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=False,
)
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight
from text_generation_server.layers.marlin import (
MarlinWeight,
repack_gptq_for_marlin,
)
B = self._get_qweight(f"{prefix}.B", block_sizes)
s = self._get_qweight(f"{prefix}.s", block_sizes)
weight = MarlinWeight(B=B, s=s)
quant_method = getattr(self, "quant_method", "marlin")
if quant_method == "gptq":
gptq_params = self._get_gptq_params()
try:
qweight = self._get_qweight(f"{prefix}.qweight", block_sizes)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
scales = self._get_qweight(f"{prefix}.scales", block_sizes)
g_idx = self.get_tensor(f"{prefix}.g_idx")
weight = repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
else:
B = self._get_qweight(f"{prefix}.B", block_sizes)
s = self._get_qweight(f"{prefix}.s", block_sizes)
weight = MarlinWeight(B=B, s=s)
else:
slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0]
@ -295,20 +335,23 @@ class Weights:
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
gptq_params = self._get_gptq_params()
from text_generation_server.layers.gptq import HAS_EXLLAMA
use_exllama = (
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
gptq_params.bits == 4
and HAS_EXLLAMA
and quantize == "gptq"
and not gptq_params.desc_act
)
if quantize == "gptq" and quant_method == "gptq":
if quantize == "gptq" and gptq_params.quant_method == "gptq":
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif quantize == "gptq" and quant_method == "awq":
elif quantize == "gptq" and gptq_params.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
@ -322,9 +365,10 @@ class Weights:
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // bits), device=qweight.device
qweight.shape[0] * (32 // gptq_params.bits),
device=qweight.device,
)
// groupsize
// gptq_params.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
@ -334,24 +378,62 @@ class Weights:
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=use_exllama,
)
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
MarlinWeight,
repack_gptq_for_marlin,
)
try:
B = torch.cat(
[self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)
s = torch.cat([self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1)
quant_method = getattr(self, "quant_method", "marlin")
if quant_method == "gptq":
gptq_params = self._get_gptq_params()
try:
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes],
dim=1,
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
weight = MarlinWeight(B=B, s=s)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
weight = repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
else:
try:
B = torch.cat(
[self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)
s = torch.cat(
[self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = MarlinWeight(B=B, s=s)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
@ -401,12 +483,12 @@ class Weights:
elif quantize == "gptq":
use_exllama = True
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
gptq_params = self._get_gptq_params()
if bits != 4:
if gptq_params.bits != 4:
use_exllama = False
if desc_act:
if gptq_params.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
@ -417,9 +499,9 @@ class Weights:
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if quant_method == "gptq":
if gptq_params.quant_method == "gptq":
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
elif quant_method == "awq":
elif gptq_params.quant_method == "awq":
g_idx = None
if self.process_group.size() > 1:
@ -428,7 +510,10 @@ class Weights:
not torch.equal(
g_idx.cpu(),
torch.tensor(
[i // groupsize for i in range(g_idx.shape[0])],
[
i // gptq_params.groupsize
for i in range(g_idx.shape[0])
],
dtype=torch.int32,
),
)
@ -455,7 +540,7 @@ class Weights:
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama and groupsize != -1:
if use_exllama and gptq_params.groupsize != -1:
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0)
else:
@ -465,7 +550,7 @@ class Weights:
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if quant_method == "awq":
if gptq_params.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
@ -479,9 +564,10 @@ class Weights:
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // bits), device=qweight.device
qweight.shape[0] * (32 // gptq_params.bits),
device=qweight.device,
)
// groupsize
// gptq_params.groupsize
).to(dtype=torch.int32)
weight = GPTQWeight(
@ -489,14 +575,14 @@ class Weights:
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=use_exllama,
)
elif quantize == "awq":
from text_generation_server.layers.gptq import GPTQWeight
bits, groupsize, _, _ = self._get_gptq_params()
gptq_params = self._get_gptq_params()
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
@ -515,38 +601,74 @@ class Weights:
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=use_exllama,
)
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
MarlinWeight,
repack_gptq_for_marlin,
)
try:
B = self.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
quant_method = getattr(self, "quant_method", "marlin")
if quant_method == "gptq":
log_once(logger.info, "Converting GPTQ model to Marlin packing format.")
gptq_params = self._get_gptq_params()
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if gptq_params.desc_act or gptq_params.groupsize == -1:
scales = self.get_tensor(f"{prefix}.scales")
else:
scales = self.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = self.process_group.size() > 1
weight = repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=sharded_in_features,
)
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when group_size == -1. share
# scales between all shards in this case.
s = self.get_tensor(f"{prefix}.s")
else:
s = self.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
try:
B = self.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = self.get_tensor(f"{prefix}.s")
else:
s = self.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
def _get_gptq_params(self) -> Tuple[int, int, int, str]:
def _get_gptq_params(self) -> _GPTQParams:
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
desc_act = False
sym = True
quant_method = "gptq"
except (SafetensorError, RuntimeError) as e:
try:
@ -554,10 +676,17 @@ class Weights:
groupsize = self.gptq_groupsize
desc_act = getattr(self, "gptq_desc_act", False)
quant_method = getattr(self, "quant_method", "gptq")
sym = getattr(self, "sym", True)
except Exception:
raise e
return bits, groupsize, desc_act, quant_method
return _GPTQParams(
bits=bits,
desc_act=desc_act,
groupsize=groupsize,
quant_method=quant_method,
sym=sym,
)
def _set_gptq_params(self, model_id, revision):
filename = "config.json"
@ -574,6 +703,7 @@ class Weights:
self.gptq_groupsize = data["quantization_config"]["group_size"]
# Order is important here, desc_act is missing on some real models
self.quant_method = data["quantization_config"]["quant_method"]
self.gptq_sym = data["quantization_config"]["sym"]
self.gptq_desc_act = data["quantization_config"]["desc_act"]
except Exception:
filename = "quantize_config.json"
@ -588,6 +718,7 @@ class Weights:
data = json.load(f)
self.gptq_bits = data["bits"]
self.gptq_groupsize = data["group_size"]
self.gptq_sym = data["sym"]
self.gptq_desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
self.quant_method = "awq"