Flashinfer test.

This commit is contained in:
Nicolas Patry 2024-05-24 15:32:24 +00:00
parent 01e4442ef6
commit 3c74cf9cd4
4 changed files with 98 additions and 39 deletions

View File

@ -5,7 +5,8 @@ from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.models.globals import FLASH_DECODING
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 # BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
BLOCK_SIZE: int = 16
# Will be set in warmup # Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None CACHE_MANAGER: Optional["CacheManager"] = None
@ -33,18 +34,21 @@ class CacheManager:
if FLASH_DECODING: if FLASH_DECODING:
self.kv_cache = [ self.kv_cache = [
( torch.empty(
torch.empty( (num_blocks, 2, self.block_size, num_heads, head_size),
(num_blocks, self.block_size, num_heads, head_size), dtype=dtype,
dtype=dtype, device=device,
device=device,
),
torch.empty(
(num_blocks, self.block_size, num_heads, head_size),
dtype=dtype,
device=device,
),
) )
# torch.empty(
# (num_blocks, self.block_size, num_heads, head_size),
# dtype=dtype,
# device=device,
# ),
# torch.empty(
# (num_blocks, self.block_size, num_heads, head_size),
# dtype=dtype,
# device=device,
# ),
for _ in range(num_layers) for _ in range(num_layers)
] ]
else: else:

View File

@ -137,6 +137,7 @@ class FlashLlamaAttention(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_wrapper,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
query, kv = qkv.split( query, kv = qkv.split(
@ -152,37 +153,40 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache( paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots kv[:, 0], kv[:, 1], kv_cache[:, 0], kv_cache[:, 1], slots
) )
# output tensor # output tensor
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
flash_attn.attention( attn_output = prefill_wrapper.forward(
query, query.contiguous(), kv[:, 0].contiguous(), kv[:, 1].contiguous()
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
) )
# flash_attn.attention(
# query,
# torch.select(kv, dim=1, index=0),
# torch.select(kv, dim=1, index=1),
# attn_output,
# cu_seqlen_prefill,
# max_s,
# self.softmax_scale,
# )
# Decode # Decode
else: else:
paged_attention.attention( attn_output = prefill_wrapper.forward(query, kv_cache)
attn_output, # paged_attention.attention(
query, # attn_output,
kv_cache[0], # query,
kv_cache[1], # kv_cache[0],
self.kv_head_mapping, # kv_cache[1],
self.softmax_scale, # self.kv_head_mapping,
block_tables, # self.softmax_scale,
input_lengths, # block_tables,
max_s, # input_lengths,
) # max_s,
# )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -283,6 +287,7 @@ class FlashLlamaLayer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_wrapper,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -297,6 +302,7 @@ class FlashLlamaLayer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_wrapper,
) )
# faster post attention rms norm # faster post attention rms norm
@ -362,6 +368,54 @@ class FlashLlamaModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype position_ids, max_s, hidden_states.dtype
) )
workspace_buffer = torch.empty(
16 * 1024 * 1024, dtype=torch.uint8, device=inputs_embeds.device
)
import flashinfer
if cu_seqlen_prefill is None:
prefill_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
cu_seqlen_q = torch.arange(
input_lengths.shape[0] + 1,
device=inputs_embeds.device,
dtype=torch.int32,
)
cu_seqlen_k = torch.cat(
[
torch.zeros(
(1,), device=input_lengths.device, dtype=input_lengths.dtype
),
input_lengths.cumsum(dim=-1),
]
).to(dtype=torch.int32)
prefill_wrapper.begin_forward(
indptr=cu_seqlen_k,
indices=block_tables.view(-1),
last_page_len=slots.to(dtype=torch.int32),
num_qo_heads=self.layers[0].self_attn.num_heads,
num_kv_heads=self.layers[0].self_attn.num_key_value_heads,
head_dim=self.layers[0].self_attn.head_size,
page_size=16,
pos_encoding_mode="NONE",
data_type=inputs_embeds.dtype,
)
else:
prefill_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffer, "NHD"
)
cu_seqlen_q = cu_seqlen_prefill
cu_seqlen_k = cu_seqlen_prefill
prefill_wrapper.begin_forward(
qo_indptr=cu_seqlen_q,
kv_indptr=cu_seqlen_k,
num_qo_heads=self.layers[0].self_attn.num_heads,
num_kv_heads=self.layers[0].self_attn.num_key_value_heads,
head_dim=self.layers[0].self_attn.head_size,
)
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
@ -376,8 +430,11 @@ class FlashLlamaModel(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_wrapper,
) )
prefill_wrapper.end_forward()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states

View File

@ -214,6 +214,8 @@ class MistralAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
attn_output = torch.empty_like(query)
if prefill_cache_indices is not None: if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices] kv_to_cache = kv[prefill_cache_indices]
else: else:
@ -222,10 +224,6 @@ class MistralAttention(torch.nn.Module):
paged_attention.reshape_and_cache( paged_attention.reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention

View File

@ -30,8 +30,8 @@ def reshape_and_cache(
else: else:
if FLASH_DECODING: if FLASH_DECODING:
shape = key_cache.shape shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key # key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value # value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else: else:
cache_ops.reshape_and_cache( cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0 key, value, key_cache, value_cache, slots, "auto", 1.0