add fp8 kv cache for rocm

This commit is contained in:
Mohit Sharma 2024-12-18 14:55:53 +00:00
parent 8f66d323d0
commit fa14d71ac8
4 changed files with 122 additions and 44 deletions

View File

@ -52,13 +52,22 @@ class KVCache:
device: torch.device,
):
"""Construct the key-value cache for a layer."""
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
if (ATTENTION == "flashinfer" and SYSTEM == "cuda") or not (
ATTENTION == "paged" and SYSTEM == "rocm"
):
raise ValueError(
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCM"
)
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
raise ValueError(
"float8_e5m2 FP8 KV cache is not supported on AMD Rocm"
)
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and (
ATTENTION != "flashinfer" or SYSTEM != "cuda"
):
raise ValueError(
"FP8 KV cache is currently only supported for flashinfer on CUDA"
)
self.kv_cache_dtype_str = "auto"
if SYSTEM == "rocm" and dtype == torch.float8_e4m3fn:
self.kv_cache_dtype_str = "fp8"
dtype = torch.uint8
element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "ipex" and device.type == "xpu":
@ -120,6 +129,16 @@ class KVCache:
"Using FP8 KV cache scales",
)
return True
elif (
self.kv_cache_dtype_str == "fp8"
and ATTENTION == "paged"
and SYSTEM == "rocm"
):
log_once(
logger.info,
"Using FP8 KV cache scales",
)
return True
else:
# We have scales, but not the correct FP8 cache type, so warn once.
log_once(
@ -158,7 +177,7 @@ class KVCache:
key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1]
if self.can_scale(kv_scales):
if self.can_scale(kv_scales) and SYSTEM == "cuda":
if kv_scales.key_scale_cpu != 1.0:
key = fp8_quantize(
key.float(),
@ -188,7 +207,16 @@ class KVCache:
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
paged_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slots,
self.kv_cache_dtype_str,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)
def paged_reshape_and_cache(
@ -197,7 +225,11 @@ def paged_reshape_and_cache(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
kv_cache_dtype: str = "auto",
k_scale: float = 1.0,
v_scale: float = 1.0,
):
if SYSTEM == "cuda":
try:
import attention_kernels
@ -216,7 +248,7 @@ def paged_reshape_and_cache(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0
key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale
)
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex

View File

@ -119,9 +119,9 @@ def paged_attention(
block_size,
max_s,
None,
"auto",
1.0,
1.0,
kv_cache.kv_cache_dtype_str,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)
else:
# Run PagedAttention V2.
@ -154,9 +154,9 @@ def paged_attention(
block_size,
max_s,
None,
"auto",
1.0,
1.0,
kv_cache.kv_cache_dtype_str,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)
else:
ops.paged_attention_rocm(
@ -174,9 +174,9 @@ def paged_attention(
block_size,
max_s,
None,
"auto",
1.0,
1.0,
kv_cache.kv_cache_dtype_str,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
None,
_PARTITION_SIZE,
)

View File

@ -398,10 +398,16 @@ class LlamaMLP(nn.Module):
return self.down_proj(out, adapter_data)
else:
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
output_shape = gate_up_states.shape[:-1] + (self.intermediate_size,)
out = torch.empty(
output_shape, dtype=gate_up_states.dtype, device=gate_up_states.device
)
ops.silu_and_mul(out, gate_up_states)
return self.down_proj(out, adapter_data)
# gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
# return self.down_proj(
# self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
# )
class FlashLlamaLayer(nn.Module):

View File

@ -520,28 +520,68 @@ class FlashMixtralForCausalLM(torch.nn.Module):
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
if (
torch.distributed.get_rank() == 0
and input_ids.shape[0] == 262144
and cu_seqlen_prefill is not None
):
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
) as prof:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
prof.export_chrome_trace("/tgi/trace_mistral_prefill.json")
else:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits