Install Marlin from standalone package (#2320)
This commit is contained in:
parent
583d37a2f8
commit
922732b255
12
Dockerfile
12
Dockerfile
|
@ -140,13 +140,6 @@ COPY server/Makefile-eetq Makefile
|
||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
|
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
|
||||||
|
|
||||||
# Build marlin kernels
|
|
||||||
FROM kernel-builder AS marlin-kernels-builder
|
|
||||||
WORKDIR /usr/src
|
|
||||||
COPY server/marlin/ .
|
|
||||||
# Build specific version of transformers
|
|
||||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
|
||||||
|
|
||||||
# Build Lorax Punica kernels
|
# Build Lorax Punica kernels
|
||||||
FROM kernel-builder AS lorax-punica-builder
|
FROM kernel-builder AS lorax-punica-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
@ -231,9 +224,6 @@ COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-31
|
||||||
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
# Copy build artifacts from eetq kernels builder
|
# Copy build artifacts from eetq kernels builder
|
||||||
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
# Copy build artifacts from marlin kernels builder
|
|
||||||
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
||||||
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
||||||
# Copy build artifacts from fbgemm builder
|
# Copy build artifacts from fbgemm builder
|
||||||
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages
|
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages
|
||||||
# Copy build artifacts from vllm builder
|
# Copy build artifacts from vllm builder
|
||||||
|
@ -252,7 +242,7 @@ COPY server/Makefile server/Makefile
|
||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
make gen-server && \
|
||||||
pip install -r requirements_cuda.txt && \
|
pip install -r requirements_cuda.txt && \
|
||||||
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir && \
|
pip install ".[bnb, accelerate, marlin, quantize, peft, outlines]" --no-cache-dir && \
|
||||||
pip install nvidia-nccl-cu12==2.22.3
|
pip install nvidia-nccl-cu12==2.22.3
|
||||||
|
|
||||||
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
|
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||||
|
|
|
@ -1,20 +0,0 @@
|
||||||
These kernels were vendored from VLLM. The Marlin kernels were developed
|
|
||||||
by Elias Frantar and extended by Neural Magic.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
Copyright (C) Marlin.2024 Elias Frantar
|
|
||||||
Modified by Neural Magic
|
|
||||||
Copyright 2024 The vLLM team.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
|
@ -1,76 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
def gptq_marlin_gemm(
|
|
||||||
a: torch.Tensor,
|
|
||||||
b_q_weight: torch.Tensor,
|
|
||||||
b_scales: torch.Tensor,
|
|
||||||
g_idx: torch.Tensor,
|
|
||||||
perm: torch.Tensor,
|
|
||||||
workspace: torch.Tensor,
|
|
||||||
num_bits: int,
|
|
||||||
size_m: int,
|
|
||||||
size_n: int,
|
|
||||||
size_k: int,
|
|
||||||
is_k_full: bool,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Matrix multiplication using Marlin kernels. This is an extension of
|
|
||||||
`marlin_gemm` that supports converted GPTQ kernels.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def gptq_marlin_24_gemm(
|
|
||||||
a: torch.Tensor,
|
|
||||||
b_q_weight: torch.Tensor,
|
|
||||||
b_meta: torch.Tensor,
|
|
||||||
b_scales: torch.Tensor,
|
|
||||||
workspace: torch.Tensor,
|
|
||||||
num_bits: int,
|
|
||||||
size_m: int,
|
|
||||||
size_n: int,
|
|
||||||
size_k: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Matrix multiplication using Marlin kernels. This is an extension of
|
|
||||||
`marlin_gemm` that supports 2:4 sparsity.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def gptq_marlin_repack(
|
|
||||||
b_q_weight: torch.Tensor,
|
|
||||||
perm: torch.Tensor,
|
|
||||||
size_k: int,
|
|
||||||
size_n: int,
|
|
||||||
num_bits: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Repack GPTQ parameters for Marlin kernels."""
|
|
||||||
...
|
|
||||||
|
|
||||||
def marlin_gemm(
|
|
||||||
a: torch.Tensor,
|
|
||||||
b_q_weight: torch.Tensor,
|
|
||||||
b_scales: torch.Tensor,
|
|
||||||
workspace: torch.Tensor,
|
|
||||||
size_m: int,
|
|
||||||
size_n: int,
|
|
||||||
size_k: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Matrix multiplication using Marlin kernels.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
# fp8 marlin
|
|
||||||
def fp8_marlin_gemm(
|
|
||||||
a: torch.Tensor,
|
|
||||||
b_q_weight: torch.Tensor,
|
|
||||||
b_scales: torch.Tensor,
|
|
||||||
workspace: torch.Tensor,
|
|
||||||
num_bits: int,
|
|
||||||
size_m: int,
|
|
||||||
size_n: int,
|
|
||||||
size_k: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return torch.ops._C.fp8_marlin_gemm(
|
|
||||||
a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
|
|
||||||
)
|
|
|
@ -1,269 +0,0 @@
|
||||||
#include "marlin.cuh"
|
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
|
||||||
|
|
||||||
namespace marlin {
|
|
||||||
|
|
||||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
|
||||||
__global__ void awq_marlin_repack_kernel(
|
|
||||||
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
|
|
||||||
int size_k, int size_n) {}
|
|
||||||
|
|
||||||
} // namespace marlin
|
|
||||||
|
|
||||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
|
||||||
int64_t size_k, int64_t size_n,
|
|
||||||
int64_t num_bits) {
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
|
||||||
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
|
|
||||||
return torch::empty({1, 1});
|
|
||||||
}
|
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
namespace marlin {
|
|
||||||
|
|
||||||
template <int const num_threads, int const num_bits>
|
|
||||||
__global__ void awq_marlin_repack_kernel(
|
|
||||||
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
|
|
||||||
int size_k, int size_n) {
|
|
||||||
constexpr int pack_factor = 32 / num_bits;
|
|
||||||
|
|
||||||
int k_tiles = size_k / tile_k_size;
|
|
||||||
int n_tiles = size_n / tile_n_size;
|
|
||||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
|
||||||
|
|
||||||
int start_k_tile = blockIdx.x * block_k_tiles;
|
|
||||||
if (start_k_tile >= k_tiles) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
|
||||||
|
|
||||||
// Wait until the next thread tile has been loaded to shared memory.
|
|
||||||
auto wait_for_stage = [&]() {
|
|
||||||
// We only have `stages - 2` active fetches since we are double buffering
|
|
||||||
// and can only issue the next fetch when it is guaranteed that the previous
|
|
||||||
// shared memory load is fully complete (as it may otherwise be
|
|
||||||
// overwritten).
|
|
||||||
cp_async_wait<repack_stages - 2>();
|
|
||||||
__syncthreads();
|
|
||||||
};
|
|
||||||
|
|
||||||
extern __shared__ int4 sh[];
|
|
||||||
|
|
||||||
constexpr int tile_n_ints = tile_n_size / pack_factor;
|
|
||||||
|
|
||||||
constexpr int stage_n_threads = tile_n_ints / 4;
|
|
||||||
constexpr int stage_k_threads = tile_k_size;
|
|
||||||
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
|
||||||
|
|
||||||
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
|
||||||
if (n_tile_id >= n_tiles) {
|
|
||||||
cp_async_fence();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int first_n = n_tile_id * tile_n_size;
|
|
||||||
int first_n_packed = first_n / pack_factor;
|
|
||||||
|
|
||||||
int4* sh_ptr = sh + stage_size * pipe;
|
|
||||||
|
|
||||||
if (threadIdx.x < stage_size) {
|
|
||||||
int k_id = threadIdx.x / stage_n_threads;
|
|
||||||
int n_id = threadIdx.x % stage_n_threads;
|
|
||||||
|
|
||||||
int first_k = k_tile_id * tile_k_size;
|
|
||||||
|
|
||||||
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
|
|
||||||
reinterpret_cast<int4 const*>(
|
|
||||||
&(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) +
|
|
||||||
first_n_packed + (n_id * 4)])));
|
|
||||||
}
|
|
||||||
|
|
||||||
cp_async_fence();
|
|
||||||
};
|
|
||||||
|
|
||||||
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
|
||||||
if (n_tile_id >= n_tiles) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int warp_id = threadIdx.x / 32;
|
|
||||||
int th_id = threadIdx.x % 32;
|
|
||||||
|
|
||||||
if (warp_id >= 4) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int tc_col = th_id / 4;
|
|
||||||
int tc_row = (th_id % 4) * 2;
|
|
||||||
|
|
||||||
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
|
||||||
|
|
||||||
int cur_n = warp_id * 16 + tc_col;
|
|
||||||
int cur_n_packed = cur_n / pack_factor;
|
|
||||||
int cur_n_pos = cur_n % pack_factor;
|
|
||||||
|
|
||||||
constexpr int sh_stride = tile_n_ints;
|
|
||||||
constexpr uint32_t mask = (1 << num_bits) - 1;
|
|
||||||
|
|
||||||
int4* sh_stage_ptr = sh + stage_size * pipe;
|
|
||||||
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
|
||||||
|
|
||||||
// Undo interleaving
|
|
||||||
int cur_n_pos_unpacked;
|
|
||||||
if constexpr (num_bits == 4) {
|
|
||||||
constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
|
|
||||||
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
|
||||||
} else {
|
|
||||||
constexpr int undo_pack[4] = {0, 2, 1, 3};
|
|
||||||
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t vals[8];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 4; i++) {
|
|
||||||
int cur_elem = tc_row + tc_offsets[i];
|
|
||||||
|
|
||||||
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
|
|
||||||
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
|
|
||||||
sh_stride * cur_elem];
|
|
||||||
|
|
||||||
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
|
||||||
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
|
||||||
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
|
||||||
|
|
||||||
// Result of:
|
|
||||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
|
||||||
if constexpr (num_bits == 4) {
|
|
||||||
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
|
||||||
|
|
||||||
uint32_t res = 0;
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 8; i++) {
|
|
||||||
res |= vals[pack_idx[i]] << (i * 4);
|
|
||||||
}
|
|
||||||
|
|
||||||
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
|
||||||
|
|
||||||
} else {
|
|
||||||
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
|
||||||
|
|
||||||
uint32_t res1 = 0;
|
|
||||||
uint32_t res2 = 0;
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 4; i++) {
|
|
||||||
res1 |= vals[pack_idx[i]] << (i * 8);
|
|
||||||
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
|
||||||
}
|
|
||||||
|
|
||||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
|
||||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
|
||||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
|
||||||
}
|
|
||||||
|
|
||||||
wait_for_stage();
|
|
||||||
};
|
|
||||||
#pragma unroll
|
|
||||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
|
||||||
int n_tile_id = 0;
|
|
||||||
|
|
||||||
start_pipes(k_tile_id, n_tile_id);
|
|
||||||
|
|
||||||
while (n_tile_id < n_tiles) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
|
||||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
|
||||||
n_tile_id + pipe + repack_stages - 1);
|
|
||||||
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
|
||||||
wait_for_stage();
|
|
||||||
}
|
|
||||||
n_tile_id += repack_stages;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace marlin
|
|
||||||
|
|
||||||
#define CALL_IF(NUM_BITS) \
|
|
||||||
else if (num_bits == NUM_BITS) { \
|
|
||||||
cudaFuncSetAttribute( \
|
|
||||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
|
|
||||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
|
||||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
|
|
||||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
|
||||||
b_q_weight_ptr, out_ptr, size_k, size_n); \
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
|
||||||
int64_t size_n, int64_t num_bits) {
|
|
||||||
// Verify compatibility with marlin tile of 16x64
|
|
||||||
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
|
|
||||||
" is not divisible by tile_k_size = ", marlin::tile_k_size);
|
|
||||||
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
|
|
||||||
" is not divisible by tile_n_size = ", marlin::tile_n_size);
|
|
||||||
|
|
||||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
|
||||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
|
||||||
int const pack_factor = 32 / num_bits;
|
|
||||||
|
|
||||||
// Verify B
|
|
||||||
TORCH_CHECK(b_q_weight.size(0) == size_k,
|
|
||||||
"b_q_weight.size(0) = ", b_q_weight.size(0),
|
|
||||||
" is not size_k = ", size_k);
|
|
||||||
TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1),
|
|
||||||
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
|
|
||||||
", size_n = ", size_n, ", pack_factor = ", pack_factor);
|
|
||||||
|
|
||||||
// Verify device and strides
|
|
||||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
|
||||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
|
||||||
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
|
||||||
|
|
||||||
// Alloc buffers
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
|
||||||
auto options = torch::TensorOptions()
|
|
||||||
.dtype(b_q_weight.dtype())
|
|
||||||
.device(b_q_weight.device());
|
|
||||||
torch::Tensor out = torch::empty(
|
|
||||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
|
||||||
options);
|
|
||||||
|
|
||||||
// Get ptrs
|
|
||||||
uint32_t const* b_q_weight_ptr =
|
|
||||||
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
|
||||||
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
|
||||||
|
|
||||||
// Get dev info
|
|
||||||
int dev = b_q_weight.get_device();
|
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
|
||||||
int blocks;
|
|
||||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
|
||||||
|
|
||||||
int max_shared_mem = 0;
|
|
||||||
cudaDeviceGetAttribute(&max_shared_mem,
|
|
||||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
|
||||||
TORCH_CHECK(max_shared_mem > 0);
|
|
||||||
|
|
||||||
if (false) {
|
|
||||||
}
|
|
||||||
CALL_IF(4)
|
|
||||||
CALL_IF(8)
|
|
||||||
else {
|
|
||||||
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
|
|
||||||
}
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
|
@ -1,16 +0,0 @@
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
#include "ext.hh"
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("awq_marlin_repack", &awq_marlin_repack,
|
|
||||||
"Repack AWQ parameters for Marlin");
|
|
||||||
m.def("gptq_marlin_gemm", &gptq_marlin_gemm,
|
|
||||||
"Marlin gemm with GPTQ compatibility");
|
|
||||||
m.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin sparse 2:4 gemm");
|
|
||||||
m.def("gptq_marlin_repack", &gptq_marlin_repack,
|
|
||||||
"Repack GPTQ parameters for Marlin");
|
|
||||||
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");
|
|
||||||
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
|
||||||
m.def("fp8_marlin_gemm", &fp8_marlin_gemm);
|
|
||||||
}
|
|
|
@ -1,39 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <torch/library.h>
|
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
|
||||||
// No support for async
|
|
||||||
#else
|
|
||||||
|
|
||||||
torch::Tensor awq_marlin_repack(torch::Tensor &b_q_weight, int64_t size_k,
|
|
||||||
int64_t size_n, int64_t num_bits);
|
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|
||||||
torch::Tensor &b_scales, torch::Tensor &b_zeros,
|
|
||||||
torch::Tensor &g_idx, torch::Tensor &perm,
|
|
||||||
torch::Tensor &workspace, int64_t num_bits,
|
|
||||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
|
||||||
bool is_k_full, bool has_zp);
|
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|
||||||
torch::Tensor &b_meta,
|
|
||||||
torch::Tensor &b_scales,
|
|
||||||
torch::Tensor &workspace, int64_t num_bits,
|
|
||||||
int64_t size_m, int64_t size_n,
|
|
||||||
int64_t size_k);
|
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
|
||||||
int64_t size_k, int64_t size_n,
|
|
||||||
int64_t num_bits);
|
|
||||||
|
|
||||||
torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|
||||||
torch::Tensor &b_scales, torch::Tensor &workspace,
|
|
||||||
int64_t size_m, int64_t size_n, int64_t size_k);
|
|
||||||
|
|
||||||
torch::Tensor fp8_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|
||||||
torch::Tensor &b_scales, torch::Tensor &workspace,
|
|
||||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
|
||||||
int64_t size_k);
|
|
||||||
|
|
||||||
#endif
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -1,344 +0,0 @@
|
||||||
#include "marlin.cuh"
|
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
|
||||||
|
|
||||||
namespace marlin {
|
|
||||||
|
|
||||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
|
||||||
__global__ void gptq_marlin_repack_kernel(
|
|
||||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
|
||||||
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
|
||||||
int size_k, int size_n) {}
|
|
||||||
|
|
||||||
} // namespace marlin
|
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
|
||||||
int64_t size_k, int64_t size_n,
|
|
||||||
int64_t num_bits) {
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
|
||||||
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
|
|
||||||
return torch::empty({1, 1});
|
|
||||||
}
|
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
namespace marlin {
|
|
||||||
|
|
||||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
|
||||||
__global__ void gptq_marlin_repack_kernel(
|
|
||||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
|
||||||
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
|
||||||
int size_k, int size_n) {
|
|
||||||
constexpr int pack_factor = 32 / num_bits;
|
|
||||||
|
|
||||||
int k_tiles = size_k / tile_k_size;
|
|
||||||
int n_tiles = size_n / tile_n_size;
|
|
||||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
|
||||||
|
|
||||||
int start_k_tile = blockIdx.x * block_k_tiles;
|
|
||||||
if (start_k_tile >= k_tiles) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
|
||||||
|
|
||||||
// Wait until the next thread tile has been loaded to shared memory.
|
|
||||||
auto wait_for_stage = [&]() {
|
|
||||||
// We only have `stages - 2` active fetches since we are double buffering
|
|
||||||
// and can only issue the next fetch when it is guaranteed that the previous
|
|
||||||
// shared memory load is fully complete (as it may otherwise be
|
|
||||||
// overwritten).
|
|
||||||
cp_async_wait<repack_stages - 2>();
|
|
||||||
__syncthreads();
|
|
||||||
};
|
|
||||||
|
|
||||||
extern __shared__ int4 sh[];
|
|
||||||
|
|
||||||
constexpr int perm_size = tile_k_size / 4;
|
|
||||||
|
|
||||||
int4* sh_perm_ptr = sh;
|
|
||||||
int4* sh_pipe_ptr = sh_perm_ptr;
|
|
||||||
if constexpr (has_perm) {
|
|
||||||
sh_pipe_ptr += perm_size;
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr int tile_ints = tile_k_size / pack_factor;
|
|
||||||
|
|
||||||
constexpr int stage_n_threads = tile_n_size / 4;
|
|
||||||
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
|
|
||||||
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
|
||||||
|
|
||||||
auto load_perm_to_shared = [&](int k_tile_id) {
|
|
||||||
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
|
|
||||||
|
|
||||||
int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
|
|
||||||
|
|
||||||
if (threadIdx.x < perm_size) {
|
|
||||||
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
};
|
|
||||||
|
|
||||||
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
|
||||||
if (n_tile_id >= n_tiles) {
|
|
||||||
cp_async_fence();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int first_n = n_tile_id * tile_n_size;
|
|
||||||
|
|
||||||
int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
|
|
||||||
|
|
||||||
if constexpr (has_perm) {
|
|
||||||
if (threadIdx.x < stage_size) {
|
|
||||||
int k_id = threadIdx.x / stage_n_threads;
|
|
||||||
int n_id = threadIdx.x % stage_n_threads;
|
|
||||||
|
|
||||||
uint32_t const* sh_perm_int_ptr =
|
|
||||||
reinterpret_cast<uint32_t const*>(sh_perm_ptr);
|
|
||||||
|
|
||||||
int src_k = sh_perm_int_ptr[k_id];
|
|
||||||
int src_k_packed = src_k / pack_factor;
|
|
||||||
|
|
||||||
cp_async4(
|
|
||||||
&sh_ptr[k_id * stage_n_threads + n_id],
|
|
||||||
reinterpret_cast<int4 const*>(&(
|
|
||||||
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
if (threadIdx.x < stage_size) {
|
|
||||||
int k_id = threadIdx.x / stage_n_threads;
|
|
||||||
int n_id = threadIdx.x % stage_n_threads;
|
|
||||||
|
|
||||||
int first_k = k_tile_id * tile_k_size;
|
|
||||||
int first_k_packed = first_k / pack_factor;
|
|
||||||
|
|
||||||
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
|
|
||||||
reinterpret_cast<int4 const*>(
|
|
||||||
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
|
|
||||||
first_n + (n_id * 4)])));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cp_async_fence();
|
|
||||||
};
|
|
||||||
|
|
||||||
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
|
||||||
if (n_tile_id >= n_tiles) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int warp_id = threadIdx.x / 32;
|
|
||||||
int th_id = threadIdx.x % 32;
|
|
||||||
|
|
||||||
if (warp_id >= 4) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int tc_col = th_id / 4;
|
|
||||||
int tc_row = (th_id % 4) * 2;
|
|
||||||
|
|
||||||
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
|
||||||
|
|
||||||
int cur_n = warp_id * 16 + tc_col;
|
|
||||||
|
|
||||||
constexpr int sh_stride = 64;
|
|
||||||
constexpr uint32_t mask = (1 << num_bits) - 1;
|
|
||||||
|
|
||||||
int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
|
|
||||||
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
|
||||||
|
|
||||||
uint32_t* sh_perm_int_ptr = reinterpret_cast<uint32_t*>(sh_perm_ptr);
|
|
||||||
|
|
||||||
uint32_t vals[8];
|
|
||||||
|
|
||||||
if constexpr (has_perm) {
|
|
||||||
for (int i = 0; i < 4; i++) {
|
|
||||||
int k_idx = tc_row + tc_offsets[i];
|
|
||||||
|
|
||||||
uint32_t src_k = sh_perm_int_ptr[k_idx];
|
|
||||||
uint32_t src_k_pos = src_k % pack_factor;
|
|
||||||
|
|
||||||
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
|
|
||||||
uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
|
|
||||||
|
|
||||||
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
|
|
||||||
uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
|
|
||||||
|
|
||||||
vals[i] = b1_cur_val;
|
|
||||||
vals[4 + i] = b2_cur_val;
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
uint32_t b1_vals[tile_ints];
|
|
||||||
uint32_t b2_vals[tile_ints];
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < tile_ints; i++) {
|
|
||||||
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
|
|
||||||
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 4; i++) {
|
|
||||||
int cur_elem = tc_row + tc_offsets[i];
|
|
||||||
int cur_int = cur_elem / pack_factor;
|
|
||||||
int cur_pos = cur_elem % pack_factor;
|
|
||||||
|
|
||||||
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
|
||||||
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
|
||||||
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
|
||||||
|
|
||||||
// Result of:
|
|
||||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
|
||||||
if constexpr (num_bits == 4) {
|
|
||||||
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
|
||||||
|
|
||||||
uint32_t res = 0;
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 8; i++) {
|
|
||||||
res |= vals[pack_idx[i]] << (i * 4);
|
|
||||||
}
|
|
||||||
|
|
||||||
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
|
||||||
|
|
||||||
} else {
|
|
||||||
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
|
||||||
|
|
||||||
uint32_t res1 = 0;
|
|
||||||
uint32_t res2 = 0;
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 4; i++) {
|
|
||||||
res1 |= vals[pack_idx[i]] << (i * 8);
|
|
||||||
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
|
||||||
}
|
|
||||||
|
|
||||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
|
||||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
|
||||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
|
||||||
}
|
|
||||||
|
|
||||||
wait_for_stage();
|
|
||||||
};
|
|
||||||
#pragma unroll
|
|
||||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
|
||||||
int n_tile_id = 0;
|
|
||||||
|
|
||||||
if constexpr (has_perm) {
|
|
||||||
load_perm_to_shared(k_tile_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
start_pipes(k_tile_id, n_tile_id);
|
|
||||||
|
|
||||||
while (n_tile_id < n_tiles) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
|
||||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
|
||||||
n_tile_id + pipe + repack_stages - 1);
|
|
||||||
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
|
||||||
wait_for_stage();
|
|
||||||
}
|
|
||||||
n_tile_id += repack_stages;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace marlin
|
|
||||||
|
|
||||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
|
||||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
|
||||||
cudaFuncSetAttribute( \
|
|
||||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
|
||||||
HAS_PERM>, \
|
|
||||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
|
||||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
|
||||||
HAS_PERM> \
|
|
||||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
|
||||||
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
|
||||||
int64_t size_k, int64_t size_n,
|
|
||||||
int64_t num_bits) {
|
|
||||||
// Verify compatibility with marlin tile of 16x64
|
|
||||||
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
|
|
||||||
" is not divisible by tile_k_size = ", marlin::tile_k_size);
|
|
||||||
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
|
|
||||||
" is not divisible by tile_n_size = ", marlin::tile_n_size);
|
|
||||||
|
|
||||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
|
||||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
|
||||||
int const pack_factor = 32 / num_bits;
|
|
||||||
|
|
||||||
// Verify B
|
|
||||||
TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0),
|
|
||||||
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
|
||||||
", size_k = ", size_k, ", pack_factor = ", pack_factor);
|
|
||||||
TORCH_CHECK(b_q_weight.size(1) == size_n,
|
|
||||||
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
|
||||||
" is not size_n = ", size_n);
|
|
||||||
|
|
||||||
// Verify device and strides
|
|
||||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
|
||||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
|
||||||
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
|
||||||
|
|
||||||
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
|
|
||||||
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
|
|
||||||
TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt");
|
|
||||||
|
|
||||||
// Alloc buffers
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
|
||||||
auto options = torch::TensorOptions()
|
|
||||||
.dtype(b_q_weight.dtype())
|
|
||||||
.device(b_q_weight.device());
|
|
||||||
torch::Tensor out = torch::empty(
|
|
||||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
|
||||||
options);
|
|
||||||
|
|
||||||
// Detect if there is act_order
|
|
||||||
bool has_perm = perm.size(0) != 0;
|
|
||||||
|
|
||||||
// Get ptrs
|
|
||||||
uint32_t const* b_q_weight_ptr =
|
|
||||||
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
|
||||||
uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());
|
|
||||||
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
|
||||||
|
|
||||||
// Get dev info
|
|
||||||
int dev = b_q_weight.get_device();
|
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
|
||||||
int blocks;
|
|
||||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
|
||||||
|
|
||||||
int max_shared_mem = 0;
|
|
||||||
cudaDeviceGetAttribute(&max_shared_mem,
|
|
||||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
|
||||||
TORCH_CHECK(max_shared_mem > 0);
|
|
||||||
|
|
||||||
if (false) {
|
|
||||||
}
|
|
||||||
CALL_IF(4, false)
|
|
||||||
CALL_IF(4, true)
|
|
||||||
CALL_IF(8, false)
|
|
||||||
CALL_IF(8, true)
|
|
||||||
else {
|
|
||||||
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
|
|
||||||
", has_perm = ", has_perm);
|
|
||||||
}
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
|
@ -1,87 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <torch/all.h>
|
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
#include <cuda.h>
|
|
||||||
#include <cuda_fp16.h>
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
namespace marlin {
|
|
||||||
|
|
||||||
// Marlin params
|
|
||||||
|
|
||||||
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
|
||||||
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
|
||||||
// we want relatively few warps to have many registers per warp and small tiles.
|
|
||||||
static constexpr int default_threads = 256;
|
|
||||||
|
|
||||||
static constexpr int pipe_stages =
|
|
||||||
4; // 4 pipeline stages fit into shared memory
|
|
||||||
|
|
||||||
static constexpr int min_thread_n = 64;
|
|
||||||
static constexpr int min_thread_k = 64;
|
|
||||||
|
|
||||||
static constexpr int tile_size = 16;
|
|
||||||
static constexpr int max_par = 16;
|
|
||||||
|
|
||||||
// Repack params
|
|
||||||
static constexpr int repack_stages = 8;
|
|
||||||
|
|
||||||
static constexpr int repack_threads = 256;
|
|
||||||
|
|
||||||
static constexpr int tile_k_size = tile_size;
|
|
||||||
static constexpr int tile_n_size = tile_k_size * 4;
|
|
||||||
|
|
||||||
// Helpers
|
|
||||||
template <typename T, int n>
|
|
||||||
struct Vec {
|
|
||||||
T elems[n];
|
|
||||||
__device__ T& operator[](int i) { return elems[i]; }
|
|
||||||
};
|
|
||||||
|
|
||||||
using I4 = Vec<int, 4>;
|
|
||||||
|
|
||||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
|
||||||
// No support for async
|
|
||||||
#else
|
|
||||||
|
|
||||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
|
||||||
bool pred = true) {
|
|
||||||
const int BYTES = 16;
|
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
|
||||||
asm volatile(
|
|
||||||
"{\n"
|
|
||||||
" .reg .pred p;\n"
|
|
||||||
" setp.ne.b32 p, %0, 0;\n"
|
|
||||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
|
||||||
"}\n" ::"r"((int)pred),
|
|
||||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
|
||||||
const int BYTES = 16;
|
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
|
||||||
asm volatile(
|
|
||||||
"{\n"
|
|
||||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
|
||||||
"}\n" ::"r"(smem),
|
|
||||||
"l"(glob_ptr), "n"(BYTES));
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ inline void cp_async_fence() {
|
|
||||||
asm volatile("cp.async.commit_group;\n" ::);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int n>
|
|
||||||
__device__ inline void cp_async_wait() {
|
|
||||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace marlin
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,79 +0,0 @@
|
||||||
|
|
||||||
#ifndef _data_types_cuh
|
|
||||||
#define _data_types_cuh
|
|
||||||
#include "marlin.cuh"
|
|
||||||
#include <cuda_fp16.h>
|
|
||||||
#include <cuda_bf16.h>
|
|
||||||
|
|
||||||
namespace marlin {
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
|
||||||
class ScalarType {};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
class ScalarType<half> {
|
|
||||||
public:
|
|
||||||
using scalar_t = half;
|
|
||||||
using scalar_t2 = half2;
|
|
||||||
|
|
||||||
// Matrix fragments for tensor core instructions; their precise layout is
|
|
||||||
// documented here:
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
|
||||||
using FragA = Vec<half2, 4>;
|
|
||||||
using FragB = Vec<half2, 2>;
|
|
||||||
using FragC = Vec<float, 4>;
|
|
||||||
using FragS = Vec<half2, 1>;
|
|
||||||
using FragZP = Vec<half2, 4>;
|
|
||||||
|
|
||||||
static __device__ float inline num2float(const half x) {
|
|
||||||
return __half2float(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ half2 inline num2num2(const half x) {
|
|
||||||
return __half2half2(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ half2 inline nums2num2(const half x1, const half x2) {
|
|
||||||
return __halves2half2(x1, x2);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __host__ __device__ half inline float2num(const float x) {
|
|
||||||
return __float2half(x);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
class ScalarType<nv_bfloat16> {
|
|
||||||
public:
|
|
||||||
using scalar_t = nv_bfloat16;
|
|
||||||
using scalar_t2 = nv_bfloat162;
|
|
||||||
|
|
||||||
using FragA = Vec<nv_bfloat162, 4>;
|
|
||||||
using FragB = Vec<nv_bfloat162, 2>;
|
|
||||||
using FragC = Vec<float, 4>;
|
|
||||||
using FragS = Vec<nv_bfloat162, 1>;
|
|
||||||
using FragZP = Vec<nv_bfloat162, 4>;
|
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
|
||||||
static __device__ float inline num2float(const nv_bfloat16 x) {
|
|
||||||
return __bfloat162float(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
|
|
||||||
return __bfloat162bfloat162(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
|
|
||||||
const nv_bfloat16 x2) {
|
|
||||||
return __halves2bfloat162(x1, x2);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
|
|
||||||
return __float2bfloat16(x);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace marlin
|
|
||||||
|
|
||||||
#endif
|
|
|
@ -1,51 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
|
|
||||||
* Rights Reserved.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
namespace marlin_24 {
|
|
||||||
|
|
||||||
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
|
||||||
|
|
||||||
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
|
|
||||||
// for instance as inputs to tensor core operations. Consequently, all
|
|
||||||
// corresponding index accesses must be compile-time constants, which is why we
|
|
||||||
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
|
||||||
// this.
|
|
||||||
template <typename T, int n>
|
|
||||||
struct Vec {
|
|
||||||
T elems[n];
|
|
||||||
__device__ T& operator[](int i) { return elems[i]; }
|
|
||||||
};
|
|
||||||
|
|
||||||
template <int M_, int N_, int K_>
|
|
||||||
struct ShapeBase {
|
|
||||||
static constexpr int M = M_, N = N_, K = K_;
|
|
||||||
};
|
|
||||||
|
|
||||||
using I4 = Vec<int, 4>;
|
|
||||||
|
|
||||||
// Matrix fragments for tensor core instructions; their precise layout is
|
|
||||||
// documented here:
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
|
||||||
using FragA = Vec<half2, 4>;
|
|
||||||
using FragB = Vec<half2, 2>;
|
|
||||||
using FragM = Vec<uint, 1>;
|
|
||||||
using FragC = Vec<float, 4>;
|
|
||||||
using FragS = Vec<half2, 1>; // quantization scales
|
|
||||||
|
|
||||||
} // namespace marlin_24
|
|
|
@ -1,136 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
|
|
||||||
* Rights Reserved.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
#include "base.h"
|
|
||||||
|
|
||||||
namespace marlin_24 {
|
|
||||||
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
|
||||||
// predication to handle batchsizes that are not multiples of 16.
|
|
||||||
__device__ inline void cp_async4_pred_zfill(void* smem_ptr,
|
|
||||||
const void* glob_ptr,
|
|
||||||
bool pred = true,
|
|
||||||
const bool zfill = false) {
|
|
||||||
const int BYTES = 16;
|
|
||||||
int src_in_bytes = (zfill ? 0 : BYTES);
|
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
|
||||||
asm volatile(
|
|
||||||
"{\n"
|
|
||||||
" .reg .pred p;\n"
|
|
||||||
" setp.ne.b32 p, %0, 0;\n"
|
|
||||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
|
||||||
"}\n" ::"r"((int)pred),
|
|
||||||
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
|
||||||
bool pred = true) {
|
|
||||||
const int BYTES = 16;
|
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
|
||||||
asm volatile(
|
|
||||||
"{\n"
|
|
||||||
" .reg .pred p;\n"
|
|
||||||
" setp.ne.b32 p, %0, 0;\n"
|
|
||||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
|
||||||
"}\n" ::"r"((int)pred),
|
|
||||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Asynchronous global->shared copy
|
|
||||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
|
||||||
const int BYTES = 16;
|
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
|
||||||
asm volatile(
|
|
||||||
"{\n"
|
|
||||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
|
||||||
"}\n" ::"r"(smem),
|
|
||||||
"l"(glob_ptr), "n"(BYTES));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Async copy fence.
|
|
||||||
__device__ inline void cp_async_fence() {
|
|
||||||
asm volatile("cp.async.commit_group;\n" ::);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait until at most `n` async copy stages are still pending.
|
|
||||||
template <int n>
|
|
||||||
__device__ inline void cp_async_wait() {
|
|
||||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
|
||||||
// memory, directly in tensor core layout.
|
|
||||||
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
|
|
||||||
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
|
||||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
|
||||||
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
|
||||||
: "r"(smem));
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
|
|
||||||
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
|
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
|
||||||
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
|
|
||||||
: "=r"(a[0]), "=r"(a[1])
|
|
||||||
: "r"(smem));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
|
||||||
// memory, directly in tensor core layout.
|
|
||||||
__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
|
|
||||||
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
|
||||||
asm volatile(
|
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
|
||||||
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
|
||||||
: "r"(smem));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait until barrier reaches `count`, then lock for current threadblock.
|
|
||||||
__device__ inline void barrier_acquire(int* lock, int count) {
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
int state = -1;
|
|
||||||
do
|
|
||||||
// Guarantee that subsequent writes by this threadblock will be visible
|
|
||||||
// globally.
|
|
||||||
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
|
|
||||||
: "=r"(state)
|
|
||||||
: "l"(lock));
|
|
||||||
while (state != count);
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Release barrier and increment visitation count.
|
|
||||||
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
|
||||||
__syncthreads();
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
if (reset) {
|
|
||||||
lock[0] = 0;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
int val = 1;
|
|
||||||
// Make sure that all writes since acquiring this barrier are visible
|
|
||||||
// globally, while releasing the barrier.
|
|
||||||
asm volatile("fence.acq_rel.gpu;\n");
|
|
||||||
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
|
|
||||||
:
|
|
||||||
: "l"(lock), "r"(val));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace marlin_24
|
|
|
@ -1,191 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
|
|
||||||
* Rights Reserved.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
#include "base.h"
|
|
||||||
#include <cudaTypedefs.h>
|
|
||||||
|
|
||||||
namespace marlin_24 {
|
|
||||||
|
|
||||||
// On CUDA earlier than 12.5, the ordered_metadata version of this instruction
|
|
||||||
// is not supported. On later versions of CUDA the version without ordered
|
|
||||||
// metadata results in the following warning:
|
|
||||||
// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction
|
|
||||||
// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially
|
|
||||||
// | reduced performance on some future architectures
|
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
|
|
||||||
#define MMA_SP_INST \
|
|
||||||
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
|
||||||
#else
|
|
||||||
#define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
|
|
||||||
// output/accumulation.
|
|
||||||
__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
|
|
||||||
const FragA& frag_b, FragC& frag_c, FragM& frag_m,
|
|
||||||
const int psel) {
|
|
||||||
const uint32_t* a0 = reinterpret_cast<const uint32_t*>(&a_frag0);
|
|
||||||
const uint32_t* a1 = reinterpret_cast<const uint32_t*>(&a_frag1);
|
|
||||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
|
||||||
const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
|
|
||||||
|
|
||||||
float* c = reinterpret_cast<float*>(&frag_c);
|
|
||||||
if (psel == 0) {
|
|
||||||
asm volatile(MMA_SP_INST
|
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
|
||||||
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
|
||||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
|
||||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
|
||||||
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
|
||||||
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
|
||||||
asm volatile(MMA_SP_INST
|
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
|
||||||
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
|
||||||
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
|
||||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
|
||||||
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
|
||||||
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
|
||||||
} else {
|
|
||||||
asm volatile(MMA_SP_INST
|
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
|
||||||
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
|
||||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
|
||||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
|
||||||
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
|
||||||
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
|
||||||
asm volatile(MMA_SP_INST
|
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
|
||||||
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
|
||||||
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
|
||||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
|
||||||
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
|
||||||
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lookup-table based 3-input logical operation; explicitly used for
|
|
||||||
// dequantization as the compiler does not seem to automatically recognize it in
|
|
||||||
// all cases.
|
|
||||||
template <int lut>
|
|
||||||
__device__ inline int lop3(int a, int b, int c) {
|
|
||||||
int res;
|
|
||||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
|
||||||
: "=r"(res)
|
|
||||||
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2,
|
|
||||||
float c3) {
|
|
||||||
uint2 r;
|
|
||||||
asm("{\n\t"
|
|
||||||
".reg .f16 a, b, c, d; \n\t"
|
|
||||||
"cvt.rn.f16.f32 a, %2; \n\t"
|
|
||||||
"cvt.rn.f16.f32 b, %3; \n\t"
|
|
||||||
"cvt.rn.f16.f32 c, %4; \n\t"
|
|
||||||
"cvt.rn.f16.f32 d, %5; \n\t"
|
|
||||||
"mov.b32 %0, {a, b}; \n\t"
|
|
||||||
"mov.b32 %1, {c, d}; \n\t"
|
|
||||||
"}"
|
|
||||||
: "=r"(r.x), "=r"(r.y)
|
|
||||||
: "f"(c0), "f"(c1), "f"(c2), "f"(c3));
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Constructs destination register by taking bytes from 2 sources (based on
|
|
||||||
// mask)
|
|
||||||
template <int start_byte, int mask>
|
|
||||||
__device__ inline uint32_t prmt(uint32_t a) {
|
|
||||||
uint32_t res;
|
|
||||||
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
|
||||||
: "=r"(res)
|
|
||||||
: "r"(a), "n"(start_byte), "n"(mask));
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
|
||||||
// values. We mostly follow the strategy in the link below, with some small
|
|
||||||
// changes:
|
|
||||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
|
||||||
__device__ inline FragB dequant_4bit(int q) {
|
|
||||||
const int LO = 0x000f000f;
|
|
||||||
const int HI = 0x00f000f0;
|
|
||||||
const int EX = 0x64006400;
|
|
||||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
|
||||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
|
||||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
|
||||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
|
||||||
// directly into `SUB` and `ADD`.
|
|
||||||
const int SUB = 0x64086408;
|
|
||||||
const int MUL = 0x2c002c00;
|
|
||||||
const int ADD = 0xd480d480;
|
|
||||||
|
|
||||||
FragB frag_b;
|
|
||||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
|
||||||
*reinterpret_cast<const half2*>(&SUB));
|
|
||||||
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
|
||||||
*reinterpret_cast<const half2*>(&MUL),
|
|
||||||
*reinterpret_cast<const half2*>(&ADD));
|
|
||||||
return frag_b;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
|
||||||
// values. We mostly follow the strategy in the link below, with some small
|
|
||||||
// changes:
|
|
||||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
|
||||||
__device__ inline FragB dequant_8bit(int q) {
|
|
||||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
|
||||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
|
||||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
|
||||||
|
|
||||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
|
||||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
|
||||||
|
|
||||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
|
||||||
|
|
||||||
FragB frag_b;
|
|
||||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
|
||||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
|
||||||
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
|
||||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
|
||||||
return frag_b;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Multiply dequantized values by the corresponding quantization scale; used
|
|
||||||
// only for grouped quantization.
|
|
||||||
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
|
|
||||||
half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
|
|
||||||
frag_b[0] = __hmul2(frag_b[0], s);
|
|
||||||
frag_b[1] = __hmul2(frag_b[1], s);
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
|
|
||||||
FragS& s0, float* c4, float* c5, float* c6,
|
|
||||||
float* c7, FragS& s1) {
|
|
||||||
*c0 = __fmul_rn(*c0, __half2float(s0[0].x));
|
|
||||||
*c1 = __fmul_rn(*c1, __half2float(s0[0].y));
|
|
||||||
*c2 = __fmul_rn(*c2, __half2float(s0[1].x));
|
|
||||||
*c3 = __fmul_rn(*c3, __half2float(s0[1].y));
|
|
||||||
|
|
||||||
*c4 = __fmul_rn(*c4, __half2float(s1[0].x));
|
|
||||||
*c5 = __fmul_rn(*c5, __half2float(s1[0].y));
|
|
||||||
*c6 = __fmul_rn(*c6, __half2float(s1[1].x));
|
|
||||||
*c7 = __fmul_rn(*c7, __half2float(s1[1].y));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace marlin_24
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,24 +0,0 @@
|
||||||
from setuptools import setup
|
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
|
||||||
|
|
||||||
extra_compile_args = []
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name="marlin_kernels",
|
|
||||||
ext_modules=[
|
|
||||||
CUDAExtension(
|
|
||||||
name="marlin_kernels",
|
|
||||||
sources=[
|
|
||||||
"marlin_kernels/awq_marlin_repack.cu",
|
|
||||||
"marlin_kernels/fp8_marlin.cu",
|
|
||||||
"marlin_kernels/gptq_marlin.cu",
|
|
||||||
"marlin_kernels/gptq_marlin_repack.cu",
|
|
||||||
"marlin_kernels/marlin_cuda_kernel.cu",
|
|
||||||
"marlin_kernels/sparse/marlin_24_cuda_kernel.cu",
|
|
||||||
"marlin_kernels/ext.cpp",
|
|
||||||
],
|
|
||||||
extra_compile_args=extra_compile_args,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
cmdclass={"build_ext": BuildExtension},
|
|
||||||
)
|
|
|
@ -1139,6 +1139,74 @@ files = [
|
||||||
{file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"},
|
{file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "marlin-kernels"
|
||||||
|
version = "0.2.0"
|
||||||
|
description = "Marlin quantization kernels"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:9a5afcf19b0f5917e43353cc19873fb3c4d4d0b924e2a95a37884f9ce208d0bd"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
torch = "*"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "url"
|
||||||
|
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "marlin-kernels"
|
||||||
|
version = "0.2.0"
|
||||||
|
description = "Marlin quantization kernels"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:1e64fcc7ebadfaffa60091ee9201ae3daaf5c1be3be60c8c054143a3dcb72d5d"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
torch = "*"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "url"
|
||||||
|
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "marlin-kernels"
|
||||||
|
version = "0.2.0"
|
||||||
|
description = "Marlin quantization kernels"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:e75f3ce9b1c13a4ed43a380d88e1d34d297259452db037ec1973ec33dc2eb78e"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
torch = "*"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "url"
|
||||||
|
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "marlin-kernels"
|
||||||
|
version = "0.2.0"
|
||||||
|
description = "Marlin quantization kernels"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:2f99a27f70b391887ee6adffeeee7c3f4df7fac37393f9fb16d4cace2b3f6457"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
torch = "*"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "url"
|
||||||
|
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mpmath"
|
name = "mpmath"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
|
@ -3507,6 +3575,7 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
|
||||||
[extras]
|
[extras]
|
||||||
accelerate = ["accelerate"]
|
accelerate = ["accelerate"]
|
||||||
bnb = ["bitsandbytes"]
|
bnb = ["bitsandbytes"]
|
||||||
|
marlin = ["marlin-kernels", "marlin-kernels", "marlin-kernels", "marlin-kernels"]
|
||||||
outlines = ["outlines"]
|
outlines = ["outlines"]
|
||||||
peft = ["peft"]
|
peft = ["peft"]
|
||||||
quantize = ["accelerate", "datasets", "texttable"]
|
quantize = ["accelerate", "datasets", "texttable"]
|
||||||
|
@ -3515,4 +3584,4 @@ torch = ["torch"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.9,<3.13"
|
python-versions = ">=3.9,<3.13"
|
||||||
content-hash = "c94bbdf8131750891fb3f7132066718534129d85a4c09126d8d01c2de6c72798"
|
content-hash = "a89867b23017d2efa8a7aa14d4764bcbd3b4dea9bfbf06a7a68464cb184ac6a1"
|
||||||
|
|
|
@ -40,10 +40,18 @@ py-cpuinfo = "^9.0.0"
|
||||||
# Remove later, temporary workaround for outlines.
|
# Remove later, temporary workaround for outlines.
|
||||||
numpy = "^1.26"
|
numpy = "^1.26"
|
||||||
|
|
||||||
|
marlin-kernels = [
|
||||||
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||||
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||||
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||||
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||||
|
]
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
torch = ["torch"]
|
torch = ["torch"]
|
||||||
accelerate = ["accelerate"]
|
accelerate = ["accelerate"]
|
||||||
bnb = ["bitsandbytes"]
|
bnb = ["bitsandbytes"]
|
||||||
|
marlin = ["marlin-kernels"]
|
||||||
peft = ["peft"]
|
peft = ["peft"]
|
||||||
quantize = ["texttable", "datasets", "accelerate"]
|
quantize = ["texttable", "datasets", "accelerate"]
|
||||||
outlines = ["outlines"]
|
outlines = ["outlines"]
|
||||||
|
|
|
@ -223,6 +223,7 @@ class GPTQMarlinLinear(nn.Module):
|
||||||
A_flat.shape[1],
|
A_flat.shape[1],
|
||||||
self.is_full_k,
|
self.is_full_k,
|
||||||
self.qzeros.numel() > 0,
|
self.qzeros.numel() > 0,
|
||||||
|
True,
|
||||||
)
|
)
|
||||||
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
|
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue