diff --git a/script/throughput.py b/script/throughput.py
new file mode 100644
index 00000000..da3399c3
--- /dev/null
+++ b/script/throughput.py
@@ -0,0 +1,107 @@
+from concurrent import futures
+from typing import Any, Dict, List
+import time
+
+import requests
+import json
+
+
+# copied from other benchmarking script
+def simple_throughput(
+ predictor: Any, payloads: List[Dict[str, Any]], concurrent_requests: int
+):
+ time_start = time.time()
+ with futures.ThreadPoolExecutor(max_workers=concurrent_requests) as executor:
+ responses = list(executor.map(predictor.predict, payloads))
+
+ if isinstance(responses[0], list):
+ responses = [response[0] for response in responses]
+
+ total_tokens = sum(
+ [
+ len([token for token in x["details"]["tokens"] if token["id"] != -1])
+ for x in responses
+ ]
+ )
+ token_throughput = total_tokens / (time.time() - time_start)
+ generated_tokens_per_request = total_tokens / len(payloads)
+ print(f"concurrent requests: {concurrent_requests}")
+ print(f"throughput: {token_throughput:.2f}")
+ print(
+ f"generated tokens per request: {generated_tokens_per_request:.2f}", end="\n\n"
+ )
+
+
+# setup to help run
+real = True
+
+
+class Predictor:
+ def predict(self, payload: Dict[str, Any]) -> Dict[str, Any]:
+ response_json = {"details": {"tokens": []}}
+ if real:
+ url = "http://localhost:3000/generate"
+ headers = {"Content-Type": "application/json"}
+ response = requests.post(url, data=json.dumps(payload), headers=headers)
+ response_json = response.json()
+ else:
+ time.sleep(0.1)
+
+ return response_json
+
+
+max_new_tokens = 100
+
+if __name__ == "__main__":
+ print("Running throughput test")
+ predictor = Predictor()
+ payloads = [
+ {
+ "inputs": "[INST] I am making mayonnaise, it was starting to thicken but now it has become runny and liquid again, is there any way to salvage it? [/INST]Yes, it's possible to fix runny mayonnaise! The most common reason for mayonnaise becoming runny is because the oil was added too quickly or the egg yolk wasn't emulsified properly. Here are some steps you can take to fix it:\n\n1. Separate another egg yolk and place it in a clean, dry bowl.\n2. Slowly add the runny mayonnaise to the egg yolk while whisking vigorously.\n3. Once all the runny mayonnaise has been added, continue whisking until the mixture has emulsified and thickened.\n4. If the mayonnaise is still too runny, you can add another egg yolk and repeat the process.\n\nIf the mayonnaise still won't thicken, you can try adding a small amount of dijon mustard or vinegar to the mixture, which can act as emulsifiers and help stabilize the mayonnaise. It's important to add these ingredients slowly and in small amounts to avoid over-thinning the mixture.[INST] What is optimal Mayonnaise thickness? [/INST]",
+ "parameters": {"details": True, "max_new_tokens": max_new_tokens},
+ },
+ {
+ "inputs": "[INST] Why Aristotelian view of physics (impetus and stuff) is wrong? [/INST]",
+ "parameters": {"details": True, "max_new_tokens": max_new_tokens},
+ },
+ {
+ "inputs": "[INST] If you were to image the experience of eating from only others descriptions and never having done it yourself, how would you describe it. [/INST]",
+ "parameters": {"details": True, "max_new_tokens": max_new_tokens},
+ },
+ {
+ "inputs": "[INST] What is the best way to cook a steak? [/INST]",
+ "parameters": {"details": True, "max_new_tokens": max_new_tokens},
+ },
+ {
+ "inputs": "[INST] How do you make a perfect omelette? [/INST]",
+ "parameters": {"details": True, "max_new_tokens": max_new_tokens},
+ },
+ {
+ "inputs": "[INST] What is the secret to a good pizza dough? [/INST]",
+ "parameters": {"details": True, "max_new_tokens": max_new_tokens},
+ },
+ {
+ "inputs": "[INST] How do you make a classic French onion soup? [/INST]",
+ "parameters": {"details": True, "max_new_tokens": max_new_tokens},
+ },
+ {
+ "inputs": "[INST] What is the best way to roast a chicken? [/INST]",
+ "parameters": {"details": True, "max_new_tokens": max_new_tokens},
+ },
+ {
+ "inputs": "[INST] How do you make a perfect chocolate cake? [/INST]",
+ "parameters": {"details": True, "max_new_tokens": max_new_tokens},
+ },
+ {
+ "inputs": "[INST] What is the secret to a good risotto? [/INST]",
+ "parameters": {"details": True, "max_new_tokens": max_new_tokens},
+ },
+ {
+ "inputs": "[INST] How do you make a classic spaghetti carbonara? [/INST]",
+ "parameters": {"details": True, "max_new_tokens": max_new_tokens},
+ },
+ ]
+
+ for concurrent_requests in [1, 2, 4, 8]:
+ simple_throughput(predictor, payloads, concurrent_requests)
+ print("Throughput test complete")
diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py
index bb0963d4..f565677b 100644
--- a/server/text_generation_server/cli.py
+++ b/server/text_generation_server/cli.py
@@ -40,6 +40,7 @@ def serve(
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
+ post_attn_method: Optional[str] = None,
):
if sharded:
assert (
@@ -96,6 +97,7 @@ def serve(
dtype,
trust_remote_code,
uds_path,
+ post_attn_method,
)
diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py
index 06792b0d..ca5f7b22 100644
--- a/server/text_generation_server/models/__init__.py
+++ b/server/text_generation_server/models/__init__.py
@@ -114,6 +114,7 @@ def get_model(
speculate: Optional[int],
dtype: Optional[str],
trust_remote_code: bool,
+ post_attn_method: Optional[str] = None,
) -> Model:
if dtype is None:
# Keep it as default for now and let
@@ -357,6 +358,7 @@ def get_model(
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
+ post_attn_method=post_attn_method,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
diff --git a/server/text_generation_server/models/custom_modeling/custom_triton_kernels/gemma_kernel.py b/server/text_generation_server/models/custom_modeling/custom_triton_kernels/gemma_kernel.py
new file mode 100644
index 00000000..f1879d08
--- /dev/null
+++ b/server/text_generation_server/models/custom_modeling/custom_triton_kernels/gemma_kernel.py
@@ -0,0 +1,354 @@
+from ctypes import c_void_p, c_long
+import torch
+import math
+import random
+import os
+import tempfile
+from math import inf, nan
+from torch._inductor.hooks import run_intermediate_hooks
+from torch._inductor.utils import maybe_profile
+from torch._inductor.codegen.memory_planning import _align as align
+
+from torch import device, empty, empty_strided
+from torch._inductor.codecache import AsyncCompile
+from torch._inductor.select_algorithm import extern_kernels
+
+aten = torch.ops.aten
+inductor_ops = torch.ops.inductor
+assert_size_stride = torch._C._dynamo.guards.assert_size_stride
+alloc_from_pool = torch.ops.inductor._alloc_from_pool
+reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
+async_compile = AsyncCompile()
+
+
+# kernel path: /tmp/torchinductor_ubuntu/sb/csbqhwspefuy7jp6pcbyvalomk4wvqyt4mplgbphfugdwu66zbp7.py
+# Source Nodes: [add, attn_res, hidden_states_1, hidden_states_2, hidden_states_3, normed_attn_res_output, pow_1, rsqrt, variance], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
+# add => add_1
+# attn_res => add
+# hidden_states_1 => convert_element_type
+# hidden_states_2 => mul
+# hidden_states_3 => mul_1
+# normed_attn_res_output => convert_element_type_1
+# pow_1 => pow_1
+# rsqrt => rsqrt
+# variance => mean
+triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0 = async_compile.triton(
+ "triton_",
+ """
+import triton
+import triton.language as tl
+from torch._inductor.ir import ReductionHint
+from torch._inductor.ir import TileHint
+from torch._inductor.triton_heuristics import AutotuneHint, reduction
+from torch._inductor.utils import instance_descriptor
+from torch._inductor import triton_helpers
+
+@reduction(
+ size_hints=[256, 4096],
+ reduction_hint=ReductionHint.INNER,
+ filename=__file__,
+ triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: 'i32', 7: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(7,))]},
+ inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_ptr0', 'out_ptr3']}
+)
+@triton.jit
+def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr2, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
+ rnumel = 3072
+ xoffset = tl.program_id(0) * XBLOCK
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
+ xmask = xindex < xnumel
+ rbase = tl.arange(0, RBLOCK)[None, :]
+ x0 = xindex
+ _tmp6 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
+ for roffset in range(0, rnumel, RBLOCK):
+ rindex = roffset + rbase
+ rmask = rindex < rnumel
+ r1 = rindex
+ tmp0 = tl.load(in_ptr0 + (r1 + (3072*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
+ tmp1 = tl.load(in_ptr1 + (r1 + (3072*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
+ tmp2 = tmp0 + tmp1
+ tmp3 = tmp2.to(tl.float32)
+ tmp4 = tmp3 * tmp3
+ tmp5 = tl.broadcast_to(tmp4, [XBLOCK, RBLOCK])
+ tmp7 = _tmp6 + tmp5
+ _tmp6 = tl.where(rmask & xmask, tmp7, _tmp6)
+ tl.store(out_ptr0 + (r1 + (3072*x0)), tmp2, rmask & xmask)
+ tmp6 = tl.sum(_tmp6, 1)[:, None]
+ for roffset in range(0, rnumel, RBLOCK):
+ rindex = roffset + rbase
+ rmask = rindex < rnumel
+ r1 = rindex
+ tmp8 = tl.load(out_ptr0 + (r1 + (3072*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
+ tmp16 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
+ tmp9 = tmp8.to(tl.float32)
+ tmp10 = 3072.0
+ tmp11 = tmp6 / tmp10
+ tmp12 = 1e-06
+ tmp13 = tmp11 + tmp12
+ tmp14 = tl.math.rsqrt(tmp13)
+ tmp15 = tmp9 * tmp14
+ tmp17 = tmp16.to(tl.float32)
+ tmp18 = tmp15 * tmp17
+ tmp19 = tmp18.to(tl.float32)
+ tl.store(out_ptr2 + (r1 + (3072*x0)), tmp19, rmask & xmask)
+ tl.store(out_ptr3 + (r1 + (3072*x0)), tmp8, rmask & xmask)
+""",
+)
+
+import triton
+import triton.language as tl
+from torch._inductor.triton_heuristics import grid, start_graph, end_graph
+from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
+
+
+# kernel path: /tmp/torchinductor_ubuntu/3y/c3ycepp3q64t65lfi2wcqq5skonqvz6t4okgw4dboyrxwdspdaff.py
+# Source Nodes: [add, gate_up_states, hidden_states_1, hidden_states_2, hidden_states_3, normed_attn_res_output, pow_1, rsqrt, variance], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mm, aten.mul, aten.pow, aten.rsqrt]
+# add => add_1
+# gate_up_states => mm
+# hidden_states_1 => convert_element_type
+# hidden_states_2 => mul
+# hidden_states_3 => mul_1
+# normed_attn_res_output => convert_element_type_1
+# pow_1 => pow_1
+# rsqrt => rsqrt
+# variance => mean
+triton_tem_fused__to_copy_add_mean_mm_mul_pow_rsqrt_1 = async_compile.triton(
+ "triton_",
+ """
+import triton.language as tl
+import triton
+from torch._inductor.triton_heuristics import template
+from torch._inductor.utils import instance_descriptor
+from torch._inductor import triton_helpers
+
+
+@template(
+ num_stages=2,
+ num_warps=4,
+ triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
+ inductor_meta={'kernel_name': 'triton_tem_fused__to_copy_add_mean_mm_mul_pow_rsqrt_1'},
+)
+@triton.jit
+
+def triton_(arg_A, arg_B, out_ptr0, ks0):
+ GROUP_M : tl.constexpr = 8
+ EVEN_K : tl.constexpr = True
+ ALLOW_TF32 : tl.constexpr = True
+ ACC_TYPE : tl.constexpr = tl.float32
+ B_PROLOGUE_CAST_TYPE : tl.constexpr = None
+ BLOCK_M : tl.constexpr = 64
+ BLOCK_N : tl.constexpr = 64
+ BLOCK_K : tl.constexpr = 32
+
+ A = arg_A
+ B = arg_B
+
+ M = ks0
+ N = 49152
+ K = 3072
+ if M * N == 0:
+ # early exit due to zero-size input(s)
+ return
+ stride_am = 3072
+ stride_ak = 1
+ stride_bk = 1
+ stride_bn = 3072
+
+ # based on triton.ops.matmul
+ pid = tl.program_id(0)
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
+
+ # re-order program ID for better L2 performance
+ width = GROUP_M * grid_n
+ group_id = pid // width
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
+ pid_m = group_id * GROUP_M + (pid % group_size)
+ pid_n = (pid % width) // (group_size)
+
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(K, 0, -BLOCK_K):
+ if EVEN_K:
+ a = tl.load(A)
+ b = tl.load(B)
+ else:
+ a = tl.load(A, mask=rk[None, :] < k, other=0.)
+ b = tl.load(B, mask=rk[:, None] < k, other=0.)
+ if B_PROLOGUE_CAST_TYPE is not None:
+ b = b.to(B_PROLOGUE_CAST_TYPE)
+ acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+
+ # rematerialize rm and rn to save registers
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ idx_m = rm[:, None]
+ idx_n = rn[None, :]
+ mask = (idx_m < M) & (idx_n < N)
+
+ # inductor generates a suffix
+ xindex = idx_n + (49152*idx_m)
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, mask.shape)), acc, mask)
+""",
+)
+import torch._inductor.kernel.mm_common
+
+meta0 = {
+ "GROUP_M": 8,
+ "EVEN_K": True,
+ "ALLOW_TF32": True,
+ "ACC_TYPE": "tl.float32",
+ "B_PROLOGUE_CAST_TYPE": None,
+ "BLOCK_M": 64,
+ "BLOCK_N": 64,
+ "BLOCK_K": 32,
+}
+
+
+# kernel path: /tmp/torchinductor_ubuntu/hz/chzpreui7cisscnbe7lrfewue7t4pg7pbnbouuqqq5d2jjbc3wsv.py
+# Source Nodes: [gelu, inter], Original ATen: [aten.gelu, aten.mul]
+# gelu => add_2, convert_element_type_4, convert_element_type_5, erf, mul_2, mul_3, mul_4
+# inter => mul_5
+triton_poi_fused_gelu_mul_2 = async_compile.triton(
+ "triton_",
+ """
+import triton
+import triton.language as tl
+from torch._inductor.ir import ReductionHint
+from torch._inductor.ir import TileHint
+from torch._inductor.triton_heuristics import AutotuneHint, pointwise
+from torch._inductor.utils import instance_descriptor
+from torch._inductor import triton_helpers
+
+@pointwise(
+ size_hints=[8388608],
+ filename=__file__,
+ triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(2,))]},
+ inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_mul_2', 'mutated_arg_names': []},
+ min_elem_per_thread=0
+)
+@triton.jit
+def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
+ xoffset = tl.program_id(0) * XBLOCK
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
+ xmask = xindex < xnumel
+ x0 = xindex % 24576
+ x1 = (xindex // 24576)
+ x2 = xindex
+ tmp0 = tl.load(in_ptr0 + (x0 + (49152*x1)), None).to(tl.float32)
+ tmp11 = tl.load(in_ptr0 + (24576 + x0 + (49152*x1)), None).to(tl.float32)
+ tmp1 = tmp0.to(tl.float32)
+ tmp2 = 0.5
+ tmp3 = tmp1 * tmp2
+ tmp4 = 0.7071067811865476
+ tmp5 = tmp1 * tmp4
+ tmp6 = tl.math.erf(tmp5)
+ tmp7 = 1.0
+ tmp8 = tmp6 + tmp7
+ tmp9 = tmp3 * tmp8
+ tmp10 = tmp9.to(tl.float32)
+ tmp12 = tmp10 * tmp11
+ tl.store(out_ptr0 + (x2), tmp12, None)
+""",
+)
+
+
+async_compile.wait(globals())
+del async_compile
+
+
+def call(args):
+ arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1 = args
+ args.clear()
+ s0 = arg3_1
+ assert_size_stride(arg0_1, (3072,), (1,))
+ assert_size_stride(arg1_1, (49152, 3072), (3072, 1))
+ assert_size_stride(arg2_1, (3072, 24576), (24576, 1))
+ assert_size_stride(arg4_1, (s0, 3072), (3072, 1))
+ assert_size_stride(arg5_1, (s0, 3072), (3072, 1))
+ with torch.cuda._DeviceGuard(0):
+ torch.cuda.set_device(0) # no-op to ensure context
+ buf0 = empty((s0, 3072), device="cuda", dtype=torch.bfloat16)
+ buf2 = empty((s0, 3072), device="cuda", dtype=torch.bfloat16)
+ # Source Nodes: [add, attn_res, hidden_states_1, hidden_states_2, hidden_states_3, normed_attn_res_output, pow_1, rsqrt, variance], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
+ stream0 = get_cuda_stream(0)
+ triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0.run(
+ arg5_1,
+ arg4_1,
+ arg0_1,
+ buf0,
+ buf2,
+ arg5_1,
+ s0,
+ 3072,
+ grid=grid(s0),
+ stream=stream0,
+ )
+ del arg0_1
+ del arg4_1
+ del arg5_1
+ buf3 = empty((s0, 49152), device="cuda", dtype=torch.bfloat16)
+ # Source Nodes: [add, gate_up_states, hidden_states_1, hidden_states_2, hidden_states_3, normed_attn_res_output, pow_1, rsqrt, variance], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mm, aten.mul, aten.pow, aten.rsqrt]
+ triton_tem_fused__to_copy_add_mean_mm_mul_pow_rsqrt_1.run(
+ buf2,
+ arg1_1,
+ buf3,
+ s0,
+ grid=torch._inductor.kernel.mm_common.mm_grid(s0, 49152, meta0),
+ stream=stream0,
+ )
+ del arg1_1
+ buf4 = empty((s0, 24576), device="cuda", dtype=torch.bfloat16)
+ # Source Nodes: [gelu, inter], Original ATen: [aten.gelu, aten.mul]
+ triton_poi_fused_gelu_mul_2_xnumel = 24576 * s0
+ triton_poi_fused_gelu_mul_2.run(
+ buf3,
+ buf4,
+ triton_poi_fused_gelu_mul_2_xnumel,
+ grid=grid(triton_poi_fused_gelu_mul_2_xnumel),
+ stream=stream0,
+ )
+ del buf3
+ buf5 = buf2
+ del buf2 # reuse
+ # Source Nodes: [gelu, inter, mlp_output], Original ATen: [aten.gelu, aten.mm, aten.mul]
+ extern_kernels.mm(
+ buf4, reinterpret_tensor(arg2_1, (24576, 3072), (1, 24576), 0), out=buf5
+ )
+ del arg2_1
+ return (
+ buf5,
+ buf0,
+ )
+
+
+def benchmark_compiled_module(times=10, repeat=10):
+ from torch._dynamo.testing import rand_strided
+ from torch._inductor.utils import print_performance
+
+ arg0_1 = rand_strided((3072,), (1,), device="cuda:0", dtype=torch.bfloat16)
+ arg1_1 = rand_strided(
+ (49152, 3072), (3072, 1), device="cuda:0", dtype=torch.bfloat16
+ )
+ arg2_1 = rand_strided(
+ (3072, 24576), (24576, 1), device="cuda:0", dtype=torch.bfloat16
+ )
+ arg3_1 = 256
+ arg4_1 = rand_strided((256, 3072), (3072, 1), device="cuda:0", dtype=torch.bfloat16)
+ arg5_1 = rand_strided((256, 3072), (3072, 1), device="cuda:0", dtype=torch.bfloat16)
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1])
+ return print_performance(fn, times=times, repeat=repeat)
+
+
+if __name__ == "__main__":
+ from torch._inductor.wrapper_benchmark import compiled_module_main
+
+ compiled_module_main("None", benchmark_compiled_module)
diff --git a/server/text_generation_server/models/custom_modeling/custom_triton_kernels/gemma_kernel_decode_one.py b/server/text_generation_server/models/custom_modeling/custom_triton_kernels/gemma_kernel_decode_one.py
new file mode 100644
index 00000000..77519d6b
--- /dev/null
+++ b/server/text_generation_server/models/custom_modeling/custom_triton_kernels/gemma_kernel_decode_one.py
@@ -0,0 +1,444 @@
+from ctypes import c_void_p, c_long
+import torch
+import math
+import random
+import os
+import tempfile
+from math import inf, nan
+from torch._inductor.hooks import run_intermediate_hooks
+from torch._inductor.utils import maybe_profile
+from torch._inductor.codegen.memory_planning import _align as align
+
+from torch import device, empty, empty_strided
+from torch._inductor.codecache import AsyncCompile
+from torch._inductor.select_algorithm import extern_kernels
+
+aten = torch.ops.aten
+inductor_ops = torch.ops.inductor
+assert_size_stride = torch._C._dynamo.guards.assert_size_stride
+alloc_from_pool = torch.ops.inductor._alloc_from_pool
+reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
+async_compile = AsyncCompile()
+
+
+# kernel path: /tmp/torchinductor_ubuntu/gp/cgpizln3o4666auqh5ql3ey4lrkrhsxsjzl4au7u2bbiig7f5mda.py
+# Source Nodes: [add, attn_res, hidden_states_1, hidden_states_2, hidden_states_3, normed_attn_res_output, pow_1, rsqrt, variance], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
+# add => add_1
+# attn_res => add
+# hidden_states_1 => convert_element_type
+# hidden_states_2 => mul
+# hidden_states_3 => mul_1
+# normed_attn_res_output => convert_element_type_1
+# pow_1 => pow_1
+# rsqrt => rsqrt
+# variance => mean
+triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0 = async_compile.triton(
+ "triton_",
+ """
+import triton
+import triton.language as tl
+from torch._inductor.ir import ReductionHint
+from torch._inductor.ir import TileHint
+from torch._inductor.triton_heuristics import AutotuneHint, reduction
+from torch._inductor.utils import instance_descriptor
+from torch._inductor import triton_helpers
+
+@reduction(
+ size_hints=[1, 4096],
+ reduction_hint=ReductionHint.INNER,
+ filename=__file__,
+ triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: 'i32', 7: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(7,))]},
+ inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_ptr0', 'out_ptr3']}
+)
+@triton.jit
+def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr2, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
+ xnumel = 1
+ rnumel = 3072
+ xoffset = tl.program_id(0) * XBLOCK
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
+ xmask = xindex < xnumel
+ rbase = tl.arange(0, RBLOCK)[None, :]
+ _tmp6 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
+ for roffset in range(0, rnumel, RBLOCK):
+ rindex = roffset + rbase
+ rmask = rindex < rnumel
+ r0 = rindex
+ tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
+ tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
+ tmp2 = tmp0 + tmp1
+ tmp3 = tmp2.to(tl.float32)
+ tmp4 = tmp3 * tmp3
+ tmp5 = tl.broadcast_to(tmp4, [XBLOCK, RBLOCK])
+ tmp7 = _tmp6 + tmp5
+ _tmp6 = tl.where(rmask, tmp7, _tmp6)
+ tl.store(out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp2, rmask)
+ tmp6 = tl.sum(_tmp6, 1)[:, None]
+ for roffset in range(0, rnumel, RBLOCK):
+ rindex = roffset + rbase
+ rmask = rindex < rnumel
+ r0 = rindex
+ tmp8 = tl.load(out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
+ tmp16 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
+ tmp9 = tmp8.to(tl.float32)
+ tmp10 = 3072.0
+ tmp11 = tmp6 / tmp10
+ tmp12 = 1e-06
+ tmp13 = tmp11 + tmp12
+ tmp14 = tl.math.rsqrt(tmp13)
+ tmp15 = tmp9 * tmp14
+ tmp17 = tmp16.to(tl.float32)
+ tmp18 = tmp15 * tmp17
+ tmp19 = tmp18.to(tl.float32)
+ tl.store(out_ptr2 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp19, rmask)
+ tl.store(out_ptr3 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp8, rmask)
+""",
+)
+
+import triton
+import triton.language as tl
+from torch._inductor.triton_heuristics import grid, start_graph, end_graph
+from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
+
+
+# kernel path: /tmp/torchinductor_ubuntu/x7/cx7girdkwd3novmgykt7ibcsyg4y34ib5e4yrg2amurt7jhrddpa.py
+# Source Nodes: [add, gate_up_states, hidden_states_1, hidden_states_2, hidden_states_3, normed_attn_res_output, pow_1, rsqrt, variance], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mm, aten.mul, aten.pow, aten.rsqrt]
+# add => add_1
+# gate_up_states => mm
+# hidden_states_1 => convert_element_type
+# hidden_states_2 => mul
+# hidden_states_3 => mul_1
+# normed_attn_res_output => convert_element_type_1
+# pow_1 => pow_1
+# rsqrt => rsqrt
+# variance => mean
+triton_tem_fused__to_copy_add_mean_mm_mul_pow_rsqrt_1 = async_compile.triton(
+ "triton_",
+ """
+import triton.language as tl
+import triton
+from torch._inductor.triton_heuristics import template
+from torch._inductor.utils import instance_descriptor
+from torch._inductor import triton_helpers
+
+
+@template(
+ num_stages=5,
+ num_warps=4,
+ triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
+ inductor_meta={'kernel_name': 'triton_tem_fused__to_copy_add_mean_mm_mul_pow_rsqrt_1'},
+)
+@triton.jit
+
+def triton_(arg_A, arg_B, out_ptr0):
+ GROUP_M : tl.constexpr = 8
+ EVEN_K : tl.constexpr = True
+ ALLOW_TF32 : tl.constexpr = True
+ ACC_TYPE : tl.constexpr = tl.float32
+ B_PROLOGUE_CAST_TYPE : tl.constexpr = None
+ BLOCK_M : tl.constexpr = 16
+ BLOCK_N : tl.constexpr = 64
+ BLOCK_K : tl.constexpr = 32
+
+ A = arg_A
+ B = arg_B
+
+ M = 1
+ N = 49152
+ K = 3072
+ if M * N == 0:
+ # early exit due to zero-size input(s)
+ return
+ stride_am = 3072
+ stride_ak = 1
+ stride_bk = 1
+ stride_bn = 3072
+
+ # based on triton.ops.matmul
+ pid = tl.program_id(0)
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
+
+ # re-order program ID for better L2 performance
+ width = GROUP_M * grid_n
+ group_id = pid // width
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
+ pid_m = group_id * GROUP_M + (pid % group_size)
+ pid_n = (pid % width) // (group_size)
+
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(K, 0, -BLOCK_K):
+ if EVEN_K:
+ a = tl.load(A)
+ b = tl.load(B)
+ else:
+ a = tl.load(A, mask=rk[None, :] < k, other=0.)
+ b = tl.load(B, mask=rk[:, None] < k, other=0.)
+ if B_PROLOGUE_CAST_TYPE is not None:
+ b = b.to(B_PROLOGUE_CAST_TYPE)
+ acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+
+ # rematerialize rm and rn to save registers
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ idx_m = rm[:, None]
+ idx_n = rn[None, :]
+ mask = (idx_m < M) & (idx_n < N)
+
+ # inductor generates a suffix
+ xindex = idx_n + (49152*idx_m)
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_n, mask.shape)), acc, mask)
+""",
+)
+import torch._inductor.kernel.mm_common
+
+meta0 = {
+ "GROUP_M": 8,
+ "EVEN_K": True,
+ "ALLOW_TF32": True,
+ "ACC_TYPE": "tl.float32",
+ "B_PROLOGUE_CAST_TYPE": None,
+ "BLOCK_M": 16,
+ "BLOCK_N": 64,
+ "BLOCK_K": 32,
+}
+
+
+# kernel path: /tmp/torchinductor_ubuntu/as/caswlqxmo7j7f3e26h3tbfwmhcdowf44zrv6mggnh5t7q5w2izpr.py
+# Source Nodes: [gelu, mul_2], Original ATen: [aten.gelu, aten.mul]
+# gelu => add_2, convert_element_type_4, convert_element_type_5, erf, mul_2, mul_3, mul_4
+# mul_2 => mul_5
+triton_poi_fused_gelu_mul_2 = async_compile.triton(
+ "triton_",
+ """
+import triton
+import triton.language as tl
+from torch._inductor.ir import ReductionHint
+from torch._inductor.ir import TileHint
+from torch._inductor.triton_heuristics import AutotuneHint, pointwise
+from torch._inductor.utils import instance_descriptor
+from torch._inductor import triton_helpers
+
+@pointwise(
+ size_hints=[32768],
+ filename=__file__,
+ triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(2,))]},
+ inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_mul_2', 'mutated_arg_names': []},
+ min_elem_per_thread=0
+)
+@triton.jit
+def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
+ xnumel = 24576
+ xoffset = tl.program_id(0) * XBLOCK
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
+ xmask = xindex < xnumel
+ x0 = xindex
+ tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
+ tmp11 = tl.load(in_ptr0 + (24576 + x0), None).to(tl.float32)
+ tmp1 = tmp0.to(tl.float32)
+ tmp2 = 0.5
+ tmp3 = tmp1 * tmp2
+ tmp4 = 0.7071067811865476
+ tmp5 = tmp1 * tmp4
+ tmp6 = tl.math.erf(tmp5)
+ tmp7 = 1.0
+ tmp8 = tmp6 + tmp7
+ tmp9 = tmp3 * tmp8
+ tmp10 = tmp9.to(tl.float32)
+ tmp12 = tmp10 * tmp11
+ tl.store(out_ptr0 + (x0), tmp12, None)
+""",
+)
+
+
+# kernel path: /tmp/torchinductor_ubuntu/ux/cuxfeb5redhx2vl5744hlfuhmstxy2yqiomik5glekkt3jn6vzq5.py
+# Source Nodes: [gelu, mlp_output, mul_2], Original ATen: [aten.gelu, aten.mm, aten.mul]
+# gelu => add_2, convert_element_type_4, convert_element_type_5, erf, mul_2, mul_3, mul_4
+# mlp_output => mm_1
+# mul_2 => mul_5
+triton_tem_fused_gelu_mm_mul_3 = async_compile.triton(
+ "triton_",
+ """
+import triton.language as tl
+import triton
+from torch._inductor.triton_heuristics import template
+from torch._inductor.utils import instance_descriptor
+from torch._inductor import triton_helpers
+
+
+@template(
+ num_stages=5,
+ num_warps=4,
+ triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
+ inductor_meta={'kernel_name': 'triton_tem_fused_gelu_mm_mul_3'},
+)
+@triton.jit
+
+def triton_(arg_A, arg_B, out_ptr0):
+ GROUP_M : tl.constexpr = 8
+ EVEN_K : tl.constexpr = True
+ ALLOW_TF32 : tl.constexpr = True
+ ACC_TYPE : tl.constexpr = tl.float32
+ B_PROLOGUE_CAST_TYPE : tl.constexpr = None
+ BLOCK_M : tl.constexpr = 16
+ BLOCK_N : tl.constexpr = 64
+ BLOCK_K : tl.constexpr = 32
+
+ A = arg_A
+ B = arg_B
+
+ M = 1
+ N = 3072
+ K = 24576
+ if M * N == 0:
+ # early exit due to zero-size input(s)
+ return
+ stride_am = 24576
+ stride_ak = 1
+ stride_bk = 1
+ stride_bn = 24576
+
+ # based on triton.ops.matmul
+ pid = tl.program_id(0)
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
+
+ # re-order program ID for better L2 performance
+ width = GROUP_M * grid_n
+ group_id = pid // width
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
+ pid_m = group_id * GROUP_M + (pid % group_size)
+ pid_n = (pid % width) // (group_size)
+
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(K, 0, -BLOCK_K):
+ if EVEN_K:
+ a = tl.load(A)
+ b = tl.load(B)
+ else:
+ a = tl.load(A, mask=rk[None, :] < k, other=0.)
+ b = tl.load(B, mask=rk[:, None] < k, other=0.)
+ if B_PROLOGUE_CAST_TYPE is not None:
+ b = b.to(B_PROLOGUE_CAST_TYPE)
+ acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+
+ # rematerialize rm and rn to save registers
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ idx_m = rm[:, None]
+ idx_n = rn[None, :]
+ mask = (idx_m < M) & (idx_n < N)
+
+ # inductor generates a suffix
+ xindex = idx_n + (3072*idx_m)
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_n, mask.shape)), acc, mask)
+""",
+)
+
+
+async_compile.wait(globals())
+del async_compile
+
+
+def call(args):
+ arg0_1, arg1_1, arg2_1, arg3_1, arg4_1 = args
+ args.clear()
+ assert_size_stride(arg0_1, (3072,), (1,))
+ assert_size_stride(arg1_1, (49152, 3072), (3072, 1))
+ assert_size_stride(arg2_1, (3072, 24576), (24576, 1))
+ assert_size_stride(arg3_1, (1, 3072), (3072, 1))
+ assert_size_stride(arg4_1, (1, 3072), (3072, 1))
+ with torch.cuda._DeviceGuard(0):
+ torch.cuda.set_device(0) # no-op to ensure context
+ buf0 = empty((1, 3072), device="cuda", dtype=torch.bfloat16)
+ buf2 = empty((1, 3072), device="cuda", dtype=torch.bfloat16)
+ # Source Nodes: [add, attn_res, hidden_states_1, hidden_states_2, hidden_states_3, normed_attn_res_output, pow_1, rsqrt, variance], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
+ stream0 = get_cuda_stream(0)
+ triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0.run(
+ arg4_1,
+ arg3_1,
+ arg0_1,
+ buf0,
+ buf2,
+ arg4_1,
+ 1,
+ 3072,
+ grid=grid(1),
+ stream=stream0,
+ )
+ del arg0_1
+ del arg3_1
+ del arg4_1
+ buf3 = empty((1, 49152), device="cuda", dtype=torch.bfloat16)
+ # Source Nodes: [add, gate_up_states, hidden_states_1, hidden_states_2, hidden_states_3, normed_attn_res_output, pow_1, rsqrt, variance], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mm, aten.mul, aten.pow, aten.rsqrt]
+ triton_tem_fused__to_copy_add_mean_mm_mul_pow_rsqrt_1.run(
+ buf2,
+ arg1_1,
+ buf3,
+ grid=torch._inductor.kernel.mm_common.mm_grid(1, 49152, meta0),
+ stream=stream0,
+ )
+ del arg1_1
+ buf4 = empty((1, 24576), device="cuda", dtype=torch.bfloat16)
+ # Source Nodes: [gelu, mul_2], Original ATen: [aten.gelu, aten.mul]
+ triton_poi_fused_gelu_mul_2.run(
+ buf3, buf4, 24576, grid=grid(24576), stream=stream0
+ )
+ del buf3
+ buf5 = buf2
+ del buf2 # reuse
+ # Source Nodes: [gelu, mlp_output, mul_2], Original ATen: [aten.gelu, aten.mm, aten.mul]
+ triton_tem_fused_gelu_mm_mul_3.run(
+ buf4,
+ arg2_1,
+ buf5,
+ grid=torch._inductor.kernel.mm_common.mm_grid(1, 3072, meta0),
+ stream=stream0,
+ )
+ del arg2_1
+ return (
+ buf5,
+ buf0,
+ )
+
+
+def benchmark_compiled_module(times=10, repeat=10):
+ from torch._dynamo.testing import rand_strided
+ from torch._inductor.utils import print_performance
+
+ arg0_1 = rand_strided((3072,), (1,), device="cuda:0", dtype=torch.bfloat16)
+ arg1_1 = rand_strided(
+ (49152, 3072), (3072, 1), device="cuda:0", dtype=torch.bfloat16
+ )
+ arg2_1 = rand_strided(
+ (3072, 24576), (24576, 1), device="cuda:0", dtype=torch.bfloat16
+ )
+ arg3_1 = rand_strided((1, 3072), (3072, 1), device="cuda:0", dtype=torch.bfloat16)
+ arg4_1 = rand_strided((1, 3072), (3072, 1), device="cuda:0", dtype=torch.bfloat16)
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1])
+ return print_performance(fn, times=times, repeat=repeat)
+
+
+if __name__ == "__main__":
+ from torch._inductor.wrapper_benchmark import compiled_module_main
+
+ compiled_module_main("None", benchmark_compiled_module)
diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
index bd7596db..abb6d1f3 100644
--- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
+++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
@@ -36,6 +36,8 @@ from text_generation_server.utils.layers import (
get_linear,
FastRMSNorm,
)
+from .custom_triton_kernels.gemma_kernel import call
+from .custom_triton_kernels.gemma_kernel_decode_one import call as call_decode_one
class GemmaConfig(PretrainedConfig):
@@ -61,6 +63,7 @@ class GemmaConfig(PretrainedConfig):
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
+ post_attn_method="post_attn_layer_norm_and_mlp_fused",
**kwargs,
):
self.vocab_size = vocab_size
@@ -84,6 +87,7 @@ class GemmaConfig(PretrainedConfig):
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
+ self.post_attn_method = post_attn_method
super().__init__(
pad_token_id=pad_token_id,
@@ -310,6 +314,67 @@ class FlashGemmaLayer(nn.Module):
eps=config.rms_norm_eps,
)
+ if config.post_attn_method == "post_attn_layer_norm_and_mlp_fused":
+ arg0_1 = self.post_attention_layernorm.weight
+ arg1_1 = self.mlp.gate_up_proj.linear.weight
+ arg2_1 = self.mlp.down_proj.linear.weight
+
+ # wrap the variable sized kernel call and use it as inner_forward
+ def _wrapped_inner_forward(
+ res,
+ attn_output,
+ ):
+ arg3_1 = res.size(0)
+ arg4_1 = res
+ arg5_1 = attn_output
+ return call(
+ [arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1],
+ )
+
+ self.inner_forward = _wrapped_inner_forward
+
+ # now wrap the kernel for a fixed size input and use it as decode_one_forward
+ def _wrapped_decode_one_forward(
+ res,
+ attn_output,
+ ):
+ arg3_1 = res
+ arg4_1 = attn_output
+ return call_decode_one(
+ [arg0_1, arg1_1, arg2_1, arg3_1, arg4_1],
+ )
+
+ self.decode_one_forward = _wrapped_decode_one_forward
+
+ elif config.post_attn_method == "compile":
+ # compile the function specifically for a fixed size input (variable one compiled outside this example)
+ self.decode_one_forward = torch.compile(
+ # the original forward function
+ self._inner_forward,
+ # avoid "reduce-overhead" mode, as it caches the graph and misleads timing
+ mode="max-autotune",
+ # avoid graph breaks
+ fullgraph=True,
+ # allow compilation for differnt input shapes
+ dynamic=True,
+ )
+ self.inner_forward = self._inner_forward
+ else:
+ # if nothing specified use the original forward function in both cases
+ self.decode_one_forward = self._inner_forward
+ self.inner_forward = self._inner_forward
+
+ def _inner_forward(
+ self,
+ res,
+ attn_output,
+ ):
+ normed_attn_res_output, attn_res = self.post_attention_layernorm(
+ attn_output, res
+ )
+ mlp_output = self.mlp(normed_attn_res_output)
+ return mlp_output, attn_res
+
def forward(
self,
hidden_states,
@@ -338,15 +403,17 @@ class FlashGemmaLayer(nn.Module):
max_s,
)
- # faster post attention rms norm
- normed_attn_res_output, attn_res = self.post_attention_layernorm(
- attn_output, res
+ if hidden_states.shape[0] == 1:
+ return self.decode_one_forward(
+ res,
+ attn_output,
+ )
+
+ return self.inner_forward(
+ res,
+ attn_output,
)
- mlp_output = self.mlp(normed_attn_res_output)
-
- return mlp_output, attn_res
-
class FlashGemmaModel(torch.nn.Module):
def __init__(self, config, weights):
diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py
index 7259b820..fbbf70dd 100644
--- a/server/text_generation_server/models/flash_gemma.py
+++ b/server/text_generation_server/models/flash_gemma.py
@@ -28,6 +28,7 @@ class FlashGemma(FlashCausalLM):
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
+ post_attn_method: Optional[str] = None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@@ -51,6 +52,7 @@ class FlashGemma(FlashCausalLM):
)
config.quantize = quantize
config.use_medusa = use_medusa
+ config.post_attn_method = post_attn_method
torch.distributed.barrier(group=self.process_group)
diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py
index 495c2c0c..111dfcf1 100644
--- a/server/text_generation_server/server.py
+++ b/server/text_generation_server/server.py
@@ -176,6 +176,7 @@ def serve(
dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
+ post_attn_method: Optional[str],
):
async def serve_inner(
model_id: str,
@@ -185,6 +186,7 @@ def serve(
speculate: Optional[int] = None,
dtype: Optional[str] = None,
trust_remote_code: bool = False,
+ post_attn_method: Optional[str] = None,
):
unix_socket_template = "unix://{}-{}"
if sharded:
@@ -206,6 +208,7 @@ def serve(
speculate,
dtype,
trust_remote_code,
+ post_attn_method,
)
except Exception:
logger.exception("Error when initializing model")
@@ -239,6 +242,13 @@ def serve(
asyncio.run(
serve_inner(
- model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
+ model_id,
+ revision,
+ sharded,
+ quantize,
+ speculate,
+ dtype,
+ trust_remote_code,
+ post_attn_method,
)
)