add torch dtype

This commit is contained in:
Mohit Sharma 2024-06-25 14:12:29 +00:00
parent f4714a8f98
commit a7909e6f94
3 changed files with 15 additions and 3 deletions

View File

@ -140,7 +140,7 @@ class FlashLlamaAttention(torch.nn.Module):
if self.kv_cache_dtype == "fp8":
self.kv_scale = weights.get_kv_cache_scaling_factor(
prefix, self.kv_cache_dtype
prefix, self.kv_cache_dtype, config.kv_cache_torch_dtype
)
else:
self.kv_scale = 1.0

View File

@ -64,6 +64,8 @@ class FlashLlama(FlashCausalLM):
config.quantize = quantize
config.speculator = speculator
config.kv_cache_dtype = kv_cache_dtype
if not hasattr(config, "kv_cache_torch_dtype"):
config.kv_cache_torch_dtype = None
torch.distributed.barrier(group=self.process_group)

View File

@ -767,7 +767,9 @@ class Weights:
except Exception:
pass
def get_kv_cache_scaling_factor(self, prefix: str, kv_cache_dtype: str):
def get_kv_cache_scaling_factor(
self, prefix: str, kv_cache_dtype: str, kv_cache_torch_dtype: str
):
try:
kv_scale = self.get_tensor(f"{prefix}.kv_scale").cpu().tolist()
except RuntimeError:
@ -791,13 +793,21 @@ class Weights:
"Only support per-tensor scaling factor for `fp8 (fp8_e4m3)` KV cache"
)
if kv_cache_torch_dtype not in {"float8_e4m3fn", "float8_e4m3fnuz"}:
raise RuntimeError(
f"Found `kv_scale` in the checkpoint, the config must specify the `kv_cache_torch_dtype` "
f"used for generating kv scales. Expected 'float8_e4m3fn' or 'float8_e4m3fnuz', but got '{kv_cache_torch_dtype}'."
)
# ROCm uses FP8 format with fp8_e4m3fn, whereas Nvidia GPUs use fp8_e4m3.
# The multiplication by 2 compensates for the different numeric representation
# between ROCm and Nvidia GPUs, ensuring consistent effective scaling across platforms.
# After this adjustment, the overall effect is equivalent to the scaling applied without
# it on Nvidia GPUs.
if SYSTEM == "rocm":
if SYSTEM == "rocm" and kv_cache_torch_dtype == "float8_e4m3fn":
kv_scale *= 2.0
elif SYSTEM == "cuda" and kv_cache_torch_dtype == "float8_e4m3fnuz":
kv_scale /= 2.0
return kv_scale