77 lines
2.0 KiB
Plaintext
77 lines
2.0 KiB
Plaintext
#pragma once
|
|
|
|
#include <torch/all.h>
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <cuda.h>
|
|
#include <cuda_fp16.h>
|
|
#include <cuda_runtime.h>
|
|
#include <iostream>
|
|
|
|
namespace gptq_marlin {
|
|
|
|
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
|
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
|
// we want relatively few warps to have many registers per warp and small tiles.
|
|
static constexpr int default_threads = 256;
|
|
|
|
static constexpr int pipe_stages =
|
|
4; // 4 pipeline stages fit into shared memory
|
|
|
|
static constexpr int min_thread_n = 64;
|
|
static constexpr int min_thread_k = 64;
|
|
|
|
static constexpr int tile_size = 16;
|
|
static constexpr int max_par = 16;
|
|
|
|
template <typename T, int n>
|
|
struct Vec {
|
|
T elems[n];
|
|
__device__ T& operator[](int i) { return elems[i]; }
|
|
};
|
|
|
|
using I4 = Vec<int, 4>;
|
|
|
|
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
|
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
|
// No support for async
|
|
#else
|
|
|
|
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
|
bool pred = true) {
|
|
const int BYTES = 16;
|
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
|
asm volatile(
|
|
"{\n"
|
|
" .reg .pred p;\n"
|
|
" setp.ne.b32 p, %0, 0;\n"
|
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
|
"}\n" ::"r"((int)pred),
|
|
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
|
}
|
|
|
|
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
|
const int BYTES = 16;
|
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
|
asm volatile(
|
|
"{\n"
|
|
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
|
"}\n" ::"r"(smem),
|
|
"l"(glob_ptr), "n"(BYTES));
|
|
}
|
|
|
|
__device__ inline void cp_async_fence() {
|
|
asm volatile("cp.async.commit_group;\n" ::);
|
|
}
|
|
|
|
template <int n>
|
|
__device__ inline void cp_async_wait() {
|
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
|
}
|
|
|
|
#endif
|
|
|
|
} // namespace gptq_marlin
|