feat: integrate triton compilations demo

This commit is contained in:
drbh 2024-04-12 21:47:15 +00:00
parent c38a7d7ddd
commit 8ebb560f2f
8 changed files with 996 additions and 8 deletions

107
script/throughput.py Normal file
View File

@ -0,0 +1,107 @@
from concurrent import futures
from typing import Any, Dict, List
import time
import requests
import json
# copied from other benchmarking script
def simple_throughput(
predictor: Any, payloads: List[Dict[str, Any]], concurrent_requests: int
):
time_start = time.time()
with futures.ThreadPoolExecutor(max_workers=concurrent_requests) as executor:
responses = list(executor.map(predictor.predict, payloads))
if isinstance(responses[0], list):
responses = [response[0] for response in responses]
total_tokens = sum(
[
len([token for token in x["details"]["tokens"] if token["id"] != -1])
for x in responses
]
)
token_throughput = total_tokens / (time.time() - time_start)
generated_tokens_per_request = total_tokens / len(payloads)
print(f"concurrent requests: {concurrent_requests}")
print(f"throughput: {token_throughput:.2f}")
print(
f"generated tokens per request: {generated_tokens_per_request:.2f}", end="\n\n"
)
# setup to help run
real = True
class Predictor:
def predict(self, payload: Dict[str, Any]) -> Dict[str, Any]:
response_json = {"details": {"tokens": []}}
if real:
url = "http://localhost:3000/generate"
headers = {"Content-Type": "application/json"}
response = requests.post(url, data=json.dumps(payload), headers=headers)
response_json = response.json()
else:
time.sleep(0.1)
return response_json
max_new_tokens = 100
if __name__ == "__main__":
print("Running throughput test")
predictor = Predictor()
payloads = [
{
"inputs": "<s>[INST] I am making mayonnaise, it was starting to thicken but now it has become runny and liquid again, is there any way to salvage it? [/INST]Yes, it's possible to fix runny mayonnaise! The most common reason for mayonnaise becoming runny is because the oil was added too quickly or the egg yolk wasn't emulsified properly. Here are some steps you can take to fix it:\n\n1. Separate another egg yolk and place it in a clean, dry bowl.\n2. Slowly add the runny mayonnaise to the egg yolk while whisking vigorously.\n3. Once all the runny mayonnaise has been added, continue whisking until the mixture has emulsified and thickened.\n4. If the mayonnaise is still too runny, you can add another egg yolk and repeat the process.\n\nIf the mayonnaise still won't thicken, you can try adding a small amount of dijon mustard or vinegar to the mixture, which can act as emulsifiers and help stabilize the mayonnaise. It's important to add these ingredients slowly and in small amounts to avoid over-thinning the mixture.</s>[INST] What is optimal Mayonnaise thickness? [/INST]",
"parameters": {"details": True, "max_new_tokens": max_new_tokens},
},
{
"inputs": "<s>[INST] Why Aristotelian view of physics (impetus and stuff) is wrong? [/INST]",
"parameters": {"details": True, "max_new_tokens": max_new_tokens},
},
{
"inputs": "<s>[INST] If you were to image the experience of eating from only others descriptions and never having done it yourself, how would you describe it. [/INST]",
"parameters": {"details": True, "max_new_tokens": max_new_tokens},
},
{
"inputs": "<s>[INST] What is the best way to cook a steak? [/INST]",
"parameters": {"details": True, "max_new_tokens": max_new_tokens},
},
{
"inputs": "<s>[INST] How do you make a perfect omelette? [/INST]",
"parameters": {"details": True, "max_new_tokens": max_new_tokens},
},
{
"inputs": "<s>[INST] What is the secret to a good pizza dough? [/INST]",
"parameters": {"details": True, "max_new_tokens": max_new_tokens},
},
{
"inputs": "<s>[INST] How do you make a classic French onion soup? [/INST]",
"parameters": {"details": True, "max_new_tokens": max_new_tokens},
},
{
"inputs": "<s>[INST] What is the best way to roast a chicken? [/INST]",
"parameters": {"details": True, "max_new_tokens": max_new_tokens},
},
{
"inputs": "<s>[INST] How do you make a perfect chocolate cake? [/INST]",
"parameters": {"details": True, "max_new_tokens": max_new_tokens},
},
{
"inputs": "<s>[INST] What is the secret to a good risotto? [/INST]",
"parameters": {"details": True, "max_new_tokens": max_new_tokens},
},
{
"inputs": "<s>[INST] How do you make a classic spaghetti carbonara? [/INST]",
"parameters": {"details": True, "max_new_tokens": max_new_tokens},
},
]
for concurrent_requests in [1, 2, 4, 8]:
simple_throughput(predictor, payloads, concurrent_requests)
print("Throughput test complete")

