From d0bc603fe6ea505e8c86634794852b0bf9e70ff0 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 9 Apr 2024 02:36:09 +0000 Subject: [PATCH] feat: explore compiled MLP bench --- server/bench/mlp-compare.py | 196 ++++++++++++++++++++++++++++++++++++ 1 file changed, 196 insertions(+) create mode 100644 server/bench/mlp-compare.py diff --git a/server/bench/mlp-compare.py b/server/bench/mlp-compare.py new file mode 100644 index 00000000..0d47d2cb --- /dev/null +++ b/server/bench/mlp-compare.py @@ -0,0 +1,196 @@ +import torch +import torch._dynamo.config +import torch._inductor.config +from torch.nn import functional as F + +import torch + +# prefer using pre-compiled ops + +# from text_generation_server.utils.layers import ( +# TensorParallelColumnLinear, +# TensorParallelRowLinear, +# ) + + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = True + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +torch.manual_seed(0) +# share input for both cases +x = torch.randn(4096, 4096, device=device) + + +class DummyLayer(torch.nn.Module): + + def __init__(self): + super().__init__() + self.gate_up_proj = torch.nn.Linear(4096, 4096).to(device) + self.down_proj = torch.nn.Linear(4096, 4096).to(device) + self.act = torch.nn.GELU().to(device) + + def forward(self, x): + y = self.gate_up_proj(x) + y = self.act(y) + y = self.down_proj(y) + return y + + +class DummyModule(torch.nn.Module): + + def __init__(self): + super().__init__() + # add multiple layers to magnify the effect + N = 10 + self.layer = torch.nn.Sequential(*[DummyLayer() for _ in range(N)]).to(device) + + def forward(self, x): + return self.layer(x) + + +model = DummyModule() + +print("Model") +print(model) + + +# run the model via a forward pass +def forward_pass(x): + return model(x) + + +# same as above but compiled +forward_pass_compiled = torch.compile( + forward_pass, mode="reduce-overhead", fullgraph=True +) + + +# one pass to avoid the compilation overhead +y = forward_pass_compiled(x) + +# start profiling +torch.profiler._utils._init_for_cuda_graphs() +prof = torch.profiler.profile() + +# run on compiled model +with prof: + y = forward_pass_compiled(x) + prof.step() + +print("Compiled") +print(prof.key_averages().table(sort_by="self_cuda_time_total")) + +# one pass to avoid the compilation overhead (just to align with the compiled case) +y = forward_pass(x) + +# remove the profiling data to avoid any contamination +del prof + +# start a new profiling session +torch.profiler._utils._init_for_cuda_graphs() +prof = torch.profiler.profile() + +# run on non-compiled model +with prof: + y = forward_pass(x) + prof.step() + +print("") +print("Not Compiled") +print(prof.key_averages().table(sort_by="self_cuda_time_total")) + + +# Expected optimized code: + +# {"XBLOCK": 256, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": true} +# coordesc: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/triton_heuristics.py#L652 + +# import triton +# import triton.language as tl +# from torch._inductor.ir import ReductionHint +# from torch._inductor.ir import TileHint +# from torch._inductor.triton_heuristics import AutotuneHint, pointwise +# from torch._inductor.utils import instance_descriptor +# from torch._inductor import triton_helpers + +# @pointwise( +# size_hints=[16777216], +# filename=__file__, +# triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(2,))]}, +# inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_0', 'mutated_arg_names': []}, +# min_elem_per_thread=0 +# ) +# @triton.jit +# def triton_poi_fused_gelu_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): +# xnumel = 16777216 +# xoffset = tl.program_id(0) * XBLOCK +# xindex = xoffset + tl.arange(0, XBLOCK)[:] +# xmask = xindex < xnumel +# x0 = xindex +# tmp0 = tl.load(in_ptr0 + (x0), None) +# tmp1 = 0.5 +# tmp2 = tmp0 * tmp1 +# tmp3 = 0.7071067811865476 +# tmp4 = tmp0 * tmp3 +# tmp5 = tl.math.erf(tmp4) +# tmp6 = 1.0 +# tmp7 = tmp5 + tmp6 +# tmp8 = tmp2 * tmp7 +# tl.store(out_ptr0 + (x0), tmp8, None) + + +# Compiled +# ------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ +# Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls +# ------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ +# aten::addmm 0.39% 586.000us 0.63% 945.000us 47.250us 144.767ms 98.15% 145.039ms 7.252ms 20 +# ampere_sgemm_128x128_tn 0.00% 0.000us 0.00% 0.000us 0.000us 144.748ms 98.13% 144.748ms 7.237ms 20 +# cudaMalloc 3.20% 4.804ms 3.20% 4.804ms 160.133us 7.329ms 4.97% 7.329ms 244.300us 30 +# cudaEventRecord 0.00% 5.000us 0.00% 5.000us 5.000us 6.778ms 4.60% 6.778ms 6.778ms 1 +# triton_poi_fused_gelu_0 0.18% 271.000us 0.25% 372.000us 37.200us 2.732ms 1.85% 2.732ms 273.200us 10 +# triton_poi_fused_gelu_0_0d1d2de 0.00% 0.000us 0.00% 0.000us 0.000us 2.732ms 1.85% 2.732ms 273.200us 10 +# cudaStreamIsCapturing 0.02% 30.000us 0.02% 30.000us 1.000us 272.000us 0.18% 272.000us 9.067us 30 +# cudaOccupancyMaxActiveBlocksPerMultiprocessor 0.01% 22.000us 0.01% 22.000us 1.100us 272.000us 0.18% 272.000us 13.600us 20 +# Memset (Device) 0.00% 0.000us 0.00% 0.000us 0.000us 19.000us 0.01% 19.000us 0.950us 20 +# TorchDynamo Cache Lookup 0.03% 41.000us 0.03% 41.000us 41.000us 0.000us 0.00% 0.000us 0.000us 1 +# Torch-Compiled Region 0.09% 141.000us 99.97% 150.195ms 150.195ms 0.000us 0.00% 162.150ms 162.150ms 1 +# aten::detach 0.00% 4.000us 0.01% 14.000us 14.000us 0.000us 0.00% 0.000us 0.000us 1 +# detach 0.01% 10.000us 0.01% 10.000us 10.000us 0.000us 0.00% 0.000us 0.000us 1 +# CompiledFunction 2.37% 3.566ms 99.87% 150.040ms 150.040ms 0.000us 0.00% 162.150ms 162.150ms 1 +# cudaDeviceSynchronize 93.08% 139.840ms 93.08% 139.840ms 46.613ms 0.000us 0.00% 0.000us 0.000us 3 +# cudaStreamWaitEvent 0.00% 2.000us 0.00% 2.000us 2.000us 0.000us 0.00% 0.000us 0.000us 1 +# aten::empty 0.28% 426.000us 3.50% 5.260ms 175.333us 0.000us 0.00% 7.601ms 253.367us 30 +# inductor::_reinterpret_tensor 0.04% 55.000us 0.04% 55.000us 1.410us 0.000us 0.00% 0.000us 0.000us 39 +# cudaMemsetAsync 0.09% 136.000us 0.09% 136.000us 6.800us 0.000us 0.00% 0.000us 0.000us 20 +# cudaLaunchKernel 0.13% 201.000us 0.13% 201.000us 10.050us 0.000us 0.00% 0.000us 0.000us 20 +# cuLaunchKernel 0.07% 101.000us 0.07% 101.000us 10.100us 0.000us 0.00% 0.000us 0.000us 10 +# ------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ +# Self CPU time total: 150.241ms +# Self CUDA time total: 147.499ms + + +# Not Compiled +# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ +# Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls +# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ +# aten::addmm 0.67% 1.885ms 1.93% 5.415ms 270.750us 145.866ms 98.15% 145.866ms 7.293ms 20 +# ampere_sgemm_128x128_tn 0.00% 0.000us 0.00% 0.000us 0.000us 145.847ms 98.14% 145.847ms 7.292ms 20 +# aten::gelu 0.09% 245.000us 0.70% 1.955ms 195.500us 2.747ms 1.85% 2.747ms 274.700us 10 +# void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 2.747ms 1.85% 2.747ms 274.700us 10 +# Memset (Device) 0.00% 0.000us 0.00% 0.000us 0.000us 19.000us 0.01% 19.000us 0.950us 20 +# aten::linear 0.03% 82.000us 2.02% 5.671ms 283.550us 0.000us 0.00% 145.866ms 7.293ms 20 +# aten::t 0.04% 110.000us 0.06% 174.000us 8.700us 0.000us 0.00% 0.000us 0.000us 20 +# aten::transpose 0.01% 42.000us 0.02% 64.000us 3.200us 0.000us 0.00% 0.000us 0.000us 20 +# aten::as_strided 0.01% 22.000us 0.01% 22.000us 1.100us 0.000us 0.00% 0.000us 0.000us 20 +# cudaStreamIsCapturing 0.01% 30.000us 0.01% 30.000us 1.000us 0.000us 0.00% 0.000us 0.000us 30 +# cudaMalloc 1.70% 4.770ms 1.70% 4.770ms 159.000us 0.000us 0.00% 0.000us 0.000us 30 +# cudaMemsetAsync 0.04% 124.000us 0.04% 124.000us 6.200us 0.000us 0.00% 0.000us 0.000us 20 +# cudaOccupancyMaxActiveBlocksPerMultiprocessor 0.01% 20.000us 0.01% 20.000us 1.000us 0.000us 0.00% 0.000us 0.000us 20 +# cudaLaunchKernel 0.11% 296.000us 0.11% 296.000us 9.867us 0.000us 0.00% 0.000us 0.000us 30 +# cudaDeviceSynchronize 97.28% 272.916ms 97.28% 272.916ms 272.916ms 0.000us 0.00% 0.000us 0.000us 1 +# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ +# Self CPU time total: 280.542ms +# Self CUDA time total: 148.613ms