#include <ATen/Dispatch.h>
#include <THC/THCAtomics.cuh>
#include <ATen/ATen.h>
#include <torch/torch.h>
#include <vector>
#include <optional>
* Friendly reminder of how multithreading works in CUDA:
* Check example at
// Available in pytorch main
// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
* Forward passes
* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype
template<typename attention_scores_scalar, int64_t min_kv_length_shard_size_per_thread>
__global__ void forward_masked_softmax_kernel(
const torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> attention_scores, // [B, KV]
const torch::PackedTensorAccessor32<bool, 2, torch::RestrictPtrTraits> mask, // [B, KV]
torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> result, // [B, KV]
const int64_t effective_kv_length,
const dim3 blockDim,
const int64_t rows_per_block,
const int64_t kv_length,
const int64_t batch_size
) {
const auto row_id = threadIdx.x / effective_kv_length;
const auto effective_kv_length_id = threadIdx.x % effective_kv_length;
const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread;
auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread;
kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_;
const auto kv_length_end = kv_length_end_;
const auto batch_id = blockIdx.x * rows_per_block + row_id;
// We need 2 float storage for each row, one for max computation, the other for normalizing exponential
extern __shared__ float temp_storage[];
const auto row_id_mem_offset = row_id * 2;
if (effective_kv_length_id == 0) {
temp_storage[row_id_mem_offset] = -std::numeric_limits<float>::infinity();
temp_storage[row_id_mem_offset + 1] = 0;
// Compute mask and max
if (batch_id < batch_size) {
float thread_max = -std::numeric_limits<float>::infinity();
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
if (mask[batch_id][kv_length_id] == 0) {
const float candidate = attention_scores[batch_id][kv_length_id];
thread_max = (thread_max < candidate) ? candidate : thread_max;
if (thread_max != -std::numeric_limits<float>::infinity()) {
// TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot
gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max);
// Compute exp(elt - max) masked
float exponential[min_kv_length_shard_size_per_thread];
if (batch_id < batch_size) {
float thread_add = 0;
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
if (mask[batch_id][kv_length_id] == 0) {
exponential[kv_length_id - kv_length_start] = std::exp(static_cast<float>(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]);
thread_add = thread_add + exponential[kv_length_id - kv_length_start];
} else {
exponential[kv_length_id - kv_length_start] = 0.;
if (thread_add > 0) {
// TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot
gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add);
// Compute softmax
if (batch_id < batch_size) {
// If sum of all exponential is 0, we set the softmax values to 0
if (temp_storage[row_id_mem_offset + 1] == 0.) {
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
result[batch_id][kv_length_id] = 0.;
} else {
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
result[batch_id][kv_length_id] = static_cast<attention_scores_scalar>(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]);
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
std::tuple<at::Tensor, std::optional<std::vector<at::Tensor>>, at::Tensor> forward(
const at::Tensor query,
const at::Tensor key,
const at::Tensor value,
const std::optional<std::vector<at::Tensor>> layer_past,
const at::Tensor attention_mask,
const std::optional<at::Tensor> head_mask,
const float inv_norm_factor,
const int num_heads,
const bool use_cache
) {
auto query_layer = query;
auto key_layer = key;
auto value_layer = value;
if (layer_past) {
const auto past_key = (*layer_past).at(0);
const auto past_value = (*layer_past).at(1);
key_layer = at::cat({past_key, key_layer}, 2);
value_layer = at::cat({past_value, value_layer}, 2);
std::optional<std::vector<at::Tensor>> present;
if (use_cache) {
present = {key_layer, value_layer};
} else {
present = {};
const auto batch_size = query_layer.size(0);
const auto q_length = query_layer.size(2);
const auto attn_head_size = query_layer.size(3);
const auto batch_size_times_num_heads = batch_size * num_heads;
const auto kv_length = key_layer.size(2);
const auto query_view = query_layer.reshape({batch_size_times_num_heads, q_length, attn_head_size});
auto key_view = key_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size}).transpose(1, 2);
auto value_view = value_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size});
auto query_scaled = query_view * inv_norm_factor;
auto attention_scores = at::bmm(query_scaled, key_view);
// Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype`
at::Tensor attention_probs;
if (true) {
// TODO @thomasw21: it's easier to think of attention_scores as 2D tensors
const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length});
const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length});
// Custom kernel
attention_probs = at::empty_like(attention_scores_2d);
// Check that inputs and contiguous + cuda tensors
// TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out
// DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] {
* Understanding how GPUs work:
* A100 specifications:
* - SMs: 108
* - TPCs: 56 (What's that?)
* - Memory size: 40 GB
* - L2 Cache size: 40960 KB (shared across all SMs)
* - L1/Shared memory size: 192 KB (shared across all threads within a SM)
* - Max Threads / SM: 2048
* - Max Thread Blocks / SM: 32
* We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block
* with multiple threads as we need to `sync_threads` to run exponential sum.
* We maximise the usage of threads within a single block
// TODO @thomasw21 figure out everything warp related:
// - why do they have to be power of 2
// TODO @thomas21 check why everyone is setting 1024 when officially it's 2048
const auto MAX_THREADS_PER_SM = 1024;
// TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD`
// `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)`
const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1;
const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length;
const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1;
const dim3 gridDim(num_blocks); // Number of blocks that run
const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block
const int shared_mem_forward = rows_per_block * 2 * sizeof(float);
// 192 * 2 ** 10
// const auto MAX_L1_MEMORY = 196608;
// const auto MAX_SMs = 108;
// TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation.");
// TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger.");
// TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher.");
forward_masked_softmax_kernel<scalar_t, MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD><<<gridDim, blockDim, shared_mem_forward>>>(
attention_scores_2d.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
attention_mask_2d.packed_accessor32<bool, 2, torch::RestrictPtrTraits>(),
attention_probs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
batch_size_times_num_heads * q_length
attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length});
} else {
// Pytorch C++ API
auto input_dtype = attention_scores.scalar_type();
if (input_dtype == at::ScalarType::Float) {
attention_scores =;
// TODO @thomasw21 Figure out how to get minimum value
auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34);
attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype);
auto context_layer = attention_probs.bmm(value_view);
// `_merge_heads`
context_layer = context_layer.view({batch_size, num_heads, q_length, attn_head_size});
context_layer = context_layer.permute({0, 2, 1, 3});
context_layer = context_layer.reshape({batch_size, q_length, attn_head_size * num_heads});
return std::make_tuple(context_layer, present, attention_probs);
"GPT-Neox attention mechanism forward (CUDA)"