wip not faster

This commit is contained in:
OlivierDehaene 2024-01-25 15:26:51 +01:00
parent bf700e7eef
commit ef99678798
5 changed files with 147 additions and 96 deletions

View File

@ -1,4 +1,4 @@
flash_att_v2_commit_cuda := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
flash_att_v2_commit_cuda := 54e80a3829c6d2337570d01e78ebd9529c02d342
flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69

View File

@ -1,9 +1,11 @@
import math
import torch
import os
from typing import Optional, List, Tuple
BLOCK_SIZE: int = 16
USE_VLLM = os.getenv("USE_VLLM", "False") == "True"
BLOCK_SIZE: int = 256 if not USE_VLLM else 16
# Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None
@ -26,15 +28,22 @@ class CacheManager:
element_size = torch.tensor([], dtype=dtype).element_size()
x = self.block_size // element_size
if USE_VLLM:
k_shape = (num_blocks, num_heads, head_size // x, self.block_size, x)
v_shape = (num_blocks, num_heads, head_size, self.block_size)
else:
k_shape = (num_blocks, BLOCK_SIZE, num_heads, head_size)
v_shape = (num_blocks, BLOCK_SIZE, num_heads, head_size)
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, head_size // x, self.block_size, x),
k_shape,
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, head_size, self.block_size),
v_shape,
dtype=dtype,
device=device,
),

View File

@ -17,6 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import torch.distributed
@ -40,26 +41,26 @@ from text_generation_server.utils.layers import (
class LlamaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
rope_theta=10000.0,
**kwargs,
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
rope_theta=10000.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
@ -139,10 +140,10 @@ def _load_gqa(config, prefix: str, weights):
class FlashLlamaAttention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
self,
prefix: str,
config,
weights,
):
super().__init__()
self.num_heads = config.num_attention_heads
@ -156,7 +157,7 @@ class FlashLlamaAttention(torch.nn.Module):
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5
self.softmax_scale = self.head_size ** -0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
@ -165,7 +166,7 @@ class FlashLlamaAttention(torch.nn.Module):
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
@ -182,16 +183,16 @@ class FlashLlamaAttention(torch.nn.Module):
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
@ -204,17 +205,24 @@ class FlashLlamaAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
use_vllm = os.getenv("USE_VLLM", "False") == "True"
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
if use_vllm:
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
if not use_vllm:
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[:, 0]
kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[:, 1]
# flash attention
flash_attn.attention(
query,
@ -227,17 +235,41 @@ class FlashLlamaAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
if not use_vllm:
import flash_attn_2_cuda
flash_attn_2_cuda.fwd_kvcache(
query.unsqueeze(1), # q
kv_cache[0], # kcache
kv_cache[1], # vcache
torch.select(kv, dim=1, index=0).unsqueeze(1), # k
torch.select(kv, dim=1, index=1).unsqueeze(1), # v
input_lengths, # seqlens_k
self.rotary_emb._cos_cached, # rotary_cos
self.rotary_emb._sin_cached, # rotary_sin
# None,None,
None, # cache_batch_idx
block_tables, # block_tables
None, # alibi_slopes
attn_output.unsqueeze(1), # out
self.softmax_scale, # softmax_scale
True, # is_causal
-1, # window_size_left
0, # window_size_right
False, # is_rotary_interleaved
0, # num_splits
)
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -271,7 +303,7 @@ class LlamaMLP(nn.Module):
bias=False,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states):
@ -299,17 +331,17 @@ class FlashLlamaLayer(nn.Module):
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -367,23 +399,27 @@ class FlashLlamaModel(torch.nn.Module):
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
use_vllm = os.getenv("USE_VLLM", "False") == "True"
if cu_seqlen_prefill is not None or use_vllm:
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
else:
cos, sin = None, None
residual = None
for i, layer in enumerate(self.layers):

View File

@ -127,6 +127,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
)
async def Decode(self, request, context):
from torch.profiler import profile, ProfilerActivity
start = time.time_ns()
if len(request.batches) == 0:
raise ValueError("Must provide at least one batch")
@ -149,7 +151,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = batches[0]
concat_ns = None
generations, next_batch, timings = self.model.generate_token(batch)
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
generations, next_batch, timings = self.model.generate_token(batch)
prefill_prof.export_chrome_trace("new_decode.json")
self.cache.set(next_batch)
return generate_pb2.DecodeResponse(

View File

@ -57,7 +57,7 @@ except ImportError as e:
elif IS_ROCM_SYSTEM:
for idx in range(torch.cuda.device_count()):
if "MI210" not in torch.cuda.get_device_name(
idx
idx
) and "MI250" not in torch.cuda.get_device_name(idx):
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
@ -68,27 +68,29 @@ except ImportError as e:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if HAS_FLASH_ATTN_V2_CUDA:
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
q, # q
k, # k
v, # v
out, # out
cu_seqlens, # cu_seqlens_q
cu_seqlens, # cu_seqlens_k
None,
None,
max_s, #
max_s,
0.0,
softmax_scale,