From ef99678798aeac053ddda3e0b886c4bcea293e69 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 25 Jan 2024 15:26:51 +0100 Subject: [PATCH] wip not faster --- server/Makefile-flash-att-v2 | 2 +- .../models/cache_manager.py | 15 +- .../custom_modeling/flash_llama_modeling.py | 186 +++++++++++------- server/text_generation_server/server.py | 6 +- .../utils/flash_attn.py | 34 ++-- 5 files changed, 147 insertions(+), 96 deletions(-) diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index 71c6cabe..041564a7 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,4 +1,4 @@ -flash_att_v2_commit_cuda := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3 +flash_att_v2_commit_cuda := 54e80a3829c6d2337570d01e78ebd9529c02d342 flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69 diff --git a/server/text_generation_server/models/cache_manager.py b/server/text_generation_server/models/cache_manager.py index 2e6ae086..e94b42bc 100644 --- a/server/text_generation_server/models/cache_manager.py +++ b/server/text_generation_server/models/cache_manager.py @@ -1,9 +1,11 @@ import math import torch +import os from typing import Optional, List, Tuple -BLOCK_SIZE: int = 16 +USE_VLLM = os.getenv("USE_VLLM", "False") == "True" +BLOCK_SIZE: int = 256 if not USE_VLLM else 16 # Will be set in warmup CACHE_MANAGER: Optional["CacheManager"] = None @@ -26,15 +28,22 @@ class CacheManager: element_size = torch.tensor([], dtype=dtype).element_size() x = self.block_size // element_size + if USE_VLLM: + k_shape = (num_blocks, num_heads, head_size // x, self.block_size, x) + v_shape = (num_blocks, num_heads, head_size, self.block_size) + else: + k_shape = (num_blocks, BLOCK_SIZE, num_heads, head_size) + v_shape = (num_blocks, BLOCK_SIZE, num_heads, head_size) + self.kv_cache = [ ( torch.empty( - (num_blocks, num_heads, head_size // x, self.block_size, x), + k_shape, dtype=dtype, device=device, ), torch.empty( - (num_blocks, num_heads, head_size, self.block_size), + v_shape, dtype=dtype, device=device, ), diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 3a269fc0..c3eb8131 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -17,6 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import torch import torch.distributed @@ -40,26 +41,26 @@ from text_generation_server.utils.layers import ( class LlamaConfig(PretrainedConfig): def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - pretraining_tp=1, - tie_word_embeddings=False, - rope_scaling=None, - rope_theta=10000.0, - **kwargs, + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_scaling=None, + rope_theta=10000.0, + **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings @@ -139,10 +140,10 @@ def _load_gqa(config, prefix: str, weights): class FlashLlamaAttention(torch.nn.Module): def __init__( - self, - prefix: str, - config, - weights, + self, + prefix: str, + config, + weights, ): super().__init__() self.num_heads = config.num_attention_heads @@ -156,7 +157,7 @@ class FlashLlamaAttention(torch.nn.Module): device=weights.device, ) - self.softmax_scale = self.head_size**-0.5 + self.softmax_scale = self.head_size ** -0.5 if self.num_heads % weights.process_group.size() != 0: raise ValueError( @@ -165,7 +166,7 @@ class FlashLlamaAttention(torch.nn.Module): ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = ( - config.num_key_value_heads // weights.process_group.size() + config.num_key_value_heads // weights.process_group.size() ) self.query_key_value = load_attention(config, prefix, weights) @@ -182,16 +183,16 @@ class FlashLlamaAttention(torch.nn.Module): ).repeat_interleave(self.num_groups) def forward( - self, - hidden_states, - cos, - sin, - cu_seqlen_prefill, - kv_cache, - block_tables, - slots, - input_lengths, - max_s, + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, ): qkv = self.query_key_value(hidden_states) query, kv = qkv.split( @@ -204,17 +205,24 @@ class FlashLlamaAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + use_vllm = os.getenv("USE_VLLM", "False") == "True" - paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + if use_vllm: + paged_attention.reshape_and_cache( + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + ) + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) # output tensor attn_output = torch.empty_like(query) # Prefill if cu_seqlen_prefill is not None: + if not use_vllm: + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[:, 0] + kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[:, 1] + # flash attention flash_attn.attention( query, @@ -227,17 +235,41 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( - attn_output, - query, - kv_cache[0], - kv_cache[1], - self.kv_head_mapping, - self.softmax_scale, - block_tables, - input_lengths, - max_s, - ) + if not use_vllm: + import flash_attn_2_cuda + flash_attn_2_cuda.fwd_kvcache( + query.unsqueeze(1), # q + kv_cache[0], # kcache + kv_cache[1], # vcache + torch.select(kv, dim=1, index=0).unsqueeze(1), # k + torch.select(kv, dim=1, index=1).unsqueeze(1), # v + input_lengths, # seqlens_k + self.rotary_emb._cos_cached, # rotary_cos + self.rotary_emb._sin_cached, # rotary_sin + # None,None, + None, # cache_batch_idx + block_tables, # block_tables + None, # alibi_slopes + attn_output.unsqueeze(1), # out + self.softmax_scale, # softmax_scale + True, # is_causal + -1, # window_size_left + 0, # window_size_right + False, # is_rotary_interleaved + 0, # num_splits + ) + else: + paged_attention.attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -271,7 +303,7 @@ class LlamaMLP(nn.Module): bias=False, ) self.intermediate_size = ( - config.intermediate_size // weights.process_group.size() + config.intermediate_size // weights.process_group.size() ) def forward(self, hidden_states): @@ -299,17 +331,17 @@ class FlashLlamaLayer(nn.Module): ) def forward( - self, - hidden_states, - residual, - cos, - sin, - cu_seqlen_prefill, - kv_cache, - block_tables, - slots, - input_lengths, - max_s, + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -367,23 +399,27 @@ class FlashLlamaModel(torch.nn.Module): self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, - slots: torch.Tensor, - input_lengths: torch.Tensor, - max_s: int, + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + use_vllm = os.getenv("USE_VLLM", "False") == "True" + if cu_seqlen_prefill is not None or use_vllm: + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + else: + cos, sin = None, None residual = None for i, layer in enumerate(self.layers): diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index d5adbd32..c5d0affa 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -127,6 +127,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ) async def Decode(self, request, context): + from torch.profiler import profile, ProfilerActivity + start = time.time_ns() if len(request.batches) == 0: raise ValueError("Must provide at least one batch") @@ -149,7 +151,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batch = batches[0] concat_ns = None - generations, next_batch, timings = self.model.generate_token(batch) + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof: + generations, next_batch, timings = self.model.generate_token(batch) + prefill_prof.export_chrome_trace("new_decode.json") self.cache.set(next_batch) return generate_pb2.DecodeResponse( diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 48f8ef70..abb362fa 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -57,7 +57,7 @@ except ImportError as e: elif IS_ROCM_SYSTEM: for idx in range(torch.cuda.device_count()): if "MI210" not in torch.cuda.get_device_name( - idx + idx ) and "MI250" not in torch.cuda.get_device_name(idx): raise ImportError( f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" @@ -68,27 +68,29 @@ except ImportError as e: def attention( - q, - k, - v, - out, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, ): if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") if HAS_FLASH_ATTN_V2_CUDA: return flash_attn_2_cuda.varlen_fwd( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - max_s, + q, # q + k, # k + v, # v + out, # out + cu_seqlens, # cu_seqlens_q + cu_seqlens, # cu_seqlens_k + None, + None, + max_s, # max_s, 0.0, softmax_scale,