Merge pull request #4 from michaelfeil/bnb_4bit

4bit quantization with bitsandbytes
This commit is contained in:
Yang, Bo 2023-09-08 14:52:32 -07:00 committed by GitHub
commit f93012d59c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 92 additions and 22 deletions

View File

@ -34,7 +34,4 @@ members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @
@OlivierDehaene OR @Narsil
-->

View File

@ -2,6 +2,7 @@
This is Preemo's fork of `text-generation-inference`, originally developed by Hugging Face. The original README is at [README-HuggingFace.md](README-HuggingFace.md). Since Hugging Face's `text-generation-inference` is no longer open-source, we have forked it and will continue to develop it here.
Our goal is to create an open-source text generation inference server that is modularized to allow for easy add state-of-the-art models, functionalities and optimizations. Functionalities and optimizations should be composable, so that users can easily combine them to create a custom inference server that fits their needs.
## our plan
@ -13,3 +14,9 @@ We at Preemo are currently busy working on our first release of our other produc
- [ ] Modularizing the codebase, introducing a plugin system
Our long-term goal is to grow the community around this repository, as a playground for trying out new ideas and optimizations in LLM inference. We at Preemo will implement features that interest us, but we also welcome contributions from the community, as long as they are modularized and composable.
## Extra features in comparison to Hugging Face `text-generation-inference` v0.9.4
### 4bit quantization
4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.

View File

@ -22,6 +22,8 @@ mod env_runtime;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization {
Bitsandbytes,
BitsandbytesNF4,
BitsandbytesFP4,
Gptq,
}
@ -32,6 +34,12 @@ impl std::fmt::Display for Quantization {
Quantization::Bitsandbytes => {
write!(f, "bitsandbytes")
}
Quantization::BitsandbytesNF4 => {
write!(f, "bitsandbytes-nf4")
}
Quantization::BitsandbytesFP4 => {
write!(f, "bitsandbytes-fp4")
}
Quantization::Gptq => {
write!(f, "gptq")
}
@ -96,7 +104,8 @@ struct Args {
num_shard: Option<usize>,
/// Whether you want the model to be quantized. This will use `bitsandbytes` for
/// quantization on the fly, or `gptq`.
/// quantization on the fly, or `gptq`. 4bit quantization is available through
/// `bitsandbytes` by providing the `bitsandbytes-fp4` or `bitsandbytes-nf4` options.
#[clap(long, env, value_enum)]
quantize: Option<Quantization>,

View File

@ -8,7 +8,7 @@ authors = ["Olivier Dehaene <olivier@huggingface.co>"]
text-generation-server = 'text_generation_server.cli:app'
[tool.poetry.dependencies]
python = "^3.9"
python = ">=3.9,<3.13"
protobuf = "^4.21.7"
grpcio = "^1.51.1"
grpcio-status = "^1.51.1"
@ -16,7 +16,7 @@ grpcio-reflection = "^1.51.1"
grpc-interceptor = "^0.15.0"
typer = "^0.6.1"
accelerate = { version = "^0.19.0", optional = true }
bitsandbytes = { version = "^0.38.1", optional = true }
bitsandbytes = { version = "^0.40.0", optional = true }
safetensors = "0.3.1"
loguru = "^0.6.0"
opentelemetry-api = "^1.15.0"

View File

@ -13,6 +13,8 @@ app = typer.Typer()
class Quantization(str, Enum):
bitsandbytes = "bitsandbytes"
bitsandbytes_nf4 = "bitsandbytes-nf4"
bitsandbytes_fp4 = "bitsandbytes-fp4"
gptq = "gptq"

View File

@ -257,6 +257,11 @@ def get_model(
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
raise ValueError(
"4bit quantization is not supported for AutoModel"
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return AutoCausalLM(
model_id,

View File

@ -9,7 +9,7 @@ from typing import List
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
from bitsandbytes.nn import Int8Params, Params4bit
except ImportError:
HAS_BITS_AND_BYTES = False
@ -140,6 +140,39 @@ class Linear8bitLt(nn.Module):
return out
class Linear4bit(nn.Module):
def __init__(self, weight, bias, quant_type):
super().__init__()
self.weight = Params4bit(
weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type
)
self.compute_dtype = None
self.weight.cuda(weight.device)
self.bias = bias
def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
if getattr(self.weight, "quant_state", None) is None:
print(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
)
inp_dtype = x.dtype
if self.compute_dtype is not None:
x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
out = bnb.matmul_4bit(
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
)
out = out.to(inp_dtype)
return out
def get_linear(weight, bias, quantize):
if quantize is None:
linear = FastLinear(weight, bias)
@ -152,6 +185,18 @@ def get_linear(weight, bias, quantize):
)
if bias is not None:
linear.bias = nn.Parameter(bias)
elif quantize == "bitsandbytes-fp4":
linear = Linear4bit(
weight,
bias,
quant_type="fp4",
)
elif quantize == "bitsandbytes-nf4":
linear = Linear4bit(
weight,
bias,
quant_type="nf4",
)
elif quantize == "gptq":
try:
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
@ -219,31 +264,36 @@ class TensorParallelHead(SuperLayer):
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if not self.should_gather:
return super().forward(input)
world_size = self.process_group.size()
# Fast branch for single requests
if (
self.should_gather
and len(input.shape) == 2
and isinstance(self.linear, FastLinear)
and input.shape[0] == 1
):
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
out_dim = self.linear.weight.shape[0]
if input.shape[0] == 1:
world_out = input.new_empty(1, out_dim * world_size)
local_out = input.new_empty(1, out_dim)
gather_input = local_out
else:
world_out = input.new_empty(out_dim * world_size, input.shape[0])
gather_input = input.new_empty(out_dim, input.shape[0])
local_out = gather_input.T
torch.mm(input, self.linear.weight.T, out=local_out)
torch.distributed.all_gather_into_tensor(
world_out, local_out, group=self.process_group
world_out, gather_input, group=self.process_group
)
if input.shape[0] == 1:
return world_out
return world_out.T
output = super().forward(input)
if not self.should_gather:
return output
world_output = [torch.empty_like(output) for _ in range(world_size)]
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
return world_output