fix(l4): fix fp8 logic on l4 (#2277)
* fix(l4): fix fp8 logic on l4 * also quant weights with single scale * use marlin even on 89
This commit is contained in:
parent
abc32537ea
commit
5fca30ee15
|
@ -32,8 +32,8 @@ def get_fp8_linear() -> torch.nn.Module:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, _ = torch.cuda.get_device_capability()
|
||||||
if major == 8 and minor < 9:
|
if major == 8:
|
||||||
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
||||||
|
|
||||||
return GPTQMarlinFP8Linear
|
return GPTQMarlinFP8Linear
|
||||||
|
@ -42,8 +42,10 @@ def get_fp8_linear() -> torch.nn.Module:
|
||||||
return Fp8Linear
|
return Fp8Linear
|
||||||
|
|
||||||
|
|
||||||
def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
|
def fp8_quantize(
|
||||||
if FBGEMM_DYN_AVAILABLE:
|
weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
|
||||||
|
):
|
||||||
|
if FBGEMM_DYN_AVAILABLE and not scalar:
|
||||||
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||||
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
||||||
)
|
)
|
||||||
|
@ -186,6 +188,9 @@ class Fp8Linear(torch.nn.Module):
|
||||||
dtype,
|
dtype,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if FBGEMM_MM_AVAILABLE:
|
||||||
|
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
|
||||||
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.qweight = qweight
|
self.qweight = qweight
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
@ -201,7 +206,7 @@ class Fp8Linear(torch.nn.Module):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_unquant(cls, weight, bias, dtype):
|
def from_unquant(cls, weight, bias, dtype):
|
||||||
qweight, scale = fp8_quantize(weight)
|
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
|
||||||
return cls(
|
return cls(
|
||||||
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
|
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
|
||||||
)
|
)
|
||||||
|
@ -232,7 +237,7 @@ class Fp8Linear(torch.nn.Module):
|
||||||
)
|
)
|
||||||
return y.to(self.dtype)
|
return y.to(self.dtype)
|
||||||
|
|
||||||
qinput, scale = fp8_quantize(input)
|
qinput, scale = fp8_quantize(input, scalar=True)
|
||||||
output, _ = torch._scaled_mm(
|
output, _ = torch._scaled_mm(
|
||||||
qinput,
|
qinput,
|
||||||
self.qweight.t(),
|
self.qweight.t(),
|
||||||
|
|
Loading…
Reference in New Issue