diff --git a/server/text_generation_server/models/cache_manager.py b/server/text_generation_server/models/cache_manager.py index df6b1ade..a79b25be 100644 --- a/server/text_generation_server/models/cache_manager.py +++ b/server/text_generation_server/models/cache_manager.py @@ -5,7 +5,8 @@ from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import FLASH_DECODING -BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 +# BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 +BLOCK_SIZE: int = 16 # Will be set in warmup CACHE_MANAGER: Optional["CacheManager"] = None @@ -33,18 +34,21 @@ class CacheManager: if FLASH_DECODING: self.kv_cache = [ - ( - torch.empty( - (num_blocks, self.block_size, num_heads, head_size), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, self.block_size, num_heads, head_size), - dtype=dtype, - device=device, - ), + torch.empty( + (num_blocks, 2, self.block_size, num_heads, head_size), + dtype=dtype, + device=device, ) + # torch.empty( + # (num_blocks, self.block_size, num_heads, head_size), + # dtype=dtype, + # device=device, + # ), + # torch.empty( + # (num_blocks, self.block_size, num_heads, head_size), + # dtype=dtype, + # device=device, + # ), for _ in range(num_layers) ] else: diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 6e23aa2b..01be171b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -137,6 +137,7 @@ class FlashLlamaAttention(torch.nn.Module): slots, input_lengths, max_s, + prefill_wrapper, ): qkv = self.query_key_value(hidden_states) query, kv = qkv.split( @@ -152,37 +153,40 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + kv[:, 0], kv[:, 1], kv_cache[:, 0], kv_cache[:, 1], slots ) # output tensor attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( - query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), - attn_output, - cu_seqlen_prefill, - max_s, - self.softmax_scale, + attn_output = prefill_wrapper.forward( + query.contiguous(), kv[:, 0].contiguous(), kv[:, 1].contiguous() ) + # flash_attn.attention( + # query, + # torch.select(kv, dim=1, index=0), + # torch.select(kv, dim=1, index=1), + # attn_output, + # cu_seqlen_prefill, + # max_s, + # self.softmax_scale, + # ) # Decode else: - paged_attention.attention( - attn_output, - query, - kv_cache[0], - kv_cache[1], - self.kv_head_mapping, - self.softmax_scale, - block_tables, - input_lengths, - max_s, - ) + attn_output = prefill_wrapper.forward(query, kv_cache) + # paged_attention.attention( + # attn_output, + # query, + # kv_cache[0], + # kv_cache[1], + # self.kv_head_mapping, + # self.softmax_scale, + # block_tables, + # input_lengths, + # max_s, + # ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -283,6 +287,7 @@ class FlashLlamaLayer(nn.Module): slots, input_lengths, max_s, + prefill_wrapper, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -297,6 +302,7 @@ class FlashLlamaLayer(nn.Module): slots, input_lengths, max_s, + prefill_wrapper, ) # faster post attention rms norm @@ -362,6 +368,54 @@ class FlashLlamaModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( position_ids, max_s, hidden_states.dtype ) + workspace_buffer = torch.empty( + 16 * 1024 * 1024, dtype=torch.uint8, device=inputs_embeds.device + ) + import flashinfer + + if cu_seqlen_prefill is None: + prefill_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, "NHD" + ) + cu_seqlen_q = torch.arange( + input_lengths.shape[0] + 1, + device=inputs_embeds.device, + dtype=torch.int32, + ) + cu_seqlen_k = torch.cat( + [ + torch.zeros( + (1,), device=input_lengths.device, dtype=input_lengths.dtype + ), + input_lengths.cumsum(dim=-1), + ] + ).to(dtype=torch.int32) + + prefill_wrapper.begin_forward( + indptr=cu_seqlen_k, + indices=block_tables.view(-1), + last_page_len=slots.to(dtype=torch.int32), + num_qo_heads=self.layers[0].self_attn.num_heads, + num_kv_heads=self.layers[0].self_attn.num_key_value_heads, + head_dim=self.layers[0].self_attn.head_size, + page_size=16, + pos_encoding_mode="NONE", + data_type=inputs_embeds.dtype, + ) + else: + prefill_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, "NHD" + ) + cu_seqlen_q = cu_seqlen_prefill + cu_seqlen_k = cu_seqlen_prefill + + prefill_wrapper.begin_forward( + qo_indptr=cu_seqlen_q, + kv_indptr=cu_seqlen_k, + num_qo_heads=self.layers[0].self_attn.num_heads, + num_kv_heads=self.layers[0].self_attn.num_key_value_heads, + head_dim=self.layers[0].self_attn.head_size, + ) residual = None for i, layer in enumerate(self.layers): @@ -376,8 +430,11 @@ class FlashLlamaModel(torch.nn.Module): slots, input_lengths, max_s, + prefill_wrapper, ) + prefill_wrapper.end_forward() + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index ef3777da..35f4c2f4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -214,6 +214,8 @@ class MistralAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + attn_output = torch.empty_like(query) + if prefill_cache_indices is not None: kv_to_cache = kv[prefill_cache_indices] else: @@ -222,10 +224,6 @@ class MistralAttention(torch.nn.Module): paged_attention.reshape_and_cache( kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) - - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index f8af5dc4..45bf9ecf 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -30,8 +30,8 @@ def reshape_and_cache( else: if FLASH_DECODING: shape = key_cache.shape - key_cache.view(-1, shape[-2], shape[-1])[slots] = key - value_cache.view(-1, shape[-2], shape[-1])[slots] = value + # key_cache.view(-1, shape[-2], shape[-1])[slots] = key + # value_cache.view(-1, shape[-2], shape[-1])[slots] = value else: cache_ops.reshape_and_cache( key, value, key_cache, value_cache, slots, "auto", 1.0