Merge pull request #4 from michaelfeil/bnb_4bit
4bit quantization with bitsandbytes
This commit is contained in:
commit
f93012d59c
|
@ -34,7 +34,4 @@ members/contributors who may be interested in your PR.
|
|||
|
||||
<!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @
|
||||
|
||||
|
||||
@OlivierDehaene OR @Narsil
|
||||
|
||||
-->
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
This is Preemo's fork of `text-generation-inference`, originally developed by Hugging Face. The original README is at [README-HuggingFace.md](README-HuggingFace.md). Since Hugging Face's `text-generation-inference` is no longer open-source, we have forked it and will continue to develop it here.
|
||||
|
||||
|
||||
Our goal is to create an open-source text generation inference server that is modularized to allow for easy add state-of-the-art models, functionalities and optimizations. Functionalities and optimizations should be composable, so that users can easily combine them to create a custom inference server that fits their needs.
|
||||
|
||||
## our plan
|
||||
|
@ -13,3 +14,9 @@ We at Preemo are currently busy working on our first release of our other produc
|
|||
- [ ] Modularizing the codebase, introducing a plugin system
|
||||
|
||||
Our long-term goal is to grow the community around this repository, as a playground for trying out new ideas and optimizations in LLM inference. We at Preemo will implement features that interest us, but we also welcome contributions from the community, as long as they are modularized and composable.
|
||||
|
||||
## Extra features in comparison to Hugging Face `text-generation-inference` v0.9.4
|
||||
|
||||
### 4bit quantization
|
||||
|
||||
4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.
|
||||
|
|
|
@ -22,6 +22,8 @@ mod env_runtime;
|
|||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Quantization {
|
||||
Bitsandbytes,
|
||||
BitsandbytesNF4,
|
||||
BitsandbytesFP4,
|
||||
Gptq,
|
||||
}
|
||||
|
||||
|
@ -32,6 +34,12 @@ impl std::fmt::Display for Quantization {
|
|||
Quantization::Bitsandbytes => {
|
||||
write!(f, "bitsandbytes")
|
||||
}
|
||||
Quantization::BitsandbytesNF4 => {
|
||||
write!(f, "bitsandbytes-nf4")
|
||||
}
|
||||
Quantization::BitsandbytesFP4 => {
|
||||
write!(f, "bitsandbytes-fp4")
|
||||
}
|
||||
Quantization::Gptq => {
|
||||
write!(f, "gptq")
|
||||
}
|
||||
|
@ -96,7 +104,8 @@ struct Args {
|
|||
num_shard: Option<usize>,
|
||||
|
||||
/// Whether you want the model to be quantized. This will use `bitsandbytes` for
|
||||
/// quantization on the fly, or `gptq`.
|
||||
/// quantization on the fly, or `gptq`. 4bit quantization is available through
|
||||
/// `bitsandbytes` by providing the `bitsandbytes-fp4` or `bitsandbytes-nf4` options.
|
||||
#[clap(long, env, value_enum)]
|
||||
quantize: Option<Quantization>,
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
|||
text-generation-server = 'text_generation_server.cli:app'
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.9"
|
||||
python = ">=3.9,<3.13"
|
||||
protobuf = "^4.21.7"
|
||||
grpcio = "^1.51.1"
|
||||
grpcio-status = "^1.51.1"
|
||||
|
@ -16,7 +16,7 @@ grpcio-reflection = "^1.51.1"
|
|||
grpc-interceptor = "^0.15.0"
|
||||
typer = "^0.6.1"
|
||||
accelerate = { version = "^0.19.0", optional = true }
|
||||
bitsandbytes = { version = "^0.38.1", optional = true }
|
||||
bitsandbytes = { version = "^0.40.0", optional = true }
|
||||
safetensors = "0.3.1"
|
||||
loguru = "^0.6.0"
|
||||
opentelemetry-api = "^1.15.0"
|
||||
|
|
|
@ -13,6 +13,8 @@ app = typer.Typer()
|
|||
|
||||
class Quantization(str, Enum):
|
||||
bitsandbytes = "bitsandbytes"
|
||||
bitsandbytes_nf4 = "bitsandbytes-nf4"
|
||||
bitsandbytes_fp4 = "bitsandbytes-fp4"
|
||||
gptq = "gptq"
|
||||
|
||||
|
||||
|
|
|
@ -257,6 +257,11 @@ def get_model(
|
|||
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
||||
raise ValueError(
|
||||
"4bit quantization is not supported for AutoModel"
|
||||
)
|
||||
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
return AutoCausalLM(
|
||||
model_id,
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import List
|
|||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import Int8Params
|
||||
from bitsandbytes.nn import Int8Params, Params4bit
|
||||
|
||||
except ImportError:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
|
@ -140,6 +140,39 @@ class Linear8bitLt(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
class Linear4bit(nn.Module):
|
||||
def __init__(self, weight, bias, quant_type):
|
||||
super().__init__()
|
||||
self.weight = Params4bit(
|
||||
weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type
|
||||
)
|
||||
self.compute_dtype = None
|
||||
self.weight.cuda(weight.device)
|
||||
self.bias = bias
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
if getattr(self.weight, "quant_state", None) is None:
|
||||
print(
|
||||
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
|
||||
)
|
||||
inp_dtype = x.dtype
|
||||
if self.compute_dtype is not None:
|
||||
x = x.to(self.compute_dtype)
|
||||
|
||||
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
|
||||
out = bnb.matmul_4bit(
|
||||
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
|
||||
)
|
||||
|
||||
out = out.to(inp_dtype)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def get_linear(weight, bias, quantize):
|
||||
if quantize is None:
|
||||
linear = FastLinear(weight, bias)
|
||||
|
@ -152,6 +185,18 @@ def get_linear(weight, bias, quantize):
|
|||
)
|
||||
if bias is not None:
|
||||
linear.bias = nn.Parameter(bias)
|
||||
elif quantize == "bitsandbytes-fp4":
|
||||
linear = Linear4bit(
|
||||
weight,
|
||||
bias,
|
||||
quant_type="fp4",
|
||||
)
|
||||
elif quantize == "bitsandbytes-nf4":
|
||||
linear = Linear4bit(
|
||||
weight,
|
||||
bias,
|
||||
quant_type="nf4",
|
||||
)
|
||||
elif quantize == "gptq":
|
||||
try:
|
||||
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
|
||||
|
@ -219,31 +264,36 @@ class TensorParallelHead(SuperLayer):
|
|||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if not self.should_gather:
|
||||
return super().forward(input)
|
||||
|
||||
world_size = self.process_group.size()
|
||||
# Fast branch for single requests
|
||||
if (
|
||||
self.should_gather
|
||||
and len(input.shape) == 2
|
||||
and isinstance(self.linear, FastLinear)
|
||||
and input.shape[0] == 1
|
||||
):
|
||||
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
|
||||
out_dim = self.linear.weight.shape[0]
|
||||
|
||||
if input.shape[0] == 1:
|
||||
world_out = input.new_empty(1, out_dim * world_size)
|
||||
local_out = input.new_empty(1, out_dim)
|
||||
gather_input = local_out
|
||||
else:
|
||||
world_out = input.new_empty(out_dim * world_size, input.shape[0])
|
||||
gather_input = input.new_empty(out_dim, input.shape[0])
|
||||
local_out = gather_input.T
|
||||
|
||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
world_out, local_out, group=self.process_group
|
||||
world_out, gather_input, group=self.process_group
|
||||
)
|
||||
|
||||
if input.shape[0] == 1:
|
||||
return world_out
|
||||
return world_out.T
|
||||
|
||||
output = super().forward(input)
|
||||
if not self.should_gather:
|
||||
return output
|
||||
|
||||
world_output = [torch.empty_like(output) for _ in range(world_size)]
|
||||
world_output = [
|
||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||
]
|
||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||
world_output = torch.cat(world_output, dim=-1)
|
||||
return world_output
|
||||
|
|
Loading…
Reference in New Issue