fix: fp8 dimensions size

fp8 quantization currently limited to tensors with shapes where both dimensions are divisible by 16.
This commit is contained in:
Dong Shin 2024-04-13 17:27:13 +09:00 committed by GitHub
parent c38a7d7ddd
commit c07f54aac2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 1 deletions

View File

@ -212,6 +212,8 @@ class Fp8Linear(nn.Module):
self.bias = bias if bias is not None else None
def forward(self, input: torch.Tensor) -> torch.Tensor:
if (bsz := input.shape[0]) & 15:
input = F.pad(input,(0, 0, 0, 16 - (bsz & 15)))
qinput, scale = fp8_quantize(input)
output, _ = torch._scaled_mm(
qinput,
@ -221,7 +223,7 @@ class Fp8Linear(nn.Module):
scale_b=self.scale,
bias=self.bias,
)
return output
return output[:bsz]
class Linear8bitLt(nn.Module):