View File

@ -40,6 +40,7 @@ def serve(
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
post_attn_method: Optional[str] = None,
):
if sharded:
assert (
@ -96,6 +97,7 @@ def serve(
dtype,
trust_remote_code,
uds_path,
post_attn_method,
)

View File

@ -114,6 +114,7 @@ def get_model(
speculate: Optional[int],
dtype: Optional[str],
trust_remote_code: bool,
post_attn_method: Optional[str] = None,
) -> Model:
if dtype is None:
# Keep it as default for now and let
@ -357,6 +358,7 @@ def get_model(
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
post_attn_method=post_attn_method,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))

View File

@ -0,0 +1,354 @@
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_ubuntu/sb/csbqhwspefuy7jp6pcbyvalomk4wvqyt4mplgbphfugdwu66zbp7.py
# Source Nodes: [add, attn_res, hidden_states_1, hidden_states_2, hidden_states_3, normed_attn_res_output, pow_1, rsqrt, variance], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
# add => add_1
# attn_res => add
# hidden_states_1 => convert_element_type
# hidden_states_2 => mul
# hidden_states_3 => mul_1
# normed_attn_res_output => convert_element_type_1
# pow_1 => pow_1
# rsqrt => rsqrt
# variance => mean
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0 = async_compile.triton(
"triton_",
"""
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, reduction
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
@reduction(
size_hints=[256, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: 'i32', 7: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(7,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_ptr0', 'out_ptr3']}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr2, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
rnumel = 3072
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
_tmp6 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (3072*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1 + (3072*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp3 * tmp3
tmp5 = tl.broadcast_to(tmp4, [XBLOCK, RBLOCK])
tmp7 = _tmp6 + tmp5
_tmp6 = tl.where(rmask & xmask, tmp7, _tmp6)
tl.store(out_ptr0 + (r1 + (3072*x0)), tmp2, rmask & xmask)
tmp6 = tl.sum(_tmp6, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp8 = tl.load(out_ptr0 + (r1 + (3072*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp16 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp9 = tmp8.to(tl.float32)
tmp10 = 3072.0
tmp11 = tmp6 / tmp10
tmp12 = 1e-06
tmp13 = tmp11 + tmp12
tmp14 = tl.math.rsqrt(tmp13)
tmp15 = tmp9 * tmp14
tmp17 = tmp16.to(tl.float32)
tmp18 = tmp15 * tmp17
tmp19 = tmp18.to(tl.float32)
tl.store(out_ptr2 + (r1 + (3072*x0)), tmp19, rmask & xmask)
tl.store(out_ptr3 + (r1 + (3072*x0)), tmp8, rmask & xmask)
""",
)
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
# kernel path: /tmp/torchinductor_ubuntu/3y/c3ycepp3q64t65lfi2wcqq5skonqvz6t4okgw4dboyrxwdspdaff.py
# Source Nodes: [add, gate_up_states, hidden_states_1, hidden_states_2, hidden_states_3, normed_attn_res_output, pow_1, rsqrt, variance], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mm, aten.mul, aten.pow, aten.rsqrt]
# add => add_1
# gate_up_states => mm
# hidden_states_1 => convert_element_type
# hidden_states_2 => mul
# hidden_states_3 => mul_1
# normed_attn_res_output => convert_element_type_1
# pow_1 => pow_1
# rsqrt => rsqrt
# variance => mean
triton_tem_fused__to_copy_add_mean_mm_mul_pow_rsqrt_1 = async_compile.triton(
"triton_",
"""
import triton.language as tl
import triton
from torch._inductor.triton_heuristics import template
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
@template(
num_stages=2,
num_warps=4,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
inductor_meta={'kernel_name': 'triton_tem_fused__to_copy_add_mean_mm_mul_pow_rsqrt_1'},
)
@triton.jit
def triton_(arg_A, arg_B, out_ptr0, ks0):
GROUP_M : tl.constexpr = 8
EVEN_K : tl.constexpr = True
ALLOW_TF32 : tl.constexpr = True
ACC_TYPE : tl.constexpr = tl.float32
B_PROLOGUE_CAST_TYPE : tl.constexpr = None
BLOCK_M : tl.constexpr = 64
BLOCK_N : tl.constexpr = 64
BLOCK_K : tl.constexpr = 32
A = arg_A
B = arg_B
M = ks0
N = 49152
K = 3072
if M * N == 0:
# early exit due to zero-size input(s)
return
stride_am = 3072
stride_ak = 1
stride_bk = 1
stride_bn = 3072
# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
if B_PROLOGUE_CAST_TYPE is not None:
b = b.to(B_PROLOGUE_CAST_TYPE)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
xindex = idx_n + (49152*idx_m)
tl.store(out_ptr0 + (tl.broadcast_to(xindex, mask.shape)), acc, mask)
""",
)
import torch._inductor.kernel.mm_common
meta0 = {
"GROUP_M": 8,
"EVEN_K": True,
"ALLOW_TF32": True,
"ACC_TYPE": "tl.float32",
"B_PROLOGUE_CAST_TYPE": None,
"BLOCK_M": 64,
"BLOCK_N": 64,
"BLOCK_K": 32,
}
# kernel path: /tmp/torchinductor_ubuntu/hz/chzpreui7cisscnbe7lrfewue7t4pg7pbnbouuqqq5d2jjbc3wsv.py
# Source Nodes: [gelu, inter], Original ATen: [aten.gelu, aten.mul]
# gelu => add_2, convert_element_type_4, convert_element_type_5, erf, mul_2, mul_3, mul_4
# inter => mul_5
triton_poi_fused_gelu_mul_2 = async_compile.triton(
"triton_",
"""
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, pointwise
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
@pointwise(
size_hints=[8388608],
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(2,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_mul_2', 'mutated_arg_names': []},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 24576
x1 = (xindex // 24576)
x2 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (49152*x1)), None).to(tl.float32)
tmp11 = tl.load(in_ptr0 + (24576 + x0 + (49152*x1)), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = 0.5
tmp3 = tmp1 * tmp2
tmp4 = 0.7071067811865476
tmp5 = tmp1 * tmp4
tmp6 = tl.math.erf(tmp5)
tmp7 = 1.0
tmp8 = tmp6 + tmp7
tmp9 = tmp3 * tmp8
tmp10 = tmp9.to(tl.float32)
tmp12 = tmp10 * tmp11
tl.store(out_ptr0 + (x2), tmp12, None)
""",
)
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1 = args
args.clear()
s0 = arg3_1
assert_size_stride(arg0_1, (3072,), (1,))
assert_size_stride(arg1_1, (49152, 3072), (3072, 1))
assert_size_stride(arg2_1, (3072, 24576), (24576, 1))
assert_size_stride(arg4_1, (s0, 3072), (3072, 1))
assert_size_stride(arg5_1, (s0, 3072), (3072, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0) # no-op to ensure context
buf0 = empty((s0, 3072), device="cuda", dtype=torch.bfloat16)
buf2 = empty((s0, 3072), device="cuda", dtype=torch.bfloat16)
# Source Nodes: [add, attn_res, hidden_states_1, hidden_states_2, hidden_states_3, normed_attn_res_output, pow_1, rsqrt, variance], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
stream0 = get_cuda_stream(0)
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0.run(
arg5_1,
arg4_1,
arg0_1,
buf0,
buf2,
arg5_1,
s0,
3072,
grid=grid(s0),
stream=stream0,
)
del arg0_1
del arg4_1
del arg5_1
buf3 = empty((s0, 49152), device="cuda", dtype=torch.bfloat16)
# Source Nodes: [add, gate_up_states, hidden_states_1, hidden_states_2, hidden_states_3, normed_attn_res_output, pow_1, rsqrt, variance], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mm, aten.mul, aten.pow, aten.rsqrt]
triton_tem_fused__to_copy_add_mean_mm_mul_pow_rsqrt_1.run(
buf2,
arg1_1,
buf3,
s0,
grid=torch._inductor.kernel.mm_common.mm_grid(s0, 49152, meta0),
stream=stream0,
)
del arg1_1
buf4 = empty((s0, 24576), device="cuda", dtype=torch.bfloat16)
# Source Nodes: [gelu, inter], Original ATen: [aten.gelu, aten.mul]
triton_poi_fused_gelu_mul_2_xnumel = 24576 * s0
triton_poi_fused_gelu_mul_2.run(
buf3,
buf4,
triton_poi_fused_gelu_mul_2_xnumel,
grid=grid(triton_poi_fused_gelu_mul_2_xnumel),
stream=stream0,
)
del buf3
buf5 = buf2
del buf2 # reuse
# Source Nodes: [gelu, inter, mlp_output], Original ATen: [aten.gelu, aten.mm, aten.mul]
extern_kernels.mm(
buf4, reinterpret_tensor(arg2_1, (24576, 3072), (1, 24576), 0), out=buf5
)
del arg2_1
return (
buf5,
buf0,
)
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((3072,), (1,), device="cuda:0", dtype=torch.bfloat16)
arg1_1 = rand_strided(
(49152, 3072), (3072, 1), device="cuda:0", dtype=torch.bfloat16
)
arg2_1 = rand_strided(
(3072, 24576), (24576, 1), device="cuda:0", dtype=torch.bfloat16
)
arg3_1 = 256
arg4_1 = rand_strided((256, 3072), (3072, 1), device="cuda:0", dtype=torch.bfloat16)
arg5_1 = rand_strided((256, 3072), (3072, 1), device="cuda:0", dtype=torch.bfloat16)
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main("None", benchmark_compiled_module)

View File

@ -0,0 +1,444 @@
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_ubuntu/gp/cgpizln3o4666auqh5ql3ey4lrkrhsxsjzl4au7u2bbiig7f5mda.py
# Source Nodes: [add, attn_res, hidden_states_1, hidden_states_2, hidden_states_3, normed_attn_res_output, pow_1, rsqrt, variance], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
# add => add_1
# attn_res => add
# hidden_states_1 => convert_element_type
# hidden_states_2 => mul
# hidden_states_3 => mul_1
# normed_attn_res_output => convert_element_type_1
# pow_1 => pow_1
# rsqrt => rsqrt
# variance => mean
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0 = async_compile.triton(
"triton_",
"""
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, reduction
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
@reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: 'i32', 7: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(7,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_ptr0', 'out_ptr3']}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr2, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 3072
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
_tmp6 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp3 * tmp3
tmp5 = tl.broadcast_to(tmp4, [XBLOCK, RBLOCK])
tmp7 = _tmp6 + tmp5
_tmp6 = tl.where(rmask, tmp7, _tmp6)
tl.store(out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp2, rmask)
tmp6 = tl.sum(_tmp6, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp8 = tl.load(out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp16 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp9 = tmp8.to(tl.float32)
tmp10 = 3072.0
tmp11 = tmp6 / tmp10
tmp12 = 1e-06
tmp13 = tmp11 + tmp12
tmp14 = tl.math.rsqrt(tmp13)
tmp15 = tmp9 * tmp14
tmp17 = tmp16.to(tl.float32)
tmp18 = tmp15 * tmp17
tmp19 = tmp18.to(tl.float32)
tl.store(out_ptr2 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp19, rmask)
tl.store(out_ptr3 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp8, rmask)
""",
)
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
# kernel path: /tmp/torchinductor_ubuntu/x7/cx7girdkwd3novmgykt7ibcsyg4y34ib5e4yrg2amurt7jhrddpa.py
# Source Nodes: [add, gate_up_states, hidden_states_1, hidden_states_2, hidden_states_3, normed_attn_res_output, pow_1, rsqrt, variance], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mm, aten.mul, aten.pow, aten.rsqrt]
# add => add_1
# gate_up_states => mm
# hidden_states_1 => convert_element_type
# hidden_states_2 => mul
# hidden_states_3 => mul_1
# normed_attn_res_output => convert_element_type_1
# pow_1 => pow_1
# rsqrt => rsqrt
# variance => mean
triton_tem_fused__to_copy_add_mean_mm_mul_pow_rsqrt_1 = async_compile.triton(
"triton_",
"""
import triton.language as tl
import triton
from torch._inductor.triton_heuristics import template
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
@template(
num_stages=5,
num_warps=4,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
inductor_meta={'kernel_name': 'triton_tem_fused__to_copy_add_mean_mm_mul_pow_rsqrt_1'},
)
@triton.jit
def triton_(arg_A, arg_B, out_ptr0):
GROUP_M : tl.constexpr = 8
EVEN_K : tl.constexpr = True
ALLOW_TF32 : tl.constexpr = True
ACC_TYPE : tl.constexpr = tl.float32
B_PROLOGUE_CAST_TYPE : tl.constexpr = None
BLOCK_M : tl.constexpr = 16
BLOCK_N : tl.constexpr = 64
BLOCK_K : tl.constexpr = 32
A = arg_A
B = arg_B
M = 1
N = 49152
K = 3072
if M * N == 0:
# early exit due to zero-size input(s)
return
stride_am = 3072
stride_ak = 1
stride_bk = 1
stride_bn = 3072
# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
if B_PROLOGUE_CAST_TYPE is not None:
b = b.to(B_PROLOGUE_CAST_TYPE)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
xindex = idx_n + (49152*idx_m)
tl.store(out_ptr0 + (tl.broadcast_to(idx_n, mask.shape)), acc, mask)
""",
)
import torch._inductor.kernel.mm_common
meta0 = {
"GROUP_M": 8,
"EVEN_K": True,
"ALLOW_TF32": True,
"ACC_TYPE": "tl.float32",
"B_PROLOGUE_CAST_TYPE": None,
"BLOCK_M": 16,
"BLOCK_N": 64,
"BLOCK_K": 32,
}
# kernel path: /tmp/torchinductor_ubuntu/as/caswlqxmo7j7f3e26h3tbfwmhcdowf44zrv6mggnh5t7q5w2izpr.py
# Source Nodes: [gelu, mul_2], Original ATen: [aten.gelu, aten.mul]
# gelu => add_2, convert_element_type_4, convert_element_type_5, erf, mul_2, mul_3, mul_4
# mul_2 => mul_5
triton_poi_fused_gelu_mul_2 = async_compile.triton(
"triton_",
"""
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, pointwise
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
@pointwise(
size_hints=[32768],
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(2,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_mul_2', 'mutated_arg_names': []},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 24576
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp11 = tl.load(in_ptr0 + (24576 + x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = 0.5
tmp3 = tmp1 * tmp2
tmp4 = 0.7071067811865476
tmp5 = tmp1 * tmp4
tmp6 = tl.math.erf(tmp5)
tmp7 = 1.0
tmp8 = tmp6 + tmp7
tmp9 = tmp3 * tmp8
tmp10 = tmp9.to(tl.float32)
tmp12 = tmp10 * tmp11
tl.store(out_ptr0 + (x0), tmp12, None)
""",
)
# kernel path: /tmp/torchinductor_ubuntu/ux/cuxfeb5redhx2vl5744hlfuhmstxy2yqiomik5glekkt3jn6vzq5.py
# Source Nodes: [gelu, mlp_output, mul_2], Original ATen: [aten.gelu, aten.mm, aten.mul]
# gelu => add_2, convert_element_type_4, convert_element_type_5, erf, mul_2, mul_3, mul_4
# mlp_output => mm_1
# mul_2 => mul_5
triton_tem_fused_gelu_mm_mul_3 = async_compile.triton(
"triton_",
"""
import triton.language as tl
import triton
from torch._inductor.triton_heuristics import template
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
@template(
num_stages=5,
num_warps=4,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
inductor_meta={'kernel_name': 'triton_tem_fused_gelu_mm_mul_3'},
)
@triton.jit
def triton_(arg_A, arg_B, out_ptr0):
GROUP_M : tl.constexpr = 8
EVEN_K : tl.constexpr = True
ALLOW_TF32 : tl.constexpr = True
ACC_TYPE : tl.constexpr = tl.float32
B_PROLOGUE_CAST_TYPE : tl.constexpr = None
BLOCK_M : tl.constexpr = 16
BLOCK_N : tl.constexpr = 64
BLOCK_K : tl.constexpr = 32
A = arg_A
B = arg_B
M = 1
N = 3072
K = 24576
if M * N == 0:
# early exit due to zero-size input(s)
return
stride_am = 24576
stride_ak = 1
stride_bk = 1
stride_bn = 24576
# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
if B_PROLOGUE_CAST_TYPE is not None:
b = b.to(B_PROLOGUE_CAST_TYPE)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
xindex = idx_n + (3072*idx_m)
tl.store(out_ptr0 + (tl.broadcast_to(idx_n, mask.shape)), acc, mask)
""",
)
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1 = args
args.clear()
assert_size_stride(arg0_1, (3072,), (1,))
assert_size_stride(arg1_1, (49152, 3072), (3072, 1))
assert_size_stride(arg2_1, (3072, 24576), (24576, 1))
assert_size_stride(arg3_1, (1, 3072), (3072, 1))
assert_size_stride(arg4_1, (1, 3072), (3072, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0) # no-op to ensure context
buf0 = empty((1, 3072), device="cuda", dtype=torch.bfloat16)
buf2 = empty((1, 3072), device="cuda", dtype=torch.bfloat16)
# Source Nodes: [add, attn_res, hidden_states_1, hidden_states_2, hidden_states_3, normed_attn_res_output, pow_1, rsqrt, variance], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
stream0 = get_cuda_stream(0)
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0.run(
arg4_1,
arg3_1,
arg0_1,
buf0,
buf2,
arg4_1,
1,
3072,
grid=grid(1),
stream=stream0,
)
del arg0_1
del arg3_1
del arg4_1
buf3 = empty((1, 49152), device="cuda", dtype=torch.bfloat16)
# Source Nodes: [add, gate_up_states, hidden_states_1, hidden_states_2, hidden_states_3, normed_attn_res_output, pow_1, rsqrt, variance], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mm, aten.mul, aten.pow, aten.rsqrt]
triton_tem_fused__to_copy_add_mean_mm_mul_pow_rsqrt_1.run(
buf2,
arg1_1,
buf3,
grid=torch._inductor.kernel.mm_common.mm_grid(1, 49152, meta0),
stream=stream0,
)
del arg1_1
buf4 = empty((1, 24576), device="cuda", dtype=torch.bfloat16)
# Source Nodes: [gelu, mul_2], Original ATen: [aten.gelu, aten.mul]
triton_poi_fused_gelu_mul_2.run(
buf3, buf4, 24576, grid=grid(24576), stream=stream0
)
del buf3
buf5 = buf2
del buf2 # reuse
# Source Nodes: [gelu, mlp_output, mul_2], Original ATen: [aten.gelu, aten.mm, aten.mul]
triton_tem_fused_gelu_mm_mul_3.run(
buf4,
arg2_1,
buf5,
grid=torch._inductor.kernel.mm_common.mm_grid(1, 3072, meta0),
stream=stream0,
)
del arg2_1
return (
buf5,
buf0,
)
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((3072,), (1,), device="cuda:0", dtype=torch.bfloat16)
arg1_1 = rand_strided(
(49152, 3072), (3072, 1), device="cuda:0", dtype=torch.bfloat16
)
arg2_1 = rand_strided(
(3072, 24576), (24576, 1), device="cuda:0", dtype=torch.bfloat16
)
arg3_1 = rand_strided((1, 3072), (3072, 1), device="cuda:0", dtype=torch.bfloat16)
arg4_1 = rand_strided((1, 3072), (3072, 1), device="cuda:0", dtype=torch.bfloat16)
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main("None", benchmark_compiled_module)

View File

@ -36,6 +36,8 @@ from text_generation_server.utils.layers import (
get_linear,
FastRMSNorm,
)
from .custom_triton_kernels.gemma_kernel import call
from .custom_triton_kernels.gemma_kernel_decode_one import call as call_decode_one
class GemmaConfig(PretrainedConfig):
@ -61,6 +63,7 @@ class GemmaConfig(PretrainedConfig):
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
post_attn_method="post_attn_layer_norm_and_mlp_fused",
**kwargs,
):
self.vocab_size = vocab_size
@ -84,6 +87,7 @@ class GemmaConfig(PretrainedConfig):
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.post_attn_method = post_attn_method
super().__init__(
pad_token_id=pad_token_id,
@ -310,6 +314,67 @@ class FlashGemmaLayer(nn.Module):
eps=config.rms_norm_eps,
)
if config.post_attn_method == "post_attn_layer_norm_and_mlp_fused":
arg0_1 = self.post_attention_layernorm.weight
arg1_1 = self.mlp.gate_up_proj.linear.weight
arg2_1 = self.mlp.down_proj.linear.weight
# wrap the variable sized kernel call and use it as inner_forward
def _wrapped_inner_forward(
res,
attn_output,
):
arg3_1 = res.size(0)
arg4_1 = res
arg5_1 = attn_output
return call(
[arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1],
)
self.inner_forward = _wrapped_inner_forward
# now wrap the kernel for a fixed size input and use it as decode_one_forward
def _wrapped_decode_one_forward(
res,
attn_output,
):
arg3_1 = res
arg4_1 = attn_output
return call_decode_one(
[arg0_1, arg1_1, arg2_1, arg3_1, arg4_1],
)
self.decode_one_forward = _wrapped_decode_one_forward
elif config.post_attn_method == "compile":
# compile the function specifically for a fixed size input (variable one compiled outside this example)
self.decode_one_forward = torch.compile(
# the original forward function
self._inner_forward,
# avoid "reduce-overhead" mode, as it caches the graph and misleads timing
mode="max-autotune",
# avoid graph breaks
fullgraph=True,
# allow compilation for differnt input shapes
dynamic=True,
)
self.inner_forward = self._inner_forward
else:
# if nothing specified use the original forward function in both cases
self.decode_one_forward = self._inner_forward
self.inner_forward = self._inner_forward
def _inner_forward(
self,
res,
attn_output,
):
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output)
return mlp_output, attn_res
def forward(
self,
hidden_states,
@ -338,15 +403,17 @@ class FlashGemmaLayer(nn.Module):
max_s,
)
# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
if hidden_states.shape[0] == 1:
return self.decode_one_forward(
res,
attn_output,
)
return self.inner_forward(
res,
attn_output,
)
mlp_output = self.mlp(normed_attn_res_output)
return mlp_output, attn_res
class FlashGemmaModel(torch.nn.Module):
def __init__(self, config, weights):

View File

@ -28,6 +28,7 @@ class FlashGemma(FlashCausalLM):
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
post_attn_method: Optional[str] = None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -51,6 +52,7 @@ class FlashGemma(FlashCausalLM):
)
config.quantize = quantize
config.use_medusa = use_medusa
config.post_attn_method = post_attn_method
torch.distributed.barrier(group=self.process_group)

View File

@ -176,6 +176,7 @@ def serve(
dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
post_attn_method: Optional[str],
):
async def serve_inner(
model_id: str,
@ -185,6 +186,7 @@ def serve(
speculate: Optional[int] = None,
dtype: Optional[str] = None,
trust_remote_code: bool = False,
post_attn_method: Optional[str] = None,
):
unix_socket_template = "unix://{}-{}"
if sharded:
@ -206,6 +208,7 @@ def serve(
speculate,
dtype,
trust_remote_code,
post_attn_method,
)
except Exception:
logger.exception("Error when initializing model")
@ -239,6 +242,13 @@ def serve(
asyncio.run(
serve_inner(
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
model_id,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
post_attn_method,
)
)