Merge c07f54aac2
into 88702d8763
This commit is contained in:
commit
9b8a7efb4b
|
@ -212,6 +212,8 @@ class Fp8Linear(nn.Module):
|
|||
self.bias = bias if bias is not None else None
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if (bsz := input.shape[0]) & 15:
|
||||
input = F.pad(input,(0, 0, 0, 16 - (bsz & 15)))
|
||||
qinput, scale = fp8_quantize(input)
|
||||
output, _ = torch._scaled_mm(
|
||||
qinput,
|
||||
|
@ -221,7 +223,7 @@ class Fp8Linear(nn.Module):
|
|||
scale_b=self.scale,
|
||||
bias=self.bias,
|
||||
)
|
||||
return output
|
||||
return output[:bsz]
|
||||
|
||||
|
||||
class Linear8bitLt(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue