BNB 4bits.
This commit is contained in:
parent
a1c23f3823
commit
bfa3920aec
|
@ -8,7 +8,6 @@ from typing import List
|
|||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import LinearNF4
|
||||
|
||||
except ImportError:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
|
@ -71,16 +70,58 @@ class FastLinear(nn.Module):
|
|||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
|
||||
|
||||
class Linear4bit(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
bias,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
compute_dtype = None
|
||||
compress_statistics = True
|
||||
quant_type = "nf4"
|
||||
self.weight = bnb.nn.modules.Params4bit(
|
||||
weight.data,
|
||||
requires_grad=False,
|
||||
compress_statistics=compress_statistics,
|
||||
quant_type=quant_type,
|
||||
).cuda("cuda")
|
||||
self.bias = bias
|
||||
self.compute_dtype = compute_dtype
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
if getattr(self.weight, "quant_state", None) is None:
|
||||
print(
|
||||
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
|
||||
)
|
||||
inp_dtype = x.dtype
|
||||
if self.compute_dtype is not None:
|
||||
x = x.to(self.compute_dtype)
|
||||
|
||||
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
|
||||
out = bnb.matmul_4bit(
|
||||
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
|
||||
)
|
||||
|
||||
out = out.to(inp_dtype)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def get_linear(weight, bias, quantize):
|
||||
if quantize is None:
|
||||
linear = FastLinear(weight, bias)
|
||||
elif quantize == "bitsandbytes":
|
||||
linear = LinearNF4(
|
||||
linear = Linear4bit(
|
||||
weight,
|
||||
bias,
|
||||
)
|
||||
if bias is not None:
|
||||
linear.bias = nn.Parameter(bias)
|
||||
elif quantize == "gptq":
|
||||
try:
|
||||
qweight, qzeros, scales, g_idx, bits, groupsize = weight
|
||||
|
|
Loading…
Reference in New Issue