add fp8 kv cache for rocm
This commit is contained in:
parent
8f66d323d0
commit
fa14d71ac8
|
@ -52,13 +52,22 @@ class KVCache:
|
|||
device: torch.device,
|
||||
):
|
||||
"""Construct the key-value cache for a layer."""
|
||||
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
|
||||
if (ATTENTION == "flashinfer" and SYSTEM == "cuda") or not (
|
||||
ATTENTION == "paged" and SYSTEM == "rocm"
|
||||
):
|
||||
raise ValueError(
|
||||
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCM"
|
||||
)
|
||||
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
|
||||
raise ValueError(
|
||||
"float8_e5m2 FP8 KV cache is not supported on AMD Rocm"
|
||||
)
|
||||
|
||||
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and (
|
||||
ATTENTION != "flashinfer" or SYSTEM != "cuda"
|
||||
):
|
||||
raise ValueError(
|
||||
"FP8 KV cache is currently only supported for flashinfer on CUDA"
|
||||
)
|
||||
self.kv_cache_dtype_str = "auto"
|
||||
if SYSTEM == "rocm" and dtype == torch.float8_e4m3fn:
|
||||
self.kv_cache_dtype_str = "fp8"
|
||||
dtype = torch.uint8
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
if SYSTEM == "ipex" and device.type == "xpu":
|
||||
|
@ -120,6 +129,16 @@ class KVCache:
|
|||
"Using FP8 KV cache scales",
|
||||
)
|
||||
return True
|
||||
elif (
|
||||
self.kv_cache_dtype_str == "fp8"
|
||||
and ATTENTION == "paged"
|
||||
and SYSTEM == "rocm"
|
||||
):
|
||||
log_once(
|
||||
logger.info,
|
||||
"Using FP8 KV cache scales",
|
||||
)
|
||||
return True
|
||||
else:
|
||||
# We have scales, but not the correct FP8 cache type, so warn once.
|
||||
log_once(
|
||||
|
@ -158,7 +177,7 @@ class KVCache:
|
|||
key_cache = self.kv_cache[0]
|
||||
value_cache = self.kv_cache[1]
|
||||
|
||||
if self.can_scale(kv_scales):
|
||||
if self.can_scale(kv_scales) and SYSTEM == "cuda":
|
||||
if kv_scales.key_scale_cpu != 1.0:
|
||||
key = fp8_quantize(
|
||||
key.float(),
|
||||
|
@ -188,7 +207,16 @@ class KVCache:
|
|||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
else:
|
||||
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||
paged_reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slots,
|
||||
self.kv_cache_dtype_str,
|
||||
kv_scales.key_scale_cpu,
|
||||
kv_scales.value_scale_cpu,
|
||||
)
|
||||
|
||||
|
||||
def paged_reshape_and_cache(
|
||||
|
@ -197,7 +225,11 @@ def paged_reshape_and_cache(
|
|||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
kv_cache_dtype: str = "auto",
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
):
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
try:
|
||||
import attention_kernels
|
||||
|
@ -216,7 +248,7 @@ def paged_reshape_and_cache(
|
|||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||
)
|
||||
ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0
|
||||
key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale
|
||||
)
|
||||
elif SYSTEM == "ipex":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
|
|
@ -119,9 +119,9 @@ def paged_attention(
|
|||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
1.0,
|
||||
kv_cache.kv_cache_dtype_str,
|
||||
kv_scales.key_scale_cpu,
|
||||
kv_scales.value_scale_cpu,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
|
@ -154,9 +154,9 @@ def paged_attention(
|
|||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
1.0,
|
||||
kv_cache.kv_cache_dtype_str,
|
||||
kv_scales.key_scale_cpu,
|
||||
kv_scales.value_scale_cpu,
|
||||
)
|
||||
else:
|
||||
ops.paged_attention_rocm(
|
||||
|
@ -174,9 +174,9 @@ def paged_attention(
|
|||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
1.0,
|
||||
kv_cache.kv_cache_dtype_str,
|
||||
kv_scales.key_scale_cpu,
|
||||
kv_scales.value_scale_cpu,
|
||||
None,
|
||||
_PARTITION_SIZE,
|
||||
)
|
||||
|
|
|
@ -398,10 +398,16 @@ class LlamaMLP(nn.Module):
|
|||
return self.down_proj(out, adapter_data)
|
||||
else:
|
||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(
|
||||
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||
output_shape = gate_up_states.shape[:-1] + (self.intermediate_size,)
|
||||
out = torch.empty(
|
||||
output_shape, dtype=gate_up_states.dtype, device=gate_up_states.device
|
||||
)
|
||||
ops.silu_and_mul(out, gate_up_states)
|
||||
return self.down_proj(out, adapter_data)
|
||||
# gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
# return self.down_proj(
|
||||
# self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||
# )
|
||||
|
||||
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
|
|
|
@ -520,28 +520,68 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
|||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
true_max_s = max_s
|
||||
if prefill_cache_indices is not None:
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
max_s,
|
||||
true_max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
if (
|
||||
torch.distributed.get_rank() == 0
|
||||
and input_ids.shape[0] == 262144
|
||||
and cu_seqlen_prefill is not None
|
||||
):
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
record_shapes=True,
|
||||
) as prof:
|
||||
true_max_s = max_s
|
||||
if prefill_cache_indices is not None:
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
max_s,
|
||||
true_max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
prof.export_chrome_trace("/tgi/trace_mistral_prefill.json")
|
||||
else:
|
||||
true_max_s = max_s
|
||||
if prefill_cache_indices is not None:
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
max_s,
|
||||
true_max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
||||
|
|
Loading…
Reference in New Issue