add torch dtype
This commit is contained in:
parent
f4714a8f98
commit
a7909e6f94
|
@ -140,7 +140,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
|
||||
if self.kv_cache_dtype == "fp8":
|
||||
self.kv_scale = weights.get_kv_cache_scaling_factor(
|
||||
prefix, self.kv_cache_dtype
|
||||
prefix, self.kv_cache_dtype, config.kv_cache_torch_dtype
|
||||
)
|
||||
else:
|
||||
self.kv_scale = 1.0
|
||||
|
|
|
@ -64,6 +64,8 @@ class FlashLlama(FlashCausalLM):
|
|||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
config.kv_cache_dtype = kv_cache_dtype
|
||||
if not hasattr(config, "kv_cache_torch_dtype"):
|
||||
config.kv_cache_torch_dtype = None
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
|
|
@ -767,7 +767,9 @@ class Weights:
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
def get_kv_cache_scaling_factor(self, prefix: str, kv_cache_dtype: str):
|
||||
def get_kv_cache_scaling_factor(
|
||||
self, prefix: str, kv_cache_dtype: str, kv_cache_torch_dtype: str
|
||||
):
|
||||
try:
|
||||
kv_scale = self.get_tensor(f"{prefix}.kv_scale").cpu().tolist()
|
||||
except RuntimeError:
|
||||
|
@ -791,13 +793,21 @@ class Weights:
|
|||
"Only support per-tensor scaling factor for `fp8 (fp8_e4m3)` KV cache"
|
||||
)
|
||||
|
||||
if kv_cache_torch_dtype not in {"float8_e4m3fn", "float8_e4m3fnuz"}:
|
||||
raise RuntimeError(
|
||||
f"Found `kv_scale` in the checkpoint, the config must specify the `kv_cache_torch_dtype` "
|
||||
f"used for generating kv scales. Expected 'float8_e4m3fn' or 'float8_e4m3fnuz', but got '{kv_cache_torch_dtype}'."
|
||||
)
|
||||
|
||||
# ROCm uses FP8 format with fp8_e4m3fn, whereas Nvidia GPUs use fp8_e4m3.
|
||||
# The multiplication by 2 compensates for the different numeric representation
|
||||
# between ROCm and Nvidia GPUs, ensuring consistent effective scaling across platforms.
|
||||
# After this adjustment, the overall effect is equivalent to the scaling applied without
|
||||
# it on Nvidia GPUs.
|
||||
if SYSTEM == "rocm":
|
||||
if SYSTEM == "rocm" and kv_cache_torch_dtype == "float8_e4m3fn":
|
||||
kv_scale *= 2.0
|
||||
elif SYSTEM == "cuda" and kv_cache_torch_dtype == "float8_e4m3fnuz":
|
||||
kv_scale /= 2.0
|
||||
|
||||
return kv_scale
|
||||
|
||||
|
|
Loading…
Reference in New Issue