feat: vendor precompiled llama mlp kernel
This commit is contained in:
parent
e4d31a40db
commit
9eeda34427
|
@ -0,0 +1,300 @@
|
|||
from ctypes import c_void_p, c_long
|
||||
import torch
|
||||
import math
|
||||
import random
|
||||
import os
|
||||
import tempfile
|
||||
from math import inf, nan
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
||||
from torch import device, empty, empty_strided
|
||||
from torch._inductor.codecache import AsyncCompile
|
||||
from torch._inductor.select_algorithm import extern_kernels
|
||||
|
||||
aten = torch.ops.aten
|
||||
inductor_ops = torch.ops.inductor
|
||||
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
||||
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
||||
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
|
||||
async_compile = AsyncCompile()
|
||||
|
||||
|
||||
triton_poi_fused_mul_silu_0 = async_compile.triton(
|
||||
"triton_poi_fused_mul_silu_0",
|
||||
"""
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch._inductor.ir import ReductionHint
|
||||
from torch._inductor.ir import TileHint
|
||||
from torch._inductor.triton_heuristics import AutotuneHint, pointwise
|
||||
from torch._inductor.utils import instance_descriptor
|
||||
from torch._inductor import triton_helpers
|
||||
|
||||
@pointwise(
|
||||
size_hints=[16384],
|
||||
filename=__file__,
|
||||
# triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(2,))]},
|
||||
# inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_0', 'mutated_arg_names': []},
|
||||
# min_elem_per_thread=0
|
||||
triton_meta={
|
||||
'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'},
|
||||
'device': 0,
|
||||
'device_type': 'cuda',
|
||||
'constants': {},
|
||||
'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(2,))]
|
||||
},
|
||||
inductor_meta={
|
||||
'autotune_hints': set(),
|
||||
'kernel_name': 'triton_poi_fused_mul_silu_0',
|
||||
'mutated_arg_names': []
|
||||
},
|
||||
min_elem_per_thread=0
|
||||
)
|
||||
@triton.jit
|
||||
def triton_poi_fused_mul_silu_0(in_ptr0, out_ptr0, xnumel: int, XBLOCK : tl.constexpr):
|
||||
# xnumel = 11008
|
||||
xoffset = tl.program_id(0) * XBLOCK
|
||||
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
||||
xmask = xindex < xnumel
|
||||
x0 = xindex
|
||||
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
|
||||
tmp5 = tl.load(in_ptr0 + (xnumel + x0), xmask).to(tl.float32)
|
||||
tmp1 = tmp0.to(tl.float32)
|
||||
tmp2 = tl.sigmoid(tmp1)
|
||||
tmp3 = tmp1 * tmp2
|
||||
tmp4 = tmp3.to(tl.float32)
|
||||
tmp6 = tmp4 * tmp5
|
||||
tl.store(out_ptr0 + (x0), tmp6, xmask)
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
triton_tem_fused_addmm_1 = async_compile.triton(
|
||||
"triton_tem_fused_addmm_1",
|
||||
"""
|
||||
import triton.language as tl
|
||||
import triton
|
||||
from torch._inductor.triton_heuristics import template
|
||||
from torch._inductor.utils import instance_descriptor
|
||||
from torch._inductor import triton_helpers
|
||||
|
||||
|
||||
@template(
|
||||
num_stages=5,
|
||||
num_warps=4,
|
||||
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
|
||||
inductor_meta={'kernel_name': 'triton_tem_fused_addmm_1'},
|
||||
)
|
||||
@triton.jit
|
||||
|
||||
def triton_tem_fused_addmm_1(in_ptr0, arg_A, arg_B, out_ptr0):
|
||||
GROUP_M : tl.constexpr = 8
|
||||
EVEN_K : tl.constexpr = True
|
||||
ALLOW_TF32 : tl.constexpr = False
|
||||
ACC_TYPE : tl.constexpr = tl.float32
|
||||
B_PROLOGUE_CAST_TYPE : tl.constexpr = None
|
||||
BLOCK_M : tl.constexpr = 16
|
||||
BLOCK_N : tl.constexpr = 64
|
||||
BLOCK_K : tl.constexpr = 32
|
||||
|
||||
A = arg_A
|
||||
B = arg_B
|
||||
|
||||
M = 1
|
||||
N = 4096
|
||||
K = 11008
|
||||
if M * N == 0:
|
||||
# early exit due to zero-size input(s)
|
||||
return
|
||||
stride_am = 11008
|
||||
stride_ak = 1
|
||||
stride_bk = 1
|
||||
stride_bn = 11008
|
||||
|
||||
# based on triton.ops.matmul
|
||||
pid = tl.program_id(0)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = tl.arange(0, BLOCK_K)
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
for k in range(K, 0, -BLOCK_K):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
else:
|
||||
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
||||
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
||||
if B_PROLOGUE_CAST_TYPE is not None:
|
||||
b = b.to(B_PROLOGUE_CAST_TYPE)
|
||||
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
||||
A += BLOCK_K * stride_ak
|
||||
B += BLOCK_K * stride_bk
|
||||
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
idx_m = rm[:, None]
|
||||
idx_n = rn[None, :]
|
||||
mask = (idx_m < M) & (idx_n < N)
|
||||
|
||||
# inductor generates a suffix
|
||||
xindex = idx_n + (4096*idx_m)
|
||||
tmp0 = tl.load(in_ptr0 + (tl.broadcast_to(idx_n, mask.shape)), mask, eviction_policy='evict_last').to(tl.float32)
|
||||
tmp1 = acc + tmp0
|
||||
tl.store(out_ptr0 + (tl.broadcast_to(idx_n, mask.shape)), tmp1, mask)
|
||||
""",
|
||||
)
|
||||
import torch._inductor.kernel.mm_common
|
||||
|
||||
meta0 = {
|
||||
"GROUP_M": 8,
|
||||
"EVEN_K": True,
|
||||
"ALLOW_TF32": False,
|
||||
"ACC_TYPE": "tl.float32",
|
||||
"B_PROLOGUE_CAST_TYPE": None,
|
||||
"BLOCK_M": 16,
|
||||
"BLOCK_N": 64,
|
||||
"BLOCK_K": 32,
|
||||
}
|
||||
|
||||
|
||||
async_compile.wait(globals())
|
||||
del async_compile
|
||||
|
||||
|
||||
def call(args):
|
||||
primals_1, primals_2, primals_3, primals_4, primals_5 = args
|
||||
args.clear()
|
||||
assert_size_stride(primals_1, (22016, 4096), (4096, 1))
|
||||
assert_size_stride(primals_2, (22016,), (1,))
|
||||
assert_size_stride(primals_3, (4096, 11008), (11008, 1))
|
||||
assert_size_stride(primals_4, (4096,), (1,))
|
||||
assert_size_stride(primals_5, (1, 4096), (4096, 1))
|
||||
with torch.cuda._DeviceGuard(0):
|
||||
torch.cuda.set_device(0) # no-op to ensure context
|
||||
buf0 = empty((1, 22016), device="cuda", dtype=torch.float16)
|
||||
# Source Nodes: [gate_up_states], Original ATen: [aten.addmm]
|
||||
extern_kernels.bias_addmm(
|
||||
reinterpret_tensor(primals_2, (1, 22016), (0, 1), 0),
|
||||
primals_5,
|
||||
reinterpret_tensor(primals_1, (4096, 22016), (1, 4096), 0),
|
||||
alpha=1,
|
||||
beta=1,
|
||||
out=buf0,
|
||||
)
|
||||
del primals_1
|
||||
del primals_2
|
||||
buf1 = empty((1, 11008), device="cuda", dtype=torch.float16)
|
||||
# Source Nodes: [l__self___act, mul], Original ATen: [aten.mul, aten.silu]
|
||||
stream0 = get_cuda_stream(0)
|
||||
triton_poi_fused_mul_silu_0.run(
|
||||
buf0, buf1, 11008, grid=grid(11008), stream=stream0
|
||||
)
|
||||
buf2 = empty((1, 4096), device="cuda", dtype=torch.float16)
|
||||
# Source Nodes: [l__self___down_proj], Original ATen: [aten.addmm]
|
||||
triton_tem_fused_addmm_1.run(
|
||||
primals_4,
|
||||
buf1,
|
||||
primals_3,
|
||||
buf2,
|
||||
grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0),
|
||||
stream=stream0,
|
||||
)
|
||||
del primals_4
|
||||
return (
|
||||
buf2,
|
||||
primals_5,
|
||||
buf0,
|
||||
buf1,
|
||||
reinterpret_tensor(primals_3, (11008, 4096), (1, 11008), 0),
|
||||
)
|
||||
|
||||
|
||||
def benchmark_compiled_module(times=10, repeat=10):
|
||||
from torch._dynamo.testing import rand_strided
|
||||
from torch._inductor.utils import print_performance
|
||||
|
||||
primals_1 = rand_strided(
|
||||
(22016, 4096), (4096, 1), device="cuda:0", dtype=torch.float16
|
||||
)
|
||||
primals_2 = rand_strided((22016,), (1,), device="cuda:0", dtype=torch.float16)
|
||||
primals_3 = rand_strided(
|
||||
(4096, 11008), (11008, 1), device="cuda:0", dtype=torch.float16
|
||||
)
|
||||
primals_4 = rand_strided((4096,), (1,), device="cuda:0", dtype=torch.float16)
|
||||
primals_5 = rand_strided((1, 4096), (4096, 1), device="cuda:0", dtype=torch.float16)
|
||||
fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5])
|
||||
return print_performance(fn, times=times, repeat=repeat)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.wrapper_benchmark import compiled_module_main
|
||||
|
||||
compiled_module_main("None", benchmark_compiled_module)
|
||||
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class CustomMLP(nn.Module):
|
||||
# TODO: replace with load when we have a way to handle dynamic shapes
|
||||
# def __init__(self, prefix, config, weights):
|
||||
# super(CustomMLP, self).__init__()
|
||||
|
||||
# prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
|
||||
# dim = 0
|
||||
# self.gate_up_weights = weights.get_multi_weights_col(
|
||||
# prefixes, quantize=config.quantize, dim=dim
|
||||
# )
|
||||
# self.gate_up_bias = torch.zeros(
|
||||
# self.gate_up_weights.size(0), device=self.gate_up_weights[0].device
|
||||
# )
|
||||
|
||||
# self.down_weights = weights.get_multi_weights_row(
|
||||
# f"{prefix}.down_proj", quantize=config.quantize
|
||||
# )
|
||||
# self.down_bias = torch.zeros(
|
||||
# self.down_weights.size(0), device=self.down_weights.device
|
||||
# )
|
||||
|
||||
def __init__(self, gate_up_weights, down_weights):
|
||||
super(CustomMLP, self).__init__()
|
||||
self.gate_up_weights = gate_up_weights
|
||||
self.gate_up_bias = torch.zeros(
|
||||
self.gate_up_weights.size(0),
|
||||
device=self.gate_up_weights[0].device,
|
||||
dtype=self.gate_up_weights[0].dtype,
|
||||
)
|
||||
self.down_weights = down_weights
|
||||
self.down_bias = torch.zeros(
|
||||
self.down_weights.size(0),
|
||||
device=self.down_weights.device,
|
||||
dtype=self.down_weights.dtype,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return call(
|
||||
[
|
||||
self.gate_up_weights,
|
||||
self.gate_up_bias,
|
||||
self.down_weights,
|
||||
self.down_bias,
|
||||
x,
|
||||
]
|
||||
)[0]
|
|
@ -37,6 +37,8 @@ from text_generation_server.utils.layers import (
|
|||
FastRMSNorm,
|
||||
)
|
||||
|
||||
from .custom_triton_kernels.mlp import CustomMLP
|
||||
|
||||
|
||||
class LlamaConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
|
@ -287,6 +289,10 @@ class FlashLlamaLayer(nn.Module):
|
|||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
# TODO: reuse weights until we can use the precompiled mlp in all cases
|
||||
self.pre_compiled_mlp = CustomMLP(
|
||||
self.mlp.gate_up_proj.linear.weight, self.mlp.down_proj.linear.weight
|
||||
)
|
||||
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
|
@ -330,7 +336,10 @@ class FlashLlamaLayer(nn.Module):
|
|||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
if normed_attn_res_output.size(0) == 1:
|
||||
mlp_output = self.pre_compiled_mlp(normed_attn_res_output)
|
||||
else:
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
|
|
Loading…
Reference in New Issue