diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 86394ff7..40ee55d7 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -66,6 +66,7 @@ Options: - bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model + - fp8: [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above This dtype has native ops should be the fastest if available. This is currently not the fastest because of local unpacking + padding to satisfy matrix multiplication limitations ``` ## SPECULATE diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 836b0381..cf876fbd 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -47,6 +47,11 @@ enum Quantization { /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better /// perplexity performance for you model BitsandbytesFP4, + /// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above + /// This dtype has native ops should be the fastest if available. + /// This is currently not the fastest because of local unpacking + padding to satisfy matrix + /// multiplication limitations. + Fp8, } impl std::fmt::Display for Quantization { @@ -73,6 +78,9 @@ impl std::fmt::Display for Quantization { Quantization::Eetq => { write!(f, "eetq") } + Quantization::Fp8 => { + write!(f, "fp8") + } } } } diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index e8b126d9..bb0963d4 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -19,6 +19,7 @@ class Quantization(str, Enum): gptq = "gptq" awq = "awq" eetq = "eetq" + fp8 = "fp8" class Dtype(str, Enum): diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index f29e55c5..2b95bc74 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -182,6 +182,48 @@ class EETQLinear(nn.Module): return output +def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): + device = weight.device + # weight, scale = quant_weights(weight, torch.int8, False) + finfo = torch.finfo(qdtype) + # Calculate the scale as dtype max divided by absmax + scale = finfo.max / weight.abs().max().clamp(min=1e-12) + # scale and clamp the tensor to bring it to + # the representative range of float8 data type + # (as default cast is unsaturated) + qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) + # Return both float8 data and the inverse scale (as float), + # as both required as inputs to torch._scaled_mm + qweight = qweight.to(qdtype) + scale = scale.float().reciprocal() + return qweight, scale + + +class Fp8Linear(nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + self.dtype = weight.dtype + self.qweight, self.scale = fp8_quantize(weight) + + self.bias = bias if bias is not None else None + + def forward(self, input: torch.Tensor) -> torch.Tensor: + qinput, scale = fp8_quantize(input) + output, _ = torch._scaled_mm( + qinput, + self.qweight.t(), + out_dtype=self.dtype, + scale_a=scale, + scale_b=self.scale, + bias=self.bias, + ) + return output + + class Linear8bitLt(nn.Module): def __init__( self, @@ -293,6 +335,8 @@ def get_linear(weight, bias, quantize): raise ImportError( "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" ) + elif quantize == "fp8": + linear = Fp8Linear(weight, bias) elif quantize == "bitsandbytes": warn_deprecate_bnb() linear = Linear8bitLt(