fix(l4): fix fp8 logic on l4 (#2277)

* fix(l4): fix fp8 logic on l4

* also quant weights with single scale

* use marlin even on 89
This commit is contained in:
OlivierDehaene 2024-07-23 09:24:29 +00:00 committed by GitHub
parent abc32537ea
commit 5fca30ee15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 11 additions and 6 deletions

View File

@ -32,8 +32,8 @@ def get_fp8_linear() -> torch.nn.Module:
""" """
if SYSTEM == "cuda": if SYSTEM == "cuda":
major, minor = torch.cuda.get_device_capability() major, _ = torch.cuda.get_device_capability()
if major == 8 and minor < 9: if major == 8:
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
return GPTQMarlinFP8Linear return GPTQMarlinFP8Linear
@ -42,8 +42,10 @@ def get_fp8_linear() -> torch.nn.Module:
return Fp8Linear return Fp8Linear
def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn): def fp8_quantize(
if FBGEMM_DYN_AVAILABLE: weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
):
if FBGEMM_DYN_AVAILABLE and not scalar:
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
) )
@ -186,6 +188,9 @@ class Fp8Linear(torch.nn.Module):
dtype, dtype,
) -> None: ) -> None:
super().__init__() super().__init__()
if FBGEMM_MM_AVAILABLE:
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
self.dtype = dtype self.dtype = dtype
self.qweight = qweight self.qweight = qweight
self.scale = scale self.scale = scale
@ -201,7 +206,7 @@ class Fp8Linear(torch.nn.Module):
@classmethod @classmethod
def from_unquant(cls, weight, bias, dtype): def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight) qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
return cls( return cls(
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
) )
@ -232,7 +237,7 @@ class Fp8Linear(torch.nn.Module):
) )
return y.to(self.dtype) return y.to(self.dtype)
qinput, scale = fp8_quantize(input) qinput, scale = fp8_quantize(input, scalar=True)
output, _ = torch._scaled_mm( output, _ = torch._scaled_mm(
qinput, qinput,
self.qweight.t(), self.qweight.t(),