Flashinfer test.
This commit is contained in:
parent
01e4442ef6
commit
3c74cf9cd4
|
@ -5,7 +5,8 @@ from typing import Optional, List, Tuple
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import FLASH_DECODING
|
from text_generation_server.models.globals import FLASH_DECODING
|
||||||
|
|
||||||
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
# BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
||||||
|
BLOCK_SIZE: int = 16
|
||||||
# Will be set in warmup
|
# Will be set in warmup
|
||||||
CACHE_MANAGER: Optional["CacheManager"] = None
|
CACHE_MANAGER: Optional["CacheManager"] = None
|
||||||
|
|
||||||
|
@ -33,18 +34,21 @@ class CacheManager:
|
||||||
|
|
||||||
if FLASH_DECODING:
|
if FLASH_DECODING:
|
||||||
self.kv_cache = [
|
self.kv_cache = [
|
||||||
(
|
torch.empty(
|
||||||
torch.empty(
|
(num_blocks, 2, self.block_size, num_heads, head_size),
|
||||||
(num_blocks, self.block_size, num_heads, head_size),
|
dtype=dtype,
|
||||||
dtype=dtype,
|
device=device,
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
torch.empty(
|
|
||||||
(num_blocks, self.block_size, num_heads, head_size),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
# torch.empty(
|
||||||
|
# (num_blocks, self.block_size, num_heads, head_size),
|
||||||
|
# dtype=dtype,
|
||||||
|
# device=device,
|
||||||
|
# ),
|
||||||
|
# torch.empty(
|
||||||
|
# (num_blocks, self.block_size, num_heads, head_size),
|
||||||
|
# dtype=dtype,
|
||||||
|
# device=device,
|
||||||
|
# ),
|
||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -137,6 +137,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_wrapper,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
|
@ -152,37 +153,40 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
kv[:, 0], kv[:, 1], kv_cache[:, 0], kv_cache[:, 1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
# Prefill
|
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attn_output = prefill_wrapper.forward(
|
||||||
query,
|
query.contiguous(), kv[:, 0].contiguous(), kv[:, 1].contiguous()
|
||||||
torch.select(kv, dim=1, index=0),
|
|
||||||
torch.select(kv, dim=1, index=1),
|
|
||||||
attn_output,
|
|
||||||
cu_seqlen_prefill,
|
|
||||||
max_s,
|
|
||||||
self.softmax_scale,
|
|
||||||
)
|
)
|
||||||
|
# flash_attn.attention(
|
||||||
|
# query,
|
||||||
|
# torch.select(kv, dim=1, index=0),
|
||||||
|
# torch.select(kv, dim=1, index=1),
|
||||||
|
# attn_output,
|
||||||
|
# cu_seqlen_prefill,
|
||||||
|
# max_s,
|
||||||
|
# self.softmax_scale,
|
||||||
|
# )
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
attn_output = prefill_wrapper.forward(query, kv_cache)
|
||||||
attn_output,
|
# paged_attention.attention(
|
||||||
query,
|
# attn_output,
|
||||||
kv_cache[0],
|
# query,
|
||||||
kv_cache[1],
|
# kv_cache[0],
|
||||||
self.kv_head_mapping,
|
# kv_cache[1],
|
||||||
self.softmax_scale,
|
# self.kv_head_mapping,
|
||||||
block_tables,
|
# self.softmax_scale,
|
||||||
input_lengths,
|
# block_tables,
|
||||||
max_s,
|
# input_lengths,
|
||||||
)
|
# max_s,
|
||||||
|
# )
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
@ -283,6 +287,7 @@ class FlashLlamaLayer(nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_wrapper,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -297,6 +302,7 @@ class FlashLlamaLayer(nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_wrapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
|
@ -362,6 +368,54 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
position_ids, max_s, hidden_states.dtype
|
position_ids, max_s, hidden_states.dtype
|
||||||
)
|
)
|
||||||
|
workspace_buffer = torch.empty(
|
||||||
|
16 * 1024 * 1024, dtype=torch.uint8, device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
import flashinfer
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is None:
|
||||||
|
prefill_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffer, "NHD"
|
||||||
|
)
|
||||||
|
cu_seqlen_q = torch.arange(
|
||||||
|
input_lengths.shape[0] + 1,
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
cu_seqlen_k = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(
|
||||||
|
(1,), device=input_lengths.device, dtype=input_lengths.dtype
|
||||||
|
),
|
||||||
|
input_lengths.cumsum(dim=-1),
|
||||||
|
]
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
|
||||||
|
prefill_wrapper.begin_forward(
|
||||||
|
indptr=cu_seqlen_k,
|
||||||
|
indices=block_tables.view(-1),
|
||||||
|
last_page_len=slots.to(dtype=torch.int32),
|
||||||
|
num_qo_heads=self.layers[0].self_attn.num_heads,
|
||||||
|
num_kv_heads=self.layers[0].self_attn.num_key_value_heads,
|
||||||
|
head_dim=self.layers[0].self_attn.head_size,
|
||||||
|
page_size=16,
|
||||||
|
pos_encoding_mode="NONE",
|
||||||
|
data_type=inputs_embeds.dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prefill_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
|
||||||
|
workspace_buffer, "NHD"
|
||||||
|
)
|
||||||
|
cu_seqlen_q = cu_seqlen_prefill
|
||||||
|
cu_seqlen_k = cu_seqlen_prefill
|
||||||
|
|
||||||
|
prefill_wrapper.begin_forward(
|
||||||
|
qo_indptr=cu_seqlen_q,
|
||||||
|
kv_indptr=cu_seqlen_k,
|
||||||
|
num_qo_heads=self.layers[0].self_attn.num_heads,
|
||||||
|
num_kv_heads=self.layers[0].self_attn.num_key_value_heads,
|
||||||
|
head_dim=self.layers[0].self_attn.head_size,
|
||||||
|
)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
|
@ -376,8 +430,11 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_wrapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prefill_wrapper.end_forward()
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
|
@ -214,6 +214,8 @@ class MistralAttention(torch.nn.Module):
|
||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
if prefill_cache_indices is not None:
|
if prefill_cache_indices is not None:
|
||||||
kv_to_cache = kv[prefill_cache_indices]
|
kv_to_cache = kv[prefill_cache_indices]
|
||||||
else:
|
else:
|
||||||
|
@ -222,10 +224,6 @@ class MistralAttention(torch.nn.Module):
|
||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
# output tensor
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
|
|
|
@ -30,8 +30,8 @@ def reshape_and_cache(
|
||||||
else:
|
else:
|
||||||
if FLASH_DECODING:
|
if FLASH_DECODING:
|
||||||
shape = key_cache.shape
|
shape = key_cache.shape
|
||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
# key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
# value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||||
else:
|
else:
|
||||||
cache_ops.reshape_and_cache(
|
cache_ops.reshape_and_cache(
|
||||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||||
|
|
Loading…
Reference in New Issue