Fp8 support.

This commit is contained in:
Nicolas Patry 2024-04-11 11:09:13 +00:00
parent c31cb32dd6
commit b24bdb9f8c
2 changed files with 0 additions and 6 deletions

View File

@ -47,7 +47,6 @@ enum Quantization {
/// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
/// perplexity performance for you model
BitsandbytesFP4,
/// [BETA]
/// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above
/// This dtype has native ops should be the fastest if available.
/// This is currently not the fastest because of local unpacking + padding to satisfy matrix

View File

@ -210,13 +210,8 @@ class Fp8Linear(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
qinput, scale = fp8_quantize(input)
seqlen = qinput.shape[0]
if seqlen % 16 != 0:
missing = 16 - seqlen % 16
qinput = F.pad(qinput, (0, 0, 0, missing), "constant", value=0)
output, _ = torch._scaled_mm(qinput, self.qweight.t(), out_dtype=self.dtype,
scale_a=scale , scale_b=self.scale, bias=self.bias)
output = output[:seqlen]
return output
class Linear8bitLt(nn.Module):