BNB 4bits.

This commit is contained in:
Nicolas Patry 2023-07-12 12:42:43 +00:00
parent a1c23f3823
commit bfa3920aec
1 changed files with 45 additions and 4 deletions

View File

@ -8,7 +8,6 @@ from typing import List
HAS_BITS_AND_BYTES = True HAS_BITS_AND_BYTES = True
try: try:
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes.nn import LinearNF4
except ImportError: except ImportError:
HAS_BITS_AND_BYTES = False HAS_BITS_AND_BYTES = False
@ -71,16 +70,58 @@ class FastLinear(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight, self.bias) return F.linear(input, self.weight, self.bias)
class Linear4bit(nn.Module):
def __init__(
self,
weight,
bias,
):
super().__init__()
compute_dtype = None
compress_statistics = True
quant_type = "nf4"
self.weight = bnb.nn.modules.Params4bit(
weight.data,
requires_grad=False,
compress_statistics=compress_statistics,
quant_type=quant_type,
).cuda("cuda")
self.bias = bias
self.compute_dtype = compute_dtype
def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
if getattr(self.weight, "quant_state", None) is None:
print(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
)
inp_dtype = x.dtype
if self.compute_dtype is not None:
x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
out = bnb.matmul_4bit(
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
)
out = out.to(inp_dtype)
return out
def get_linear(weight, bias, quantize): def get_linear(weight, bias, quantize):
if quantize is None: if quantize is None:
linear = FastLinear(weight, bias) linear = FastLinear(weight, bias)
elif quantize == "bitsandbytes": elif quantize == "bitsandbytes":
linear = LinearNF4( linear = Linear4bit(
weight, weight,
bias, bias,
) )
if bias is not None:
linear.bias = nn.Parameter(bias)
elif quantize == "gptq": elif quantize == "gptq":
try: try:
qweight, qzeros, scales, g_idx, bits, groupsize = weight qweight, qzeros, scales, g_idx, bits, groupsize = weight