diff --git a/Dockerfile b/Dockerfile index c3a4623e..52393a76 100644 --- a/Dockerfile +++ b/Dockerfile @@ -140,13 +140,6 @@ COPY server/Makefile-eetq Makefile # Build specific version of transformers RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq -# Build marlin kernels -FROM kernel-builder AS marlin-kernels-builder -WORKDIR /usr/src -COPY server/marlin/ . -# Build specific version of transformers -RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build - # Build Lorax Punica kernels FROM kernel-builder AS lorax-punica-builder WORKDIR /usr/src @@ -231,9 +224,6 @@ COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-31 COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from eetq kernels builder COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages -# Copy build artifacts from marlin kernels builder -COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages -COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from fbgemm builder COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages # Copy build artifacts from vllm builder @@ -252,7 +242,7 @@ COPY server/Makefile server/Makefile RUN cd server && \ make gen-server && \ pip install -r requirements_cuda.txt && \ - pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir && \ + pip install ".[bnb, accelerate, marlin, quantize, peft, outlines]" --no-cache-dir && \ pip install nvidia-nccl-cu12==2.22.3 ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2 diff --git a/server/marlin/COPYRIGHT b/server/marlin/COPYRIGHT deleted file mode 100644 index 69f3b8e6..00000000 --- a/server/marlin/COPYRIGHT +++ /dev/null @@ -1,20 +0,0 @@ -These kernels were vendored from VLLM. The Marlin kernels were developed -by Elias Frantar and extended by Neural Magic. - ---- - -Copyright (C) Marlin.2024 Elias Frantar -Modified by Neural Magic -Copyright 2024 The vLLM team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/server/marlin/marlin_kernels/__init__.pyi b/server/marlin/marlin_kernels/__init__.pyi deleted file mode 100644 index 53464719..00000000 --- a/server/marlin/marlin_kernels/__init__.pyi +++ /dev/null @@ -1,76 +0,0 @@ -import torch - -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, -) -> torch.Tensor: - """ - Matrix multiplication using Marlin kernels. This is an extension of - `marlin_gemm` that supports converted GPTQ kernels. - """ - ... - -def gptq_marlin_24_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_meta: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - """ - Matrix multiplication using Marlin kernels. This is an extension of - `marlin_gemm` that supports 2:4 sparsity. - """ - ... - -def gptq_marlin_repack( - b_q_weight: torch.Tensor, - perm: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int, -) -> torch.Tensor: - """Repack GPTQ parameters for Marlin kernels.""" - ... - -def marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - """ - Matrix multiplication using Marlin kernels. - """ - ... - -# fp8 marlin -def fp8_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return torch.ops._C.fp8_marlin_gemm( - a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k - ) diff --git a/server/marlin/marlin_kernels/awq_marlin_repack.cu b/server/marlin/marlin_kernels/awq_marlin_repack.cu deleted file mode 100644 index c58216d8..00000000 --- a/server/marlin/marlin_kernels/awq_marlin_repack.cu +++ /dev/null @@ -1,269 +0,0 @@ -#include "marlin.cuh" - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -namespace marlin { - -template -__global__ void awq_marlin_repack_kernel( - uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, - int size_k, int size_n) {} - -} // namespace marlin - -torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, - int64_t size_k, int64_t size_n, - int64_t num_bits) { - TORCH_CHECK_NOT_IMPLEMENTED( - false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else - -namespace marlin { - -template -__global__ void awq_marlin_repack_kernel( - uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, - int size_k, int size_n) { - constexpr int pack_factor = 32 / num_bits; - - int k_tiles = size_k / tile_k_size; - int n_tiles = size_n / tile_n_size; - int block_k_tiles = div_ceil(k_tiles, gridDim.x); - - int start_k_tile = blockIdx.x * block_k_tiles; - if (start_k_tile >= k_tiles) { - return; - } - - int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - extern __shared__ int4 sh[]; - - constexpr int tile_n_ints = tile_n_size / pack_factor; - - constexpr int stage_n_threads = tile_n_ints / 4; - constexpr int stage_k_threads = tile_k_size; - constexpr int stage_size = stage_k_threads * stage_n_threads; - - auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { - if (n_tile_id >= n_tiles) { - cp_async_fence(); - return; - } - - int first_n = n_tile_id * tile_n_size; - int first_n_packed = first_n / pack_factor; - - int4* sh_ptr = sh + stage_size * pipe; - - if (threadIdx.x < stage_size) { - int k_id = threadIdx.x / stage_n_threads; - int n_id = threadIdx.x % stage_n_threads; - - int first_k = k_tile_id * tile_k_size; - - cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast( - &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + - first_n_packed + (n_id * 4)]))); - } - - cp_async_fence(); - }; - - auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { - if (n_tile_id >= n_tiles) { - return; - } - - int warp_id = threadIdx.x / 32; - int th_id = threadIdx.x % 32; - - if (warp_id >= 4) { - return; - } - - int tc_col = th_id / 4; - int tc_row = (th_id % 4) * 2; - - constexpr int tc_offsets[4] = {0, 1, 8, 9}; - - int cur_n = warp_id * 16 + tc_col; - int cur_n_packed = cur_n / pack_factor; - int cur_n_pos = cur_n % pack_factor; - - constexpr int sh_stride = tile_n_ints; - constexpr uint32_t mask = (1 << num_bits) - 1; - - int4* sh_stage_ptr = sh + stage_size * pipe; - uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); - - // Undo interleaving - int cur_n_pos_unpacked; - if constexpr (num_bits == 4) { - constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7}; - cur_n_pos_unpacked = undo_pack[cur_n_pos]; - } else { - constexpr int undo_pack[4] = {0, 2, 1, 3}; - cur_n_pos_unpacked = undo_pack[cur_n_pos]; - } - - uint32_t vals[8]; - #pragma unroll - for (int i = 0; i < 4; i++) { - int cur_elem = tc_row + tc_offsets[i]; - - int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; - int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + - sh_stride * cur_elem]; - - vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; - vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; - } - - constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; - int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; - - // Result of: - // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - if constexpr (num_bits == 4) { - constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - - uint32_t res = 0; - #pragma unroll - for (int i = 0; i < 8; i++) { - res |= vals[pack_idx[i]] << (i * 4); - } - - out_ptr[out_offset + th_id * 4 + warp_id] = res; - - } else { - constexpr int pack_idx[4] = {0, 2, 1, 3}; - - uint32_t res1 = 0; - uint32_t res2 = 0; - #pragma unroll - for (int i = 0; i < 4; i++) { - res1 |= vals[pack_idx[i]] << (i * 8); - res2 |= vals[4 + pack_idx[i]] << (i * 8); - } - - out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; - out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; - } - }; - - auto start_pipes = [&](int k_tile_id, int n_tile_id) { - #pragma unroll - for (int pipe = 0; pipe < repack_stages - 1; pipe++) { - fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); - } - - wait_for_stage(); - }; - #pragma unroll - for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { - int n_tile_id = 0; - - start_pipes(k_tile_id, n_tile_id); - - while (n_tile_id < n_tiles) { - #pragma unroll - for (int pipe = 0; pipe < repack_stages; pipe++) { - fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, - n_tile_id + pipe + repack_stages - 1); - repack_tile(pipe, k_tile_id, n_tile_id + pipe); - wait_for_stage(); - } - n_tile_id += repack_stages; - } - } -} - -} // namespace marlin - - #define CALL_IF(NUM_BITS) \ - else if (num_bits == NUM_BITS) { \ - cudaFuncSetAttribute( \ - marlin::awq_marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - marlin::awq_marlin_repack_kernel \ - <<>>( \ - b_q_weight_ptr, out_ptr, size_k, size_n); \ - } - -torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, - int64_t size_n, int64_t num_bits) { - // Verify compatibility with marlin tile of 16x64 - TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, - " is not divisible by tile_k_size = ", marlin::tile_k_size); - TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n, - " is not divisible by tile_n_size = ", marlin::tile_n_size); - - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - int const pack_factor = 32 / num_bits; - - // Verify B - TORCH_CHECK(b_q_weight.size(0) == size_k, - "b_q_weight.size(0) = ", b_q_weight.size(0), - " is not size_k = ", size_k); - TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1), - "Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1), - ", size_n = ", size_n, ", pack_factor = ", pack_factor); - - // Verify device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt"); - - // Alloc buffers - const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); - auto options = torch::TensorOptions() - .dtype(b_q_weight.dtype()) - .device(b_q_weight.device()); - torch::Tensor out = torch::empty( - {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, - options); - - // Get ptrs - uint32_t const* b_q_weight_ptr = - reinterpret_cast(b_q_weight.data_ptr()); - uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); - - // Get dev info - int dev = b_q_weight.get_device(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); - int blocks; - cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - if (false) { - } - CALL_IF(4) - CALL_IF(8) - else { - TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits); - } - - return out; -} - -#endif diff --git a/server/marlin/marlin_kernels/ext.cpp b/server/marlin/marlin_kernels/ext.cpp deleted file mode 100644 index 8dcd2ff9..00000000 --- a/server/marlin/marlin_kernels/ext.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include - -#include "ext.hh" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("awq_marlin_repack", &awq_marlin_repack, - "Repack AWQ parameters for Marlin"); - m.def("gptq_marlin_gemm", &gptq_marlin_gemm, - "Marlin gemm with GPTQ compatibility"); - m.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin sparse 2:4 gemm"); - m.def("gptq_marlin_repack", &gptq_marlin_repack, - "Repack GPTQ parameters for Marlin"); - m.def("marlin_gemm", &marlin_gemm, "Marlin gemm"); - // fp8_marlin Optimized Quantized GEMM for FP8 weight-only. - m.def("fp8_marlin_gemm", &fp8_marlin_gemm); -} diff --git a/server/marlin/marlin_kernels/ext.hh b/server/marlin/marlin_kernels/ext.hh deleted file mode 100644 index ec76b5f4..00000000 --- a/server/marlin/marlin_kernels/ext.hh +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once - -#include - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -// No support for async -#else - -torch::Tensor awq_marlin_repack(torch::Tensor &b_q_weight, int64_t size_k, - int64_t size_n, int64_t num_bits); - -torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &b_zeros, - torch::Tensor &g_idx, torch::Tensor &perm, - torch::Tensor &workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp); - -torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_meta, - torch::Tensor &b_scales, - torch::Tensor &workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, - int64_t size_k); - -torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, - int64_t size_k, int64_t size_n, - int64_t num_bits); - -torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &workspace, - int64_t size_m, int64_t size_n, int64_t size_k); - -torch::Tensor fp8_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k); - -#endif diff --git a/server/marlin/marlin_kernels/fp8_marlin.cu b/server/marlin/marlin_kernels/fp8_marlin.cu deleted file mode 100644 index f6458ca6..00000000 --- a/server/marlin/marlin_kernels/fp8_marlin.cu +++ /dev/null @@ -1,1305 +0,0 @@ -/* - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Adapted from https://github.com/IST-DASLab/marlin - */ - -#include "marlin.cuh" -#include "marlin_dtypes.cuh" - -using namespace marlin; - -#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ - static_assert(std::is_same::value || \ - std::is_same::value, \ - "only float16 and bfloat16 is supported"); - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace fp8_marlin { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) {} - -} // namespace fp8_marlin - -torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k) { - TORCH_CHECK_NOT_IMPLEMENTED(false, - "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -template -__device__ inline void mma(const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - typename ScalarType::FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); - } -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -template -__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, - const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Fast FP8ToFp16/FP8ToBf16: Efficiently dequantize 8bit fp8_e4m3 values to fp16 -// bf16 Reference: -// - FP16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -// - BF16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 -template -__device__ inline typename ScalarType::FragB dequant_8bit(int q) { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); -} - -template <> -__device__ inline typename ScalarType::FragB dequant_8bit(int q) { - // Constants for FP8 (E4M3) and FP16 formats - constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5; - constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; - - // Calculate MASK for extracting mantissa and exponent - constexpr int MASK1 = 0x80000000; - constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); - constexpr int MASK3 = MASK2 & 0x7fffffff; - constexpr int MASK = MASK3 | (MASK3 >> 16); - // Final MASK value: 0x7F007F00 - - // Extract and shift FP8 values to FP16 format - int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); - - // Construct and apply exponent bias - constexpr int BIAS_OFFSET = - (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); - const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); - - // Convert to half2 and apply bias - typename ScalarType::FragB frag_b; - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); - frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant_8bit(int q) { - // Constants for FP8 (E4M3) and BF16 formats - constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8; - constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; - - // Calculate MASK for extracting mantissa and exponent - constexpr int MASK1 = 0x80000000; - constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); - constexpr int MASK3 = MASK2 & 0x7fffffff; - constexpr int MASK = MASK3 | (MASK3 >> 16); - // Final MASK value: 0x7F007F00 - - // Extract and shift FP8 values to BF16 format - int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); - - // Construct and apply exponent bias - constexpr int BIAS_OFFSET = - (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); - // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent - // position - constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; - const nv_bfloat162 bias_reg = - __float2bfloat162_rn(*reinterpret_cast(&BIAS)); - - // Convert to bfloat162 and apply bias - typename ScalarType::FragB frag_b; - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); - frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -template -__device__ inline void scale(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s = - ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -// Given 2 floats multiply by 2 scales (halves) -template -__device__ inline void scale_float(float* c, - typename ScalarType::FragS& s) { - scalar_t* s_ptr = reinterpret_cast(&s); - c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - using Dtype = ScalarType; - using scalar_t2 = typename ScalarType::scalar_t2; - using FragA = typename ScalarType::FragA; - using FragB = typename ScalarType::FragB; - using FragC = typename ScalarType::FragC; - using FragS = typename ScalarType::FragS; - - constexpr int pack_factor = 32 / num_bits; - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = div_ceil(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // A sizes/strides - - // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); - - // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - - // Scale size/strides with act_order - constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; - - // For act_order - int slice_k_start = tb_k * slice_row; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; - - // No act_order - int s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - int s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // We scale a `half2` tile in row-major layout for column-wise quantization. - int s_sh_rd = - 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_s = sh_g_idx + (stages * g_idx_stage); - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - int sh_first_group_id = -1; - int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; - - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - - B_ptr[i] += b_gl_rd_delta_o; - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], - &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - }; - - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - is_same_group[pipe] = false; - same_group_id[pipe] = 0; - return; - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - FragB frag_b0; - FragB frag_b1; - - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - - frag_b0 = dequant_8bit(b_quant_0); - frag_b1 = dequant_8bit(b_quant_1); - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - scalar_t2 res = - Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); - - ((scalar_t2*)sh)[idx] = res; - }; - - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - - #pragma unroll - for (int i = 0; i < stages - 1; i++) { - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); - }; - if (slice_iters) { - start_pipes(); - } - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. - - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - - thread_block_reduce(); - - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - - start_pipes(); - } - } - } -} - - #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, num_groups, prob_m, prob_n, prob_k, \ - locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -typedef struct { - int max_m_blocks; - thread_config_t tb_cfg; -} exec_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, - {64, 128, 128}, - {128, 64, 128}, -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, - {64, 128, 128}, - {128, 64, 128}, - -}; - -int get_scales_cache_size(thread_config_t const& th_config, int prob_m, - int prob_n, int prob_k, int num_bits, - int group_size) { - int tb_n = th_config.thread_n; - - // Get max scale groups per thread-block - // Fixed for channelwise - int tb_groups = 1; - int tb_scales = tb_groups * tb_n * 2; - - return tb_scales * pipe_stages; -} - -bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int scales_cache_size, int max_shared_mem) { - int pack_factor = 32 / num_bits; - - // Get B size - int tb_k = th_config.thread_k; - int tb_n = th_config.thread_n; - - int b_size = (tb_k * tb_n / pack_factor) * 4; - - // Get A size - int m_blocks = div_ceil(prob_m, 16); - int tb_max_m = 16; - - while (true) { - if (m_blocks >= max_m_blocks) { - tb_max_m *= max_m_blocks; - break; - } - - max_m_blocks--; - if (max_m_blocks == 0) { - TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); - } - } - - int a_size = (tb_max_m * tb_k) * 2; - - float pipe_size = (a_size + b_size) * pipe_stages; - - TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity - - return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); -} - -bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int group_size, int max_shared_mem) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - // Determine cache for scales - int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n, - prob_k, num_bits, group_size); - - // Check that pipeline fits into cache - if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, scales_cache_size, max_shared_mem)) { - return false; - } - - return true; -} - -exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, - int num_bits, int group_size, - int max_shared_mem) { - int max_m_blocks = 4; - while (max_m_blocks > 0) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } - - max_m_blocks--; // Process less M blocks per invocation to reduce cache - // usage - } - - return exec_config_t{0, {-1, -1, -1}}; -} - - #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) - -template -void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, int prob_m, - int prob_n, int prob_k, void* workspace, int num_bits, - int num_groups, int group_size, int dev, - cudaStream_t stream, int thread_k, int thread_n, int sms, - int max_par) { - TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); - TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, - ", ", prob_n, ", ", prob_k, "]"); - - int tot_m = prob_m; - int tot_m_blocks = div_ceil(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) { - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - } - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - // Set thread config - exec_config_t exec_cfg; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - exec_cfg = - exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; - } else { - // Auto config - exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits, - group_size, max_shared_mem); - } - - TORCH_CHECK( - exec_cfg.max_m_blocks > 0 && - is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m, - prob_n, prob_k, num_bits, group_size, max_shared_mem), - "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, - ", thread_k = ", exec_cfg.tb_cfg.thread_k, - ", thread_n = ", exec_cfg.tb_cfg.thread_n, - ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m, - ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, - ", group_size = ", group_size, ", max_shared_mem = ", max_shared_mem); - - int num_threads = exec_cfg.tb_cfg.num_threads; - thread_k = exec_cfg.tb_cfg.thread_k; - thread_n = exec_cfg.tb_cfg.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - - int blocks = sms; - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - - int group_blocks = -1; - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - - int* locks = (int*)workspace; - - // Main loop - for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > exec_cfg.max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); - if (par > max_par) par = max_par; - prob_m = (16 * exec_cfg.max_m_blocks) * par; - i += exec_cfg.max_m_blocks * (par - 1); - thread_m_blocks = exec_cfg.max_m_blocks; - } - - // Define kernel configurations - if (false) { - } - CALL_IF(8, 32, 2, 256) - CALL_IF(8, 16, 4, 256) - CALL_IF(8, 8, 8, 256) - CALL_IF(8, 8, 4, 128) - CALL_IF(8, 4, 8, 128) - else { - TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + - str(prob_n) + ", " + str(prob_k) + "]" + - ", num_groups = " + str(num_groups) + - ", group_size = " + str(group_size) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; - C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - } -} - -} // namespace fp8_marlin - -torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k) { - // Verify num_bits - TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); - int pack_factor = 32 / num_bits; - - // Verify A - TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), - ", size_m = ", size_m); - TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), - ", size_k = ", size_k); - - // Verify B - TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k, - " is not divisible by tile_size = ", marlin::tile_size); - TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, ", tile_size = ", marlin::tile_size); - TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, - "b_q_weight.size(1) = ", b_q_weight.size(1), - " is not divisible by tile_size = ", marlin::tile_size); - int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor; - TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, - ", actual_size_n = ", actual_size_n); - - // Verify device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - - // Alloc buffers - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Detect groupsize and act_order - int num_groups = -1; - int group_size = -1; - - int b_rank = b_scales.sizes().size(); - TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); - TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), - " is not size_n = ", size_n); - // Channelwise only for FP8 - TORCH_CHECK(b_scales.size(0) == 1) - num_groups = b_scales.size(0); - - // Verify workspace size - TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n, - ", is not divisible by min_thread_n = ", marlin::min_thread_n); - int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = ", workspace.numel(), - " is below min_workspace_size = ", min_workspace_size); - - int dev = a.get_device(); - if (a.scalar_type() == at::ScalarType::Half) { - fp8_marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), num_bits, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - marlin::max_par); - } else if (a.scalar_type() == at::ScalarType::BFloat16) { - fp8_marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), b_scales.data_ptr(), size_m, - size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, - dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - marlin::max_par); - } else { - TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16"); - } - - return c; -} - -#endif diff --git a/server/marlin/marlin_kernels/gptq_marlin.cu b/server/marlin/marlin_kernels/gptq_marlin.cu deleted file mode 100644 index 122c5c16..00000000 --- a/server/marlin/marlin_kernels/gptq_marlin.cu +++ /dev/null @@ -1,2195 +0,0 @@ -/* - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Adapted from https://github.com/IST-DASLab/marlin - */ - -#include "marlin.cuh" -#include "marlin_dtypes.cuh" - -#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ - static_assert(std::is_same::value || \ - std::is_same::value, \ - "only float16 and bfloat16 is supported"); - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace marlin { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, - int const* __restrict__ perm_int_ptr, - int4* __restrict__ out_int4_ptr, int size_m, - int size_k, int block_rows) {} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int* __restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) {} - -} // namespace gptq_marlin - -torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& b_zeros, - torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full) { - TORCH_CHECK_NOT_IMPLEMENTED(false, - "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -template -__device__ inline void mma(const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - typename ScalarType::FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); - } -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -template -__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, - const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// - FP16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -// - BF16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 -template -__device__ inline typename ScalarType::FragB dequant_4bit(int q) { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); -} - -template <> -__device__ inline typename ScalarType::FragB dequant_4bit(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant_4bit(int q) { - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t EX = 0x43004300; - - // Guarantee that the `(a & b) | c` operations are LOP3s. - - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - q >>= 4; - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - - typename ScalarType::FragB frag_b; - static constexpr uint32_t MUL = 0x3F803F80; - static constexpr uint32_t ADD = 0xC308C308; - - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or -// bf16 Reference: -// - FP16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -// - BF16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 -template -__device__ inline typename ScalarType::FragB dequant_8bit(int q) { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); -} - -template <> -__device__ inline typename ScalarType::FragB dequant_8bit(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant_8bit(int q) { - typename ScalarType::FragB frag_b; - - float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = - reinterpret_cast(fp32_intermediates); - - static constexpr uint32_t fp32_base = 0x4B000000; - fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); - - fp32_intermediates[0] -= 8388736.f; - fp32_intermediates[1] -= 8388736.f; - fp32_intermediates[2] -= 8388736.f; - fp32_intermediates[3] -= 8388736.f; - - uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); - bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], - fp32_intermediates_casted[1], 0x7632); - bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], - fp32_intermediates_casted[3], 0x7632); - - return frag_b; -} - -// Zero-point dequantizers - -template -__device__ inline typename ScalarType::FragB dequant_4bit_zp(int q) { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); -} - -template <> -__device__ inline typename ScalarType::FragB dequant_4bit_zp( - int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - - const int SUB = 0x64006400; - const int MUL = 0x2c002c00; - const int ADD = 0xd400d400; - typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant_4bit_zp(int q) { - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t EX = 0x43004300; - - // Guarantee that the `(a & b) | c` operations are LOP3s. - - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - q >>= 4; - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - - typename ScalarType::FragB frag_b; - static constexpr uint32_t MUL = 0x3F803F80; - static constexpr uint32_t ADD = 0xC300C300; - - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template -__device__ inline typename ScalarType::FragB dequant_8bit_zp(int q) { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); -} - -template <> -__device__ inline typename ScalarType::FragB dequant_8bit_zp( - int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; - - typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant_8bit_zp(int q) { - typename ScalarType::FragB frag_b; - - float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = - reinterpret_cast(fp32_intermediates); - - static constexpr uint32_t fp32_base = 0x4B000000; - fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); - - fp32_intermediates[0] -= 8388608.f; - fp32_intermediates[1] -= 8388608.f; - fp32_intermediates[2] -= 8388608.f; - fp32_intermediates[3] -= 8388608.f; - - uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); - bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], - fp32_intermediates_casted[1], 0x7632); - bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], - fp32_intermediates_casted[3], 0x7632); - - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -template -__device__ inline void scale(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s = - ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -template -__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, - typename ScalarType::scalar_t2& frag_zp, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 zp = - ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); - frag_b[0] = __hsub2(frag_b[0], zp); - frag_b[1] = __hsub2(frag_b[1], zp); -} - -// Same as above, but for act_order (each K is multiplied individually) -template -__device__ inline void scale4(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s_1, - typename ScalarType::FragS& frag_s_2, - typename ScalarType::FragS& frag_s_3, - typename ScalarType::FragS& frag_s_4, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s_val_1_2; - s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; - - scalar_t2 s_val_3_4; - s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; - - frag_b[0] = __hmul2(frag_b[0], s_val_1_2); - frag_b[1] = __hmul2(frag_b[1], s_val_3_4); -} - -// Given 2 floats multiply by 2 scales (halves) -template -__device__ inline void scale_float(float* c, - typename ScalarType::FragS& s) { - scalar_t* s_ptr = reinterpret_cast(&s); - c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - -// For a given "a" of size [M,K] performs a permutation of the K columns based -// on the given "perm" indices. -__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, - int const* __restrict__ perm_int_ptr, - int4* __restrict__ out_int4_ptr, int size_m, - int size_k, int block_rows) { - int start_row = block_rows * blockIdx.x; - int finish_row = start_row + block_rows; - if (finish_row > size_m) { - finish_row = size_m; - } - int cur_block_rows = finish_row - start_row; - - int row_stride = size_k * sizeof(half) / 16; - - auto permute_row = [&](int row) { - int iters = size_k / default_threads; - int rest = size_k % default_threads; - - int offset = row * row_stride; - - half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); - half* out_half = reinterpret_cast(out_int4_ptr + offset); - - int base_k = 0; - - for (int i = 0; i < iters; i++) { - int cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; - - out_half[cur_k] = a_row_half[src_pos]; - - base_k += default_threads; - } - - if (rest) { - if (threadIdx.x < rest) { - int cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; - - out_half[cur_k] = a_row_half[src_pos]; - } - } - }; - - for (int i = 0; i < cur_block_rows; i++) { - int cur_row = start_row + i; - if (cur_row < size_m) { - permute_row(cur_row); - } - } -} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - using Dtype = ScalarType; - using scalar_t2 = typename ScalarType::scalar_t2; - using FragA = typename ScalarType::FragA; - using FragB = typename ScalarType::FragB; - using FragC = typename ScalarType::FragC; - using FragS = typename ScalarType::FragS; - using FragZP = typename ScalarType::FragZP; - - constexpr int pack_factor = 32 / num_bits; - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); - - if constexpr (!has_act_order && group_blocks != -1) { - if (group_blocks >= thread_k_blocks) { - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts - // in the middle of group. - iters = (group_blocks / thread_k_blocks) * - div_ceil(iters, (group_blocks / thread_k_blocks)); - } - } - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = div_ceil(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // A sizes/strides - - // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); - - // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = - !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks - : 1; - constexpr int s_sh_stage = s_tb_groups * s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - - // Scale size/strides with act_order - constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - - // Zero-points sizes/strides - int zp_gl_stride = (prob_n / pack_factor) / 4; - constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; - constexpr int zp_tb_groups = s_tb_groups; - constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; - int zp_gl_rd_delta = zp_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; - - // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; - int slice_k_start = tb_k * slice_row; - int slice_k_finish = slice_k_start + tb_k * slice_iters; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; - - // No act_order - int s_gl_rd; - if constexpr (!has_act_order) { - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } - } - int s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // Zero-points - int zp_gl_rd; - if constexpr (has_zp) { - if constexpr (group_blocks == -1) { - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } else { - zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - zp_sh_stride * slice_col + threadIdx.x; - } - } - int zp_sh_wr = threadIdx.x; - bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; - - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - int s_sh_rd; - if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - // Zero-points have the same read layout as the scales - // (without column-wise case) - constexpr int num_col_threads = 8; - constexpr int num_row_threads = 4; - constexpr int num_ints_per_thread = 8 / pack_factor; - int zp_sh_rd; - if constexpr (has_zp) { - zp_sh_rd = num_ints_per_thread * num_col_threads * - ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); - } - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_zp = sh_g_idx + (stages * g_idx_stage); - int4* sh_s = sh_zp + (stages * zp_sh_stage); - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order - int frag_qzp[2][num_ints_per_thread]; // Zero-points - FragZP frag_zp; // Zero-points in fp16 - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - int sh_first_group_id = -1; - int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; - - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - - B_ptr[i] += b_gl_rd_delta_o; - } - - if constexpr (has_act_order) { - // Fetch g_idx thread-block portion - int full_pipe = a_off; - int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; - if (cur_k < prob_k && cur_k < slice_k_finish) { - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - - int4 const* cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); - - if (threadIdx.x < g_idx_stage) { - cp_async4_pred(&sh_g_idx_stage[threadIdx.x], - &cur_g_idx_stage_ptr[threadIdx.x]); - } - } - } else { - if constexpr (group_blocks != -1) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } - } - - if constexpr (has_zp && group_blocks != -1) { - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } else { - for (int i = 0; i < zp_tb_groups; i++) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], - &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - auto fetch_zp_to_shared = [&]() { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], - &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - }; - - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - if constexpr (!has_act_order) { - is_same_group[pipe] = false; - same_group_id[pipe] = 0; - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - int group_id_1 = sh_g_idx_int_ptr[0]; - int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - - is_same_group[pipe] = group_id_1 == group_id_2; - same_group_id[pipe] = group_id_1; - }; - - auto fetch_scales_to_registers = [&](int k, int full_pipe) { - int pipe = full_pipe % stages; - - if constexpr (!has_act_order) { - // No act-order case - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; - - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } - } - - return; - } - - // Act-order case - - // Determine K of the "current" thread-block - int cur_k = slice_k_start + tb_k * full_pipe; - if (cur_k >= prob_k || cur_k >= slice_k_finish) { - return; - } - - // Reset (to current thread-block) since we read g_idx portion from the - // shared memory - cur_k = 0; - - // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); - - // Determine "position" inside the thread-block (based on warp and - // thread-id) - int warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; - - int th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix - - int s_col_shift = - /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + - (th_id / 4) * act_s_col_stride; - - if (is_same_group[pipe]) { - if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + - s_col_shift]; - } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); - } - - for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); - } - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread - - #pragma unroll - for (int i = 0; i < 4; i++) { - int actual_k = cur_k + k_frag_offsets[i]; - - int group_id = sh_g_idx_int_ptr[actual_k]; - int rel_group_id = group_id - sh_first_group_id; - - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - sh_s[rel_group_id * s_sh_stride + s_col_shift]; - } - }; - - auto fetch_zp_to_registers = [&](int k, int full_pipe) { - if constexpr (!has_zp) { - return; - } - - int pipe = full_pipe % stages; - - if constexpr (group_blocks == -1) { - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; - } - - } else if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; - - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - sh_zp_stage += cur_group_id * zp_sh_stride; - - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - if constexpr (has_zp) { - FragB frag_zp_0; - FragB frag_zp_1; - if constexpr (num_bits == 4) { - int zp_quant = frag_qzp[k % 2][0]; - int zp_quant_shift = zp_quant >> 8; - frag_zp_0 = dequant_4bit_zp(zp_quant); - frag_zp_1 = dequant_4bit_zp(zp_quant_shift); - - } else { - int zp_quant_0 = frag_qzp[k % 2][0]; - int zp_quant_1 = frag_qzp[k % 2][1]; - frag_zp_0 = dequant_8bit_zp(zp_quant_0); - frag_zp_1 = dequant_8bit_zp(zp_quant_1); - } - - frag_zp[0] = frag_zp_0[0]; - frag_zp[1] = frag_zp_0[1]; - frag_zp[2] = frag_zp_1[0]; - frag_zp[3] = frag_zp_1[1]; - } - - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - FragB frag_b0; - FragB frag_b1; - if constexpr (num_bits == 4) { - int b_quant = frag_b_quant[k % 2][0][j]; - int b_quant_shift = b_quant >> 8; - - if constexpr (has_zp) { - frag_b0 = dequant_4bit_zp(b_quant); - frag_b1 = dequant_4bit_zp(b_quant_shift); - - } else { - frag_b0 = dequant_4bit(b_quant); - frag_b1 = dequant_4bit(b_quant_shift); - } - - } else { - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - - if constexpr (has_zp) { - frag_b0 = dequant_8bit_zp(b_quant_0); - frag_b1 = dequant_8bit_zp(b_quant_1); - } else { - frag_b0 = dequant_8bit(b_quant_0); - frag_b1 = dequant_8bit(b_quant_1); - } - } - - // Apply zero-point to frag_b0 - if constexpr (has_zp) { - sub_zp(frag_b0, frag_zp[j], 0); - } - - // Apply scale to frag_b0 - if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], - act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], - act_frag_s[k % 2][3][j], 0); - } else { - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - } - - // Apply zero-point to frag_b1 - if constexpr (has_zp) { - sub_zp(frag_b1, frag_zp[j], 1); - } - - // Apply scale to frag_b1 - if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], - act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], - act_frag_s[k % 2][3][j], 1); - - } else { - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - scalar_t2 res = - Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); - - // For per-column quantization we finally apply the scale here (only for - // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) { - res = __hmul2(res, s[0]); - } - - ((scalar_t2*)sh)[idx] = res; - }; - - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - - #pragma unroll - for (int i = 0; i < stages - 1; i++) { - if (has_act_order && i == 0) { - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); - } - - if constexpr (has_zp && group_blocks == -1) { - if (i == 0) { - fetch_zp_to_shared(); - } - } - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - fetch_scales_to_registers(0, 0); - fetch_zp_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); - }; - if (slice_iters) { - start_pipes(); - } - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. - - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - fetch_scales_to_registers(k + 1, pipe); - fetch_zp_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - if constexpr (has_act_order) { - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); - } - } - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (num_bits == 8) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } else { - if (last) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } - } - } - - thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (num_bits == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - } else { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - } - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float( - reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float( - reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float( - reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float( - reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - if constexpr (has_act_order) { - slice_k_start = tb_k * slice_row; - slice_k_finish = slice_k_start + tb_k * slice_iters; - slice_k_start_shared_fetch = slice_k_start; - slice_n_offset = act_s_col_tb_stride * slice_col; - - } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } - - start_pipes(); - } - } - } -} - - #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ - NUM_THREADS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \ - prob_m, prob_n, prob_k, locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -typedef struct { - int max_m_blocks; - thread_config_t tb_cfg; -} exec_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, - {64, 128, 128}, - {128, 64, 128}, -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, - {64, 128, 128}, - {128, 64, 128}, - -}; - -int get_scales_cache_size(thread_config_t const& th_config, int prob_m, - int prob_n, int prob_k, int num_bits, int group_size, - bool has_act_order, bool is_k_full) { - bool cache_scales_chunk = has_act_order && !is_k_full; - - int tb_n = th_config.thread_n; - int tb_k = th_config.thread_k; - - // Get max scale groups per thread-block - int tb_groups; - if (group_size == -1) { - tb_groups = 1; - } else if (group_size == 0) { - tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size - } else { - tb_groups = div_ceil(tb_k, group_size); - } - - if (cache_scales_chunk) { - int load_groups = - tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K - load_groups = max(load_groups, 32); // We load at least 32 scale groups - return load_groups * tb_n * 2; - - } else { - int tb_scales = tb_groups * tb_n * 2; - - return tb_scales * pipe_stages; - } -} - -bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int scales_cache_size, int max_shared_mem) { - int pack_factor = 32 / num_bits; - - // Get B size - int tb_k = th_config.thread_k; - int tb_n = th_config.thread_n; - - int b_size = (tb_k * tb_n / pack_factor) * 4; - - // Get A size - int m_blocks = div_ceil(prob_m, 16); - int tb_max_m = 16; - - while (true) { - if (m_blocks >= max_m_blocks) { - tb_max_m *= max_m_blocks; - break; - } - - max_m_blocks--; - if (max_m_blocks == 0) { - TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); - } - } - - int a_size = (tb_max_m * tb_k) * 2; - - float pipe_size = (a_size + b_size) * pipe_stages; - - TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity - - return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); -} - -bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int group_size, bool has_act_order, bool is_k_full, - int max_shared_mem) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - // Determine cache for scales - int scales_cache_size = - get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, - group_size, has_act_order, is_k_full); - - // Check that pipeline fits into cache - if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, scales_cache_size, max_shared_mem)) { - return false; - } - - return true; -} - -exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, - int num_bits, int group_size, - bool has_act_order, bool is_k_full, - int max_shared_mem) { - int max_m_blocks = 4; - while (max_m_blocks > 0) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } - - max_m_blocks--; // Process less M blocks per invocation to reduce cache - // usage - } - - return exec_config_t{0, {-1, -1, -1}}; -} - - #define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) - - #define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) - -template -void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp, - void* g_idx, void* perm, void* a_tmp, int prob_m, - int prob_n, int prob_k, void* workspace, int num_bits, - bool has_act_order, bool is_k_full, bool has_zp, - int num_groups, int group_size, int dev, - cudaStream_t stream, int thread_k, int thread_n, int sms, - int max_par) { - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, - ", ", prob_n, ", ", prob_k, "]"); - - int tot_m = prob_m; - int tot_m_blocks = div_ceil(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) { - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - } - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - // Set thread config - exec_config_t exec_cfg; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - exec_cfg = - exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; - } else { - // Auto config - exec_cfg = - determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, max_shared_mem); - } - - TORCH_CHECK(exec_cfg.max_m_blocks > 0 && - is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, - prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, max_shared_mem), - "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, - ", thread_k = ", exec_cfg.tb_cfg.thread_k, - ", thread_n = ", exec_cfg.tb_cfg.thread_n, - ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", - prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, - ", group_size = ", group_size, - ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, - ", max_shared_mem = ", max_shared_mem); - - int num_threads = exec_cfg.tb_cfg.num_threads; - thread_k = exec_cfg.tb_cfg.thread_k; - thread_n = exec_cfg.tb_cfg.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - - int blocks = sms; - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - - int group_blocks = 0; - if (has_act_order) { - if (is_k_full) { - TORCH_CHECK(group_size != -1); - group_blocks = group_size / 16; - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } else { - TORCH_CHECK(group_size == 0); - group_blocks = 0; - } - - } else { - if (group_size == -1) { - group_blocks = -1; - } else { - group_blocks = group_size / 16; - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - } - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - const int4* zp_ptr = (const int4*)zp; - const int* g_idx_ptr = (const int*)g_idx; - const int* perm_ptr = (const int*)perm; - int4* a_tmp_ptr = (int4*)a_tmp; - - int* locks = (int*)workspace; - - if (has_act_order) { - // Permute A columns - int block_rows = div_ceil(prob_m, blocks); - permute_cols_kernel<<>>( - A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows); - A_ptr = a_tmp_ptr; - } - - // If we have a full K, then we can run the non-act-order version of Marlin - // (since the weight rows are reordered by increasing group ids, and by having - // a full K, we have full original groups) - if (is_k_full) { - has_act_order = false; - } - - // Main loop - for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > exec_cfg.max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); - if (par > max_par) par = max_par; - prob_m = (16 * exec_cfg.max_m_blocks) * par; - i += exec_cfg.max_m_blocks * (par - 1); - thread_m_blocks = exec_cfg.max_m_blocks; - } - - if (false) { - } - GPTQ_CALL_IF(4, 16, 4, 256) - GPTQ_CALL_IF(4, 8, 8, 256) - GPTQ_CALL_IF(4, 8, 4, 128) - GPTQ_CALL_IF(4, 4, 8, 128) - GPTQ_CALL_IF(8, 16, 4, 256) - GPTQ_CALL_IF(8, 8, 8, 256) - GPTQ_CALL_IF(8, 8, 4, 128) - GPTQ_CALL_IF(8, 4, 8, 128) - - AWQ_CALL_IF(4, 16, 4, 256) - AWQ_CALL_IF(4, 8, 8, 256) - AWQ_CALL_IF(4, 8, 4, 128) - AWQ_CALL_IF(4, 4, 8, 128) - AWQ_CALL_IF(8, 16, 4, 256) - AWQ_CALL_IF(8, 8, 8, 256) - AWQ_CALL_IF(8, 8, 4, 128) - AWQ_CALL_IF(8, 4, 8, 128) - else { - TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, - ", ", prob_k, "]", ", has_act_order = ", has_act_order, - ", num_groups = ", num_groups, ", group_size = ", group_size, - ", thread_m_blocks = ", thread_m_blocks, - ", thread_n_blocks = ", thread_n_blocks, - ", thread_k_blocks = ", thread_k_blocks, - ", num_bits = ", num_bits); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; - C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - } -} - -} // namespace gptq_marlin - -torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& b_zeros, - torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp) { - // Verify num_bits - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - int pack_factor = 32 / num_bits; - - // Verify A - TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), - ", size_m = ", size_m); - TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), - ", size_k = ", size_k); - - // Verify B - TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k, - " is not divisible by tile_size = ", marlin::tile_size); - TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, ", tile_size = ", marlin::tile_size); - TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, - "b_q_weight.size(1) = ", b_q_weight.size(1), - " is not divisible by tile_size = ", marlin::tile_size); - int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor; - TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, - ", actual_size_n = ", actual_size_n); - - // Verify device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - - TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU"); - TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous"); - - TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); - TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); - - TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); - TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); - - // Alloc buffers - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - torch::Tensor a_tmp = torch::empty({size_m, size_k}, options); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Verify g_idx and perm - TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) || - (g_idx.size(0) == size_k && perm.size(0) == size_k), - "Unexpected g_idx.size(0) = ", g_idx.size(0), - " and perm.size(0) = ", perm.size(0), - ", where size_k = ", size_k); - - // Detect groupsize and act_order - int num_groups = -1; - int group_size = -1; - bool has_act_order = g_idx.size(0) != 0; - - int rank = b_scales.sizes().size(); - TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2"); - TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), - " is not size_n = ", size_n); - num_groups = b_scales.size(0); - - if (has_act_order) { - if (is_k_full) { - TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); - TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, - ", is not divisible by num_groups = ", num_groups); - group_size = size_k / num_groups; - } else { - group_size = 0; - } - - } else { - if (num_groups > 1) { - TORCH_CHECK( - size_k % num_groups == 0, "size_k = ", size_k, - ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); - group_size = size_k / num_groups; - } else { - group_size = -1; - } - } - - // Verify b_zeros - if (has_zp) { - int rank = b_zeros.sizes().size(); - TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2"); - TORCH_CHECK(b_zeros.size(0) == num_groups, - "b_zeros dim 0 = ", b_zeros.size(0), - " is not num_groups = ", num_groups); - TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor, - "b_zeros dim 1 = ", b_scales.size(1), - " is not size_n / pack_factor = ", size_n / pack_factor); - } - - // Verify workspace size - TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n, - ", is not divisible by min_thread_n = ", marlin::min_thread_n); - int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = ", workspace.numel(), - " is below min_workspace_size = ", min_workspace_size); - - int dev = a.get_device(); - if (a.scalar_type() == at::ScalarType::Half) { - marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, marlin::max_par); - } else if (a.scalar_type() == at::ScalarType::BFloat16) { - marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), b_scales.data_ptr(), - b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), - a_tmp.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, marlin::max_par); - } else { - TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); - } - - return c; -} - -#endif diff --git a/server/marlin/marlin_kernels/gptq_marlin_repack.cu b/server/marlin/marlin_kernels/gptq_marlin_repack.cu deleted file mode 100644 index c71b1bf5..00000000 --- a/server/marlin/marlin_kernels/gptq_marlin_repack.cu +++ /dev/null @@ -1,344 +0,0 @@ -#include "marlin.cuh" - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -namespace marlin { - -template -__global__ void gptq_marlin_repack_kernel( - uint32_t const* __restrict__ b_q_weight_ptr, - uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, - int size_k, int size_n) {} - -} // namespace marlin - -torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, - int64_t size_k, int64_t size_n, - int64_t num_bits) { - TORCH_CHECK_NOT_IMPLEMENTED( - false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else - -namespace marlin { - -template -__global__ void gptq_marlin_repack_kernel( - uint32_t const* __restrict__ b_q_weight_ptr, - uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, - int size_k, int size_n) { - constexpr int pack_factor = 32 / num_bits; - - int k_tiles = size_k / tile_k_size; - int n_tiles = size_n / tile_n_size; - int block_k_tiles = div_ceil(k_tiles, gridDim.x); - - int start_k_tile = blockIdx.x * block_k_tiles; - if (start_k_tile >= k_tiles) { - return; - } - - int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - extern __shared__ int4 sh[]; - - constexpr int perm_size = tile_k_size / 4; - - int4* sh_perm_ptr = sh; - int4* sh_pipe_ptr = sh_perm_ptr; - if constexpr (has_perm) { - sh_pipe_ptr += perm_size; - } - - constexpr int tile_ints = tile_k_size / pack_factor; - - constexpr int stage_n_threads = tile_n_size / 4; - constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; - constexpr int stage_size = stage_k_threads * stage_n_threads; - - auto load_perm_to_shared = [&](int k_tile_id) { - int first_k_int4 = (k_tile_id * tile_k_size) / 4; - - int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr); - - if (threadIdx.x < perm_size) { - sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; - } - __syncthreads(); - }; - - auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { - if (n_tile_id >= n_tiles) { - cp_async_fence(); - return; - } - - int first_n = n_tile_id * tile_n_size; - - int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; - - if constexpr (has_perm) { - if (threadIdx.x < stage_size) { - int k_id = threadIdx.x / stage_n_threads; - int n_id = threadIdx.x % stage_n_threads; - - uint32_t const* sh_perm_int_ptr = - reinterpret_cast(sh_perm_ptr); - - int src_k = sh_perm_int_ptr[k_id]; - int src_k_packed = src_k / pack_factor; - - cp_async4( - &sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast(&( - b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); - } - - } else { - if (threadIdx.x < stage_size) { - int k_id = threadIdx.x / stage_n_threads; - int n_id = threadIdx.x % stage_n_threads; - - int first_k = k_tile_id * tile_k_size; - int first_k_packed = first_k / pack_factor; - - cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast( - &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + - first_n + (n_id * 4)]))); - } - } - - cp_async_fence(); - }; - - auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { - if (n_tile_id >= n_tiles) { - return; - } - - int warp_id = threadIdx.x / 32; - int th_id = threadIdx.x % 32; - - if (warp_id >= 4) { - return; - } - - int tc_col = th_id / 4; - int tc_row = (th_id % 4) * 2; - - constexpr int tc_offsets[4] = {0, 1, 8, 9}; - - int cur_n = warp_id * 16 + tc_col; - - constexpr int sh_stride = 64; - constexpr uint32_t mask = (1 << num_bits) - 1; - - int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; - uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); - - uint32_t* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); - - uint32_t vals[8]; - - if constexpr (has_perm) { - for (int i = 0; i < 4; i++) { - int k_idx = tc_row + tc_offsets[i]; - - uint32_t src_k = sh_perm_int_ptr[k_idx]; - uint32_t src_k_pos = src_k % pack_factor; - - uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; - uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; - - uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; - uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; - - vals[i] = b1_cur_val; - vals[4 + i] = b2_cur_val; - } - - } else { - uint32_t b1_vals[tile_ints]; - uint32_t b2_vals[tile_ints]; - - #pragma unroll - for (int i = 0; i < tile_ints; i++) { - b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; - b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; - } - - #pragma unroll - for (int i = 0; i < 4; i++) { - int cur_elem = tc_row + tc_offsets[i]; - int cur_int = cur_elem / pack_factor; - int cur_pos = cur_elem % pack_factor; - - vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; - vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; - } - } - - constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; - int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; - - // Result of: - // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - if constexpr (num_bits == 4) { - constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - - uint32_t res = 0; - #pragma unroll - for (int i = 0; i < 8; i++) { - res |= vals[pack_idx[i]] << (i * 4); - } - - out_ptr[out_offset + th_id * 4 + warp_id] = res; - - } else { - constexpr int pack_idx[4] = {0, 2, 1, 3}; - - uint32_t res1 = 0; - uint32_t res2 = 0; - #pragma unroll - for (int i = 0; i < 4; i++) { - res1 |= vals[pack_idx[i]] << (i * 8); - res2 |= vals[4 + pack_idx[i]] << (i * 8); - } - - out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; - out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; - } - }; - - auto start_pipes = [&](int k_tile_id, int n_tile_id) { - #pragma unroll - for (int pipe = 0; pipe < repack_stages - 1; pipe++) { - fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); - } - - wait_for_stage(); - }; - #pragma unroll - for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { - int n_tile_id = 0; - - if constexpr (has_perm) { - load_perm_to_shared(k_tile_id); - } - - start_pipes(k_tile_id, n_tile_id); - - while (n_tile_id < n_tiles) { - #pragma unroll - for (int pipe = 0; pipe < repack_stages; pipe++) { - fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, - n_tile_id + pipe + repack_stages - 1); - repack_tile(pipe, k_tile_id, n_tile_id + pipe); - wait_for_stage(); - } - n_tile_id += repack_stages; - } - } -} - -} // namespace marlin - - #define CALL_IF(NUM_BITS, HAS_PERM) \ - else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ - cudaFuncSetAttribute( \ - marlin::gptq_marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - marlin::gptq_marlin_repack_kernel \ - <<>>( \ - b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ - } - -torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, - int64_t size_k, int64_t size_n, - int64_t num_bits) { - // Verify compatibility with marlin tile of 16x64 - TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, - " is not divisible by tile_k_size = ", marlin::tile_k_size); - TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n, - " is not divisible by tile_n_size = ", marlin::tile_n_size); - - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - int const pack_factor = 32 / num_bits; - - // Verify B - TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, ", pack_factor = ", pack_factor); - TORCH_CHECK(b_q_weight.size(1) == size_n, - "b_q_weight.size(1) = ", b_q_weight.size(1), - " is not size_n = ", size_n); - - // Verify device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt"); - - TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); - TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); - TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt"); - - // Alloc buffers - const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); - auto options = torch::TensorOptions() - .dtype(b_q_weight.dtype()) - .device(b_q_weight.device()); - torch::Tensor out = torch::empty( - {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, - options); - - // Detect if there is act_order - bool has_perm = perm.size(0) != 0; - - // Get ptrs - uint32_t const* b_q_weight_ptr = - reinterpret_cast(b_q_weight.data_ptr()); - uint32_t const* perm_ptr = reinterpret_cast(perm.data_ptr()); - uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); - - // Get dev info - int dev = b_q_weight.get_device(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); - int blocks; - cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - if (false) { - } - CALL_IF(4, false) - CALL_IF(4, true) - CALL_IF(8, false) - CALL_IF(8, true) - else { - TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, - ", has_perm = ", has_perm); - } - - return out; -} - -#endif diff --git a/server/marlin/marlin_kernels/marlin.cuh b/server/marlin/marlin_kernels/marlin.cuh deleted file mode 100644 index 74ccbac5..00000000 --- a/server/marlin/marlin_kernels/marlin.cuh +++ /dev/null @@ -1,87 +0,0 @@ -#pragma once - -#include - -#include -#include -#include -#include -#include -#include - -namespace marlin { - -// Marlin params - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -static constexpr int default_threads = 256; - -static constexpr int pipe_stages = - 4; // 4 pipeline stages fit into shared memory - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -static constexpr int tile_size = 16; -static constexpr int max_par = 16; - -// Repack params -static constexpr int repack_stages = 8; - -static constexpr int repack_threads = 256; - -static constexpr int tile_k_size = tile_size; -static constexpr int tile_n_size = tile_k_size * 4; - -// Helpers -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - -using I4 = Vec; - -constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -// No support for async -#else - -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -#endif - -} // namespace marlin diff --git a/server/marlin/marlin_kernels/marlin_cuda_kernel.cu b/server/marlin/marlin_kernels/marlin_cuda_kernel.cu deleted file mode 100644 index 37339b84..00000000 --- a/server/marlin/marlin_kernels/marlin_cuda_kernel.cu +++ /dev/null @@ -1,1138 +0,0 @@ -/* - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include -#include - -#include - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace marlin_dense { - -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - -using I4 = Vec; - -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales - -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts in - // the middle of group. - if (group_blocks != -1) - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory - // We typically use `constexpr` to indicate that this value is a compile-time - // constant - constexpr int a_sh_stride = - 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory - constexpr int a_gl_rd_delta_o = - 16 * thread_k_blocks / - 8; // delta between subsequent A tiles in global memory - int a_gl_rd_delta_i = - a_gl_stride * - (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile - constexpr int a_sh_wr_delta = - a_sh_stride * - (threads / a_gl_rd_delta_o); // between shared memory writes - constexpr int a_sh_rd_delta_o = - 2 * ((threads / 32) / - (thread_n_blocks / 4)); // between shared memory tile reads - constexpr int a_sh_rd_delta_i = - a_sh_stride * 16; // within a shared memory tile - constexpr int a_sh_stage = - a_sh_stride * (16 * thread_m_blocks); // overall size of a tile - constexpr int a_sh_wr_iters = - ceildiv(a_sh_stage, - a_sh_wr_delta); // number of shared write iterations for a tile - - int b_gl_stride = 16 * prob_n / 32; - constexpr int b_sh_stride = 32 * thread_n_blocks / 4; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); - constexpr int b_sh_wr_delta = threads; - constexpr int b_sh_rd_delta = threads; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_sh_stage = s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = - b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x; - int b_sh_rd = threadIdx.x; - - int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - int s_sh_wr = threadIdx.x; - int s_sh_rd; - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - if (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s = sh_b + (stages * b_sh_stage); - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); - B_ptr[i] += b_gl_rd_delta_o; - } - // Only fetch scales if this tile starts a new group - if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); - s_gl_rd += s_gl_rd_delta; - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - // It may seem inefficient that we reload the groups for every sub-tile; - // however, this does not seem to be a significant bottleneck, while some - // theoretically better attempts have lead to bad instruction ordering by - // the compiler and correspondingly a noticeable drop in performance. - if (group_blocks != -1) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; - FragB frag_b0 = dequant(b_quant); - // If there are no groups, we can just scale the final output once and can - // avoid doing so for each weight. - if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0); - FragB frag_b1 = dequant(b_quant_shift); - if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1); - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half*>(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half*>(&c)[j] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - if (group_blocks == - -1) // for per-column quantization we finally apply the scale here - res = __hmul2(res, s[0]); - ((half2*)sh)[idx] = res; - }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - #pragma unroll - for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); - zero_accums(); - wait_for_stage(); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - }; - start_pipes(); - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines have - // even length meaning that the next iteration will always start at index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) break; - } - a_gl_rd += a_gl_rd_delta_o * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if (group_blocks == -1 && last) { - if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); - cp_async_fence(); - } - thread_block_reduce(); - if (group_blocks == -1 && last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - start_pipes(); - } - } - } -} - -#else - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory -const int SHARED_MEM = - 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -static constexpr int tile_size = 16; -static constexpr int max_par = 16; - -static constexpr int pack_factor_4bit = - 8; // We have 8 4-bit vals inside a 32 bit - -#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - GROUP_BLOCKS, NUM_THREADS) \ - else if (thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - SHARED_MEM); \ - Marlin<<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X -}; - -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, - int prob_k) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // thread_k can be only 128 or 64 (because it must be less than groupsize - // which is 128) - if (th_config.thread_k != 128 && th_config.thread_k != 64) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - return true; -} - -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - } - - return thread_config_t{-1, -1, -1}; -} - -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) - -void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m, - int prob_n, int prob_k, void* workspace, int groupsize = -1, - int dev = 0, cudaStream_t stream = 0, int thread_k = -1, - int thread_n = -1, int sms = -1, int max_par = 16) { - int tot_m = prob_m; - int tot_m_blocks = ceildiv(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - - // Set thread config - thread_config_t th_config; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; - } else { - // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); - } - - if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) { - throw std::runtime_error( - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + " for MKN = [" + - str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]"); - } - - // Uncomment for debug - // std::cout << "Using thread_config: thread_k = " + str(th_config.thread_k) + - // ", thread_n = " + str(th_config.thread_n) + - // ", num_threads = " + str(th_config.num_threads) + " for - // MKN = [" + str(prob_m) + - // ", " + str(prob_k) + ", " + str(prob_n) + "]\n"; - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; - int blocks = sms; - - if (prob_m == 0 || prob_n == 0 || prob_k == 0) { - return; - } - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - if (group_blocks != -1) { - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - - int* locks = (int*)workspace; - - for (int i = 0; i < tot_m_blocks; i += 4) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > 4) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / 64; - if (par > max_par) par = max_par; - prob_m = 64 * par; - i += 4 * (par - 1); - thread_m_blocks = 4; - } - - // For compilation speed, we only define the kernel configurations that have - // seemed useful (in terms of performance) in our testing, however many more - // are, in principle, possible. - if (false) { - } - CALL_IF(8, 8, 256) - CALL_IF(16, 4, 256) - CALL_IF(8, 4, 128) - CALL_IF(4, 8, 128) - else { - throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + - ", " + str(prob_k) + ", " + str(prob_n) + "]" + - ", groupsize = " + str(groupsize) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; - C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - } -} - -} // namespace marlin_dense - -torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t size_m, int64_t size_n, int64_t size_k) { - // Verify M - TORCH_CHECK(size_m == a.size(0), - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - - // Verify K - TORCH_CHECK(size_k == a.size(1), - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % marlin_dense::tile_size == 0, - "size_k = " + str(size_k) + " is not divisible by tile_size = " + - str(marlin_dense::tile_size)); - TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + - str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + - ", tile_size = " + str(marlin_dense::tile_size)); - - // Verify N - TORCH_CHECK(b_scales.size(1) == size_n, - "b_scales.size(1) = " + str(b_scales.size(1)) + - ", size_n = " + str(size_n)); - TORCH_CHECK( - b_q_weight.size(1) % marlin_dense::tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(marlin_dense::tile_size)); - - int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) * - marlin_dense::pack_factor_4bit; - TORCH_CHECK( - size_n == actual_size_n, - "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); - - // Verify A device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - // Verify B device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - // Verify scales device and strides - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - - // Alloc C matrix - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Detect groupsize - if (b_scales.size(0) != 1) { - TORCH_CHECK(size_k % b_scales.size(0) == 0, - "size_k = " + str(size_k) + - ", is not divisible by b_scales.size(0) = " + - str(b_scales.size(0))); - } - int groupsize = b_scales.size(0) == 1 ? -1 : size_k / b_scales.size(0); - - // Verify groupsize - TORCH_CHECK(groupsize == -1 || groupsize == 128, - "Unexpected groupsize = " + str(groupsize)); - - // Verify workspace size - TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + - str(marlin_dense::min_thread_n)); - int min_workspace_size = - (size_n / marlin_dense::min_thread_n) * marlin_dense::max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); - - int dev = a.get_device(); - marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_n, sms, marlin_dense::max_par); - - return c; -} diff --git a/server/marlin/marlin_kernels/marlin_dtypes.cuh b/server/marlin/marlin_kernels/marlin_dtypes.cuh deleted file mode 100644 index be06c09b..00000000 --- a/server/marlin/marlin_kernels/marlin_dtypes.cuh +++ /dev/null @@ -1,79 +0,0 @@ - -#ifndef _data_types_cuh -#define _data_types_cuh -#include "marlin.cuh" -#include -#include - -namespace marlin { - -template -class ScalarType {}; - -template <> -class ScalarType { - public: - using scalar_t = half; - using scalar_t2 = half2; - - // Matrix fragments for tensor core instructions; their precise layout is - // documented here: - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type - using FragA = Vec; - using FragB = Vec; - using FragC = Vec; - using FragS = Vec; - using FragZP = Vec; - - static __device__ float inline num2float(const half x) { - return __half2float(x); - } - - static __device__ half2 inline num2num2(const half x) { - return __half2half2(x); - } - - static __device__ half2 inline nums2num2(const half x1, const half x2) { - return __halves2half2(x1, x2); - } - - static __host__ __device__ half inline float2num(const float x) { - return __float2half(x); - } -}; - -template <> -class ScalarType { - public: - using scalar_t = nv_bfloat16; - using scalar_t2 = nv_bfloat162; - - using FragA = Vec; - using FragB = Vec; - using FragC = Vec; - using FragS = Vec; - using FragZP = Vec; - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - static __device__ float inline num2float(const nv_bfloat16 x) { - return __bfloat162float(x); - } - - static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { - return __bfloat162bfloat162(x); - } - - static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, - const nv_bfloat16 x2) { - return __halves2bfloat162(x1, x2); - } - - static __host__ __device__ nv_bfloat16 inline float2num(const float x) { - return __float2bfloat16(x); - } -#endif -}; - -} // namespace marlin - -#endif diff --git a/server/marlin/marlin_kernels/py.typed b/server/marlin/marlin_kernels/py.typed deleted file mode 100644 index e69de29b..00000000 diff --git a/server/marlin/marlin_kernels/sparse/common/base.h b/server/marlin/marlin_kernels/sparse/common/base.h deleted file mode 100644 index 16018d33..00000000 --- a/server/marlin/marlin_kernels/sparse/common/base.h +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All - * Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -namespace marlin_24 { - -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - -template -struct ShapeBase { - static constexpr int M = M_, N = N_, K = K_; -}; - -using I4 = Vec; - -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragM = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales - -} // namespace marlin_24 diff --git a/server/marlin/marlin_kernels/sparse/common/mem.h b/server/marlin/marlin_kernels/sparse/common/mem.h deleted file mode 100644 index 83e3578d..00000000 --- a/server/marlin/marlin_kernels/sparse/common/mem.h +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All - * Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "base.h" - -namespace marlin_24 { -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred_zfill(void* smem_ptr, - const void* glob_ptr, - bool pred = true, - const bool zfill = false) { - const int BYTES = 16; - int src_in_bytes = (zfill ? 0 : BYTES); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); -} - -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_m); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" - : "=r"(a[0]), "=r"(a[1]) - : "r"(smem)); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} -} // namespace marlin_24 diff --git a/server/marlin/marlin_kernels/sparse/common/mma.h b/server/marlin/marlin_kernels/sparse/common/mma.h deleted file mode 100644 index b26505f7..00000000 --- a/server/marlin/marlin_kernels/sparse/common/mma.h +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All - * Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "base.h" -#include - -namespace marlin_24 { - -// On CUDA earlier than 12.5, the ordered_metadata version of this instruction -// is not supported. On later versions of CUDA the version without ordered -// metadata results in the following warning: -// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction -// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially -// | reduced performance on some future architectures -#if defined CUDA_VERSION && CUDA_VERSION >= 12050 - #define MMA_SP_INST \ - "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " -#else - #define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " -#endif - -// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, - const FragA& frag_b, FragC& frag_c, FragM& frag_m, - const int psel) { - const uint32_t* a0 = reinterpret_cast(&a_frag0); - const uint32_t* a1 = reinterpret_cast(&a_frag1); - const uint32_t* b = reinterpret_cast(&frag_b); - const uint32_t* e = reinterpret_cast(&frag_m); - - float* c = reinterpret_cast(&frag_c); - if (psel == 0) { - asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), - "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), - "f"(c[2]), "f"(c[3]), "r"(e[0])); - asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), - "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), - "f"(c[6]), "f"(c[7]), "r"(e[0])); - } else { - asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), - "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), - "f"(c[2]), "f"(c[3]), "r"(e[0])); - asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), - "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), - "f"(c[6]), "f"(c[7]), "r"(e[0])); - } -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -__device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2, - float c3) { - uint2 r; - asm("{\n\t" - ".reg .f16 a, b, c, d; \n\t" - "cvt.rn.f16.f32 a, %2; \n\t" - "cvt.rn.f16.f32 b, %3; \n\t" - "cvt.rn.f16.f32 c, %4; \n\t" - "cvt.rn.f16.f32 d, %5; \n\t" - "mov.b32 %0, {a, b}; \n\t" - "mov.b32 %1, {c, d}; \n\t" - "}" - : "=r"(r.x), "=r"(r.y) - : "f"(c0), "f"(c1), "f"(c2), "f"(c3)); - return r; -} - -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant_4bit(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant_8bit(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, - FragS& s0, float* c4, float* c5, float* c6, - float* c7, FragS& s1) { - *c0 = __fmul_rn(*c0, __half2float(s0[0].x)); - *c1 = __fmul_rn(*c1, __half2float(s0[0].y)); - *c2 = __fmul_rn(*c2, __half2float(s0[1].x)); - *c3 = __fmul_rn(*c3, __half2float(s0[1].y)); - - *c4 = __fmul_rn(*c4, __half2float(s1[0].x)); - *c5 = __fmul_rn(*c5, __half2float(s1[0].y)); - *c6 = __fmul_rn(*c6, __half2float(s1[1].x)); - *c7 = __fmul_rn(*c7, __half2float(s1[1].y)); -} - -} // namespace marlin_24 diff --git a/server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu b/server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu deleted file mode 100644 index b5effc30..00000000 --- a/server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu +++ /dev/null @@ -1,1125 +0,0 @@ -/* - * Notice: This file was modified by Neuralmagic inc to include 8-bit support - * - * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All - * Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include -#include -#include -#include -#include - -#include - -#include "common/base.h" - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -#else - - #include "common/mem.h" - #include "common/mma.h" - -#endif - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace marlin_24 { - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -static constexpr int THREADS = 256; -static constexpr int STAGES = 4; - -static constexpr int min_thread_n = 128; - -static constexpr int tile_size = 16; -static constexpr int max_par = 64; - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin_24( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - const int4* __restrict__ meta, // 2bit metadata information about 2:4 - // format on B - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) {} - -torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_meta, - torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, - int64_t size_k) { - TORCH_CHECK_NOT_IMPLEMENTED( - false, "gptq_marlin_24_gemm(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin_24( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - const int4* __restrict__ meta, // 2bit metadata information about 2:4 - // format on B - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - // number of thread_k_blocks in k-dim - int k_tiles = prob_k / 32 / thread_k_blocks; - // number of thread_n_blocks in n-dim - int n_tiles = prob_n / 16 / thread_n_blocks; - // iters needed to cover all slices - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts in - // the middle of group. - if (group_blocks != -1) - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - // number of threadblock tiles in the current slice - int slice_iters; - // total number of active threadblocks in the current slice - int slice_count = 0; - // index of threadblock in current slice; numbered bottom to top - int slice_idx; - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory - - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 32 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 32 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads //RLC: 2 * #warps k-dim - constexpr int a_sh_rd_delta_o = 4 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); - - constexpr int pack_factor = 32 / num_bits; - - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 - constexpr int m_sh_stride = - (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp - int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks; - int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride); - constexpr int m_sh_wr_delta = threads / 2; - constexpr int m_sh_rd_delta = threads / 2; - constexpr int m_sh_stage = m_sh_stride * thread_k_blocks; - constexpr int m_sh_iters = ceildiv(m_sh_stage, m_sh_wr_delta); - - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_sh_stage = s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 4 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; - - int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) + - (threadIdx.x % (m_sh_stride)); - m_gl_rd += (m_sh_stride)*slice_col; - m_gl_rd += m_gl_rd_delta_o * slice_row; - int m_sh_wr = threadIdx.x; - int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16; - - int s_gl_rd; - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } - - int s_sh_wr = threadIdx.x; - int s_sh_rd; - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - if (group_blocks != -1) { - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - } else { - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - } - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - } - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) { - a_sh_rd_trans[0][i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - a_sh_rd_trans[1][i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd + 2); - } - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta; - const int4* meta_ptr[m_sh_iters]; - #pragma unroll - for (int i = 0; i < m_sh_iters; i++) - meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s = sh_b + (stages * b_sh_stage); - int4* sh_m = sh_s + (stages * s_sh_stage); - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks][2]; - I4 frag_b_quant[2][b_thread_vecs]; - FragM frag_m[2][2]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - B_ptr[i] += b_gl_rd_delta_o; - } - int4* sh_meta_stage = sh_m + m_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < m_sh_iters; i++) { - if (m_sh_wr_pred) - cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]); - meta_ptr[i] += m_gl_rd_delta_o; - } - // Only fetch scales if this tile starts a new group - if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); - s_gl_rd += s_gl_rd_delta; - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - // It may seem inefficient that we reload the groups for every sub-tile; - // however, this does not seem to be a significant bottleneck, while some - // theoretically better attempts have lead to bad instruction ordering by - // the compiler and correspondingly a noticeable drop in performance. - if (group_blocks != -1) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - ldsm4(frag_a[k % 2][i][0], - &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]); - ldsm4(frag_a[k % 2][i][1], - &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]); - } - - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - - // Load meta with ldsm4 - int4* sh_m_stage = sh_m + m_sh_stage * pipe; - ldsm4_m(frag_m[k % 2][0], - &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]); - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - FragB frag_b0; - FragB frag_b1; - - if constexpr (num_bits == 4) { - int b_quant = frag_b_quant[k % 2][0][j]; - int b_quant_shift = b_quant >> 8; - - frag_b0 = dequant_4bit(b_quant); - frag_b1 = dequant_4bit(b_quant_shift); - - } else { - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - - frag_b0 = dequant_8bit(b_quant_0); - frag_b1 = dequant_8bit(b_quant_1); - } - - // If there are no groups, we can just scale the final output once and can - // avoid doing so for each weight. - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0], - frag_m[k % 2][j / 2], j % 2); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 2 * 4 * c_gl_stride; - int c_gl_wr_delta_i = - c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) - int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) + - 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int col = 2 * ((threadIdx.x % 32) % 4); - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || - 8 * (i / 2) + col + (i % 2) < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || - 8 * (i / 2) + col + (i % 2) < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j2 = 0; j2 < 2; j2++) { - #pragma unroll - for (int j1 = 0; j1 < 4; j1++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + - 4 * ((i % 4) / 2) + i % 2] += - __half2float( - reinterpret_cast<__half*>(&c_red)[(j2 * 4 + j1)]); - } - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j2 = 0; j2 < 2; j2++) { - #pragma unroll - for (int j1 = 0; j1 < 4; j1++) { - reinterpret_cast<__half*>(&c)[(j2 * 4 + j1)] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + - 4 * ((i % 4) / 2) + i % 2]); - } - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - - constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: - constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: - constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: - - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - - int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) + - ((threadIdx.x % 32) / 4); // RLC: - c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) - - constexpr int c_sh_rd_delta = - c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: - int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) + - (threadIdx.x % (2 * 2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS& s0, - float c4, float c5, float c6, float c7, FragS& s1) { - uint2 res[2]; - res[0] = to_half4(c0, c1, c2, c3); - res[1] = to_half4(c4, c5, c6, c7); - half2* tmp = (half2*)&res; - // for per-column quantization we finally apply the scale here - if constexpr (group_blocks == -1 && num_bits == 4) { - tmp[0] = __hmul2(tmp[0], s0[0]); - tmp[1] = __hmul2(tmp[1], s0[1]); - tmp[2] = __hmul2(tmp[2], s1[0]); - tmp[3] = __hmul2(tmp[3], s1[1]); - } - ((int4*)sh)[idx] = *((int4*)&res[0]); - }; - - // RLC: only warp 0 and 1 baseline example - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - int wr = c_sh_wr; - write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0], - frag_c[i][3][0][0], frag_s[0][0], frag_c[i][0][0][2], - frag_c[i][1][0][2], frag_c[i][2][0][2], frag_c[i][3][0][2], - frag_s[0][2]); - write(wr + c_sh_stride, frag_c[i][0][0][1], frag_c[i][1][0][1], - frag_c[i][2][0][1], frag_c[i][3][0][1], frag_s[0][0], - frag_c[i][0][0][3], frag_c[i][1][0][3], frag_c[i][2][0][3], - frag_c[i][3][0][3], frag_s[0][2]); - write(wr + 4 * c_sh_stride_2, frag_c[i][0][1][0], frag_c[i][1][1][0], - frag_c[i][2][1][0], frag_c[i][3][1][0], frag_s[0][0], - frag_c[i][0][1][2], frag_c[i][1][1][2], frag_c[i][2][1][2], - frag_c[i][3][1][2], frag_s[0][2]); - write(wr + 4 * c_sh_stride_2 + c_sh_stride, frag_c[i][0][1][1], - frag_c[i][1][1][1], frag_c[i][2][1][1], frag_c[i][3][1][1], - frag_s[0][0], frag_c[i][0][1][3], frag_c[i][1][1][3], - frag_c[i][2][1][3], frag_c[i][3][1][3], frag_s[0][2]); - - c_sh_wr += 8 * c_sh_stride_2; - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - #pragma unroll - for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); - zero_accums(); - wait_for_stage(); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - }; - start_pipes(); - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines have - // even length meaning that the next iteration will always start at index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - matmul(pipe); - wait_for_stage(); - - fetch_to_registers(pipe + 1, (pipe + 1) % stages); - - pipe++; - slice_iters--; - if (slice_iters == 0) break; - } - a_gl_rd += a_gl_rd_delta_o * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if constexpr (group_blocks == -1) { - if constexpr (num_bits == 8) { - if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); - cp_async_fence(); - } else { - if (last) { - if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); - cp_async_fence(); - } - } - } - thread_block_reduce(); - - if constexpr (group_blocks == -1) { - if constexpr (num_bits == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); - } - } else { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); - } - } - } - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (group_blocks == -1 && num_bits == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0], - &frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0], - &frag_c[i][0][0][2], &frag_c[i][1][0][2], - &frag_c[i][2][0][2], &frag_c[i][3][0][2], - frag_s[0][2]); - - scale_floats(&frag_c[i][0][0][1], &frag_c[i][1][0][1], - &frag_c[i][2][0][1], &frag_c[i][3][0][1], frag_s[0][0], - &frag_c[i][0][0][3], &frag_c[i][1][0][3], - &frag_c[i][2][0][3], &frag_c[i][3][0][3], - frag_s[0][2]); - - scale_floats(&frag_c[i][0][1][0], &frag_c[i][1][1][0], - &frag_c[i][2][1][0], &frag_c[i][3][1][0], frag_s[0][0], - &frag_c[i][0][1][2], &frag_c[i][1][1][2], - &frag_c[i][2][1][2], &frag_c[i][3][1][2], - frag_s[0][2]); - - scale_floats(&frag_c[i][0][1][1], &frag_c[i][1][1][1], - &frag_c[i][2][1][1], &frag_c[i][3][1][1], frag_s[0][0], - &frag_c[i][0][1][3], &frag_c[i][1][1][3], - &frag_c[i][2][1][3], &frag_c[i][3][1][3], - frag_s[0][2]); - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - #pragma unroll - for (int i = 0; i < m_sh_iters; i++) - meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - #pragma unroll - for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride; - } - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - start_pipes(); - } - } - } -} - -#endif - -#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, GROUP_BLOCKS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS) { \ - cudaFuncSetAttribute( \ - Marlin_24, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin_24 \ - <<>>(A_ptr, B_ptr, meta_ptr, \ - C_ptr, s_ptr, prob_n, \ - prob_m, prob_k, locks); \ - } - -void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, - void* s, int prob_m, int prob_n, int prob_k, - void* workspace, int num_bits, int groupsize = -1, - int dev = 0, cudaStream_t stream = 0, int thread_k = -1, - int thread_m = -1, int sms = -1, int max_par = 16) { - int tot_n = prob_n; - int tot_n_blocks = ceildiv(tot_n, 16); - int pad = 16 * tot_n_blocks - tot_n; - - if (sms == -1) { - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - } - TORCH_CHECK(sms > 0); - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - if (thread_k == -1 || thread_m == -1) { - if (prob_n <= 16) { - // For small batchizes, better partitioningif is slightly more important - // than better compute utilization - thread_k = 128; - thread_m = 128; - } else if (prob_n <= 256) { - thread_k = 64; - thread_m = 256; - } else { - thread_k = 32; - thread_m = 512; - } - } - - int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction - int thread_m_blocks = thread_m / 16; - int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; - int blocks = sms; - - TORCH_CHECK(prob_m % thread_m == 0, "prob_m = ", prob_m, - " is not divisible by thread_m = ", thread_m); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - if (group_blocks != -1) { - TORCH_CHECK((prob_k / 2) % group_blocks == 0, "prob_k/2 = ", prob_k / 2, - " is not divisible by group_blocks = ", group_blocks); - } - - TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, - ", ", prob_n, ", ", prob_k, "]"); - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - const int4* meta_ptr = (const int4*)meta; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - - constexpr int max_m_blocks = 4; - - int* locks = (int*)workspace; - for (int i = 0; i < tot_n_blocks; i += max_m_blocks) { - int thread_n_blocks = tot_n_blocks - i; - prob_n = tot_n - 16 * i; - int par = 1; - if (thread_n_blocks > max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16); - if (par > max_par) par = max_par; - prob_n = (max_m_blocks * 16) * par; - i += max_m_blocks * (par - 1); - thread_n_blocks = max_m_blocks; - } - - // For compilation speed, we only define the kernel configurations that have - // seemed useful (in terms of performance) in our testing, however many more - // are, in principle, possible. - - // the false is start of the CALL_IF macros - if (false) { - } // BMxBNxBK, group - // 4-bit - CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 - - CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 - CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 - CALL_IF_2_4(4, 16, 2, 2, 4) - CALL_IF_2_4(4, 16, 3, 2, -1) - CALL_IF_2_4(4, 16, 3, 2, 4) - CALL_IF_2_4(4, 16, 4, 2, -1) - CALL_IF_2_4(4, 16, 4, 2, 4) - - CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64 - CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64 - CALL_IF_2_4(4, 32, 2, 1, 4) - CALL_IF_2_4(4, 32, 3, 1, -1) - CALL_IF_2_4(4, 32, 3, 1, 4) - CALL_IF_2_4(4, 32, 4, 1, -1) - CALL_IF_2_4(4, 32, 4, 1, 4) - - // 8-bit - CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 - - CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 - CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 - CALL_IF_2_4(8, 16, 2, 2, 4) - CALL_IF_2_4(8, 16, 3, 2, -1) - CALL_IF_2_4(8, 16, 3, 2, 4) - CALL_IF_2_4(8, 16, 4, 2, -1) - CALL_IF_2_4(8, 16, 4, 2, 4) - - CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64 - CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64 - CALL_IF_2_4(8, 32, 2, 1, 4) - CALL_IF_2_4(8, 32, 3, 1, -1) - CALL_IF_2_4(8, 32, 3, 1, 4) - CALL_IF_2_4(8, 32, 4, 1, -1) - CALL_IF_2_4(8, 32, 4, 1, 4) - else { - throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + - ", " + str(prob_k) + ", " + str(prob_n) + "]" + - ", groupsize = " + str(groupsize) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_n_blocks * (prob_k / 8) * par; - C_ptr += 16 * thread_n_blocks * (prob_m / 8) * par; - } -} - -} // namespace marlin_24 - -torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_meta, - torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, - int64_t size_k) { - // Verify num_bits - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - int pack_factor = 32 / num_bits; - - // Verify M - TORCH_CHECK(size_m == a.size(0), - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - - // Verify K - TORCH_CHECK(size_k == a.size(1), - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % marlin_24::tile_size == 0, - "size_k = " + str(size_k) + " is not divisible by tile_size = " + - str(marlin_24::tile_size)); - TORCH_CHECK((size_k / marlin_24::tile_size / 2) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + - str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + - ", tile_size = " + str(marlin_24::tile_size)); - - // Verify N - TORCH_CHECK(b_scales.size(1) == size_n, - "b_scales.size(1) = " + str(b_scales.size(1)) + - ", size_n = " + str(size_n)); - TORCH_CHECK( - b_q_weight.size(1) % marlin_24::tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(marlin_24::tile_size)); - - int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor; - TORCH_CHECK( - size_n == actual_size_n, - "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); - - // Verify meta - TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2, - "b_meta.size(0) = ", b_meta.size(0), - " is not size_k / 8 / 2 / 2 = ", size_k / 8 / 2 / 2); - TORCH_CHECK(b_meta.size(1) == size_n * 2, "b_meta.size(1) = ", b_meta.size(1), - " is not size_n * 2 = ", size_n * 2); - - // Verify A device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - // Verify B device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - // Verify b_meta device and strides - TORCH_CHECK(b_meta.device().is_cuda(), "b_meta is not on GPU"); - TORCH_CHECK(b_meta.is_contiguous(), "b_meta is not contiguous"); - - // Verify scales device and strides - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - - // Alloc C matrix - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - - int thread_k = -1; - int thread_m = -1; - int sms = -1; - int max_par = marlin_24::max_par; - - int groupsize = -1; - if (b_scales.size(0) > 1) { - TORCH_CHECK(size_k % b_scales.size(0) == 0, - "size_k = " + str(size_k) + - ", is not divisible by b_scales.size(0) = " + - str(b_scales.size(0))); - groupsize = size_k / b_scales.size(0); - groupsize /= 2; // Because of 24 - } - - // Verify groupsize - TORCH_CHECK(groupsize == -1 || groupsize == 64, - "Unexpected groupsize = " + str(groupsize)); - - // Verify workspace size - TORCH_CHECK(size_n % marlin_24::min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + - str(marlin_24::min_thread_n)); - int min_workspace_size = - (size_n / marlin_24::min_thread_n) * marlin_24::max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); - - int dev = a.get_device(); - marlin_24::marlin_cuda_2_4( - a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(), - num_bits, groupsize, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_m, sms, max_par); - - return c; -} diff --git a/server/marlin/setup.py b/server/marlin/setup.py deleted file mode 100644 index f92eef04..00000000 --- a/server/marlin/setup.py +++ /dev/null @@ -1,24 +0,0 @@ -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -extra_compile_args = [] - -setup( - name="marlin_kernels", - ext_modules=[ - CUDAExtension( - name="marlin_kernels", - sources=[ - "marlin_kernels/awq_marlin_repack.cu", - "marlin_kernels/fp8_marlin.cu", - "marlin_kernels/gptq_marlin.cu", - "marlin_kernels/gptq_marlin_repack.cu", - "marlin_kernels/marlin_cuda_kernel.cu", - "marlin_kernels/sparse/marlin_24_cuda_kernel.cu", - "marlin_kernels/ext.cpp", - ], - extra_compile_args=extra_compile_args, - ), - ], - cmdclass={"build_ext": BuildExtension}, -) diff --git a/server/poetry.lock b/server/poetry.lock index c993a441..5072aa0b 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1139,6 +1139,74 @@ files = [ {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"}, ] +[[package]] +name = "marlin-kernels" +version = "0.2.0" +description = "Marlin quantization kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:9a5afcf19b0f5917e43353cc19873fb3c4d4d0b924e2a95a37884f9ce208d0bd"}, +] + +[package.dependencies] +torch = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl" + +[[package]] +name = "marlin-kernels" +version = "0.2.0" +description = "Marlin quantization kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:1e64fcc7ebadfaffa60091ee9201ae3daaf5c1be3be60c8c054143a3dcb72d5d"}, +] + +[package.dependencies] +torch = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl" + +[[package]] +name = "marlin-kernels" +version = "0.2.0" +description = "Marlin quantization kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:e75f3ce9b1c13a4ed43a380d88e1d34d297259452db037ec1973ec33dc2eb78e"}, +] + +[package.dependencies] +torch = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl" + +[[package]] +name = "marlin-kernels" +version = "0.2.0" +description = "Marlin quantization kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:2f99a27f70b391887ee6adffeeee7c3f4df7fac37393f9fb16d4cace2b3f6457"}, +] + +[package.dependencies] +torch = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" + [[package]] name = "mpmath" version = "1.3.0" @@ -3507,6 +3575,7 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [extras] accelerate = ["accelerate"] bnb = ["bitsandbytes"] +marlin = ["marlin-kernels", "marlin-kernels", "marlin-kernels", "marlin-kernels"] outlines = ["outlines"] peft = ["peft"] quantize = ["accelerate", "datasets", "texttable"] @@ -3515,4 +3584,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "c94bbdf8131750891fb3f7132066718534129d85a4c09126d8d01c2de6c72798" +content-hash = "a89867b23017d2efa8a7aa14d4764bcbd3b4dea9bfbf06a7a68464cb184ac6a1" diff --git a/server/pyproject.toml b/server/pyproject.toml index 37bf2afd..15da4a8f 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -40,10 +40,18 @@ py-cpuinfo = "^9.0.0" # Remove later, temporary workaround for outlines. numpy = "^1.26" +marlin-kernels = [ + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, +] + [tool.poetry.extras] torch = ["torch"] accelerate = ["accelerate"] bnb = ["bitsandbytes"] +marlin = ["marlin-kernels"] peft = ["peft"] quantize = ["texttable", "datasets", "accelerate"] outlines = ["outlines"] diff --git a/server/text_generation_server/layers/marlin/gptq.py b/server/text_generation_server/layers/marlin/gptq.py index 697634ea..80674b14 100644 --- a/server/text_generation_server/layers/marlin/gptq.py +++ b/server/text_generation_server/layers/marlin/gptq.py @@ -223,6 +223,7 @@ class GPTQMarlinLinear(nn.Module): A_flat.shape[1], self.is_full_k, self.qzeros.numel() > 0, + True, ) C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))