137 lines
4.9 KiB
C
137 lines
4.9 KiB
C
|
/*
|
||
|
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
|
||
|
* Rights Reserved.
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
#pragma once
|
||
|
#include "base.h"
|
||
|
|
||
|
namespace marlin_24 {
|
||
|
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
||
|
// predication to handle batchsizes that are not multiples of 16.
|
||
|
__device__ inline void cp_async4_pred_zfill(void* smem_ptr,
|
||
|
const void* glob_ptr,
|
||
|
bool pred = true,
|
||
|
const bool zfill = false) {
|
||
|
const int BYTES = 16;
|
||
|
int src_in_bytes = (zfill ? 0 : BYTES);
|
||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||
|
asm volatile(
|
||
|
"{\n"
|
||
|
" .reg .pred p;\n"
|
||
|
" setp.ne.b32 p, %0, 0;\n"
|
||
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||
|
"}\n" ::"r"((int)pred),
|
||
|
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
|
||
|
}
|
||
|
|
||
|
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||
|
bool pred = true) {
|
||
|
const int BYTES = 16;
|
||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||
|
asm volatile(
|
||
|
"{\n"
|
||
|
" .reg .pred p;\n"
|
||
|
" setp.ne.b32 p, %0, 0;\n"
|
||
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||
|
"}\n" ::"r"((int)pred),
|
||
|
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||
|
}
|
||
|
|
||
|
// Asynchronous global->shared copy
|
||
|
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||
|
const int BYTES = 16;
|
||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||
|
asm volatile(
|
||
|
"{\n"
|
||
|
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||
|
"}\n" ::"r"(smem),
|
||
|
"l"(glob_ptr), "n"(BYTES));
|
||
|
}
|
||
|
|
||
|
// Async copy fence.
|
||
|
__device__ inline void cp_async_fence() {
|
||
|
asm volatile("cp.async.commit_group;\n" ::);
|
||
|
}
|
||
|
|
||
|
// Wait until at most `n` async copy stages are still pending.
|
||
|
template <int n>
|
||
|
__device__ inline void cp_async_wait() {
|
||
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||
|
}
|
||
|
|
||
|
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||
|
// memory, directly in tensor core layout.
|
||
|
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
|
||
|
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
||
|
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
||
|
: "r"(smem));
|
||
|
}
|
||
|
|
||
|
__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
|
||
|
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
|
||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||
|
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
|
||
|
: "=r"(a[0]), "=r"(a[1])
|
||
|
: "r"(smem));
|
||
|
}
|
||
|
|
||
|
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||
|
// memory, directly in tensor core layout.
|
||
|
__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
|
||
|
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||
|
asm volatile(
|
||
|
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
||
|
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
||
|
: "r"(smem));
|
||
|
}
|
||
|
|
||
|
// Wait until barrier reaches `count`, then lock for current threadblock.
|
||
|
__device__ inline void barrier_acquire(int* lock, int count) {
|
||
|
if (threadIdx.x == 0) {
|
||
|
int state = -1;
|
||
|
do
|
||
|
// Guarantee that subsequent writes by this threadblock will be visible
|
||
|
// globally.
|
||
|
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
|
||
|
: "=r"(state)
|
||
|
: "l"(lock));
|
||
|
while (state != count);
|
||
|
}
|
||
|
__syncthreads();
|
||
|
}
|
||
|
|
||
|
// Release barrier and increment visitation count.
|
||
|
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
||
|
__syncthreads();
|
||
|
if (threadIdx.x == 0) {
|
||
|
if (reset) {
|
||
|
lock[0] = 0;
|
||
|
return;
|
||
|
}
|
||
|
int val = 1;
|
||
|
// Make sure that all writes since acquiring this barrier are visible
|
||
|
// globally, while releasing the barrier.
|
||
|
asm volatile("fence.acq_rel.gpu;\n");
|
||
|
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
|
||
|
:
|
||
|
: "l"(lock), "r"(val));
|
||
|
}
|
||
|
}
|
||
|
} // namespace marlin_24
|