wip not faster
This commit is contained in:
parent
bf700e7eef
commit
ef99678798
|
@ -1,4 +1,4 @@
|
|||
flash_att_v2_commit_cuda := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
|
||||
flash_att_v2_commit_cuda := 54e80a3829c6d2337570d01e78ebd9529c02d342
|
||||
flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import math
|
||||
import torch
|
||||
import os
|
||||
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
BLOCK_SIZE: int = 16
|
||||
USE_VLLM = os.getenv("USE_VLLM", "False") == "True"
|
||||
BLOCK_SIZE: int = 256 if not USE_VLLM else 16
|
||||
# Will be set in warmup
|
||||
CACHE_MANAGER: Optional["CacheManager"] = None
|
||||
|
||||
|
@ -26,15 +28,22 @@ class CacheManager:
|
|||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
x = self.block_size // element_size
|
||||
|
||||
if USE_VLLM:
|
||||
k_shape = (num_blocks, num_heads, head_size // x, self.block_size, x)
|
||||
v_shape = (num_blocks, num_heads, head_size, self.block_size)
|
||||
else:
|
||||
k_shape = (num_blocks, BLOCK_SIZE, num_heads, head_size)
|
||||
v_shape = (num_blocks, BLOCK_SIZE, num_heads, head_size)
|
||||
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size // x, self.block_size, x),
|
||||
k_shape,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size, self.block_size),
|
||||
v_shape,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
@ -40,26 +41,26 @@ from text_generation_server.utils.layers import (
|
|||
|
||||
class LlamaConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_scaling=None,
|
||||
rope_theta=10000.0,
|
||||
**kwargs,
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_scaling=None,
|
||||
rope_theta=10000.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
@ -139,10 +140,10 @@ def _load_gqa(config, prefix: str, weights):
|
|||
|
||||
class FlashLlamaAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
@ -156,7 +157,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
device=weights.device,
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
self.softmax_scale = self.head_size ** -0.5
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
|
@ -165,7 +166,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
@ -182,16 +183,16 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
query, kv = qkv.split(
|
||||
|
@ -204,17 +205,24 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
use_vllm = os.getenv("USE_VLLM", "False") == "True"
|
||||
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
if use_vllm:
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
if not use_vllm:
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[:, 0]
|
||||
kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[:, 1]
|
||||
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
|
@ -227,17 +235,41 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
if not use_vllm:
|
||||
import flash_attn_2_cuda
|
||||
flash_attn_2_cuda.fwd_kvcache(
|
||||
query.unsqueeze(1), # q
|
||||
kv_cache[0], # kcache
|
||||
kv_cache[1], # vcache
|
||||
torch.select(kv, dim=1, index=0).unsqueeze(1), # k
|
||||
torch.select(kv, dim=1, index=1).unsqueeze(1), # v
|
||||
input_lengths, # seqlens_k
|
||||
self.rotary_emb._cos_cached, # rotary_cos
|
||||
self.rotary_emb._sin_cached, # rotary_sin
|
||||
# None,None,
|
||||
None, # cache_batch_idx
|
||||
block_tables, # block_tables
|
||||
None, # alibi_slopes
|
||||
attn_output.unsqueeze(1), # out
|
||||
self.softmax_scale, # softmax_scale
|
||||
True, # is_causal
|
||||
-1, # window_size_left
|
||||
0, # window_size_right
|
||||
False, # is_rotary_interleaved
|
||||
0, # num_splits
|
||||
)
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
@ -271,7 +303,7 @@ class LlamaMLP(nn.Module):
|
|||
bias=False,
|
||||
)
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
|
@ -299,17 +331,17 @@ class FlashLlamaLayer(nn.Module):
|
|||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
|
@ -367,23 +399,27 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids, max_s, hidden_states.dtype
|
||||
)
|
||||
use_vllm = os.getenv("USE_VLLM", "False") == "True"
|
||||
if cu_seqlen_prefill is not None or use_vllm:
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids, max_s, hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
cos, sin = None, None
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
|
|
|
@ -127,6 +127,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
)
|
||||
|
||||
async def Decode(self, request, context):
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
|
||||
start = time.time_ns()
|
||||
if len(request.batches) == 0:
|
||||
raise ValueError("Must provide at least one batch")
|
||||
|
@ -149,7 +151,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
batch = batches[0]
|
||||
concat_ns = None
|
||||
|
||||
generations, next_batch, timings = self.model.generate_token(batch)
|
||||
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
|
||||
generations, next_batch, timings = self.model.generate_token(batch)
|
||||
prefill_prof.export_chrome_trace("new_decode.json")
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.DecodeResponse(
|
||||
|
|
|
@ -57,7 +57,7 @@ except ImportError as e:
|
|||
elif IS_ROCM_SYSTEM:
|
||||
for idx in range(torch.cuda.device_count()):
|
||||
if "MI210" not in torch.cuda.get_device_name(
|
||||
idx
|
||||
idx
|
||||
) and "MI250" not in torch.cuda.get_device_name(idx):
|
||||
raise ImportError(
|
||||
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
||||
|
@ -68,27 +68,29 @@ except ImportError as e:
|
|||
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
):
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
|
||||
if HAS_FLASH_ATTN_V2_CUDA:
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
q, # q
|
||||
k, # k
|
||||
v, # v
|
||||
out, # out
|
||||
cu_seqlens, # cu_seqlens_q
|
||||
cu_seqlens, # cu_seqlens_k
|
||||
None,
|
||||
None,
|
||||
max_s, #
|
||||
max_s,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
|
|
Loading…
Reference in New Issue