restoring commit from dev branch, rebase on current master

This commit is contained in:
michaelfeil 2023-08-01 18:15:18 +02:00
parent afd04dc71e
commit 44fa36b5bf
6 changed files with 88 additions and 20 deletions

View File

@ -254,6 +254,8 @@ You can also quantize the weights with bitsandbytes to reduce the VRAM requireme
make run-falcon-7b-instruct-quantize
```
4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.
## Develop
```shell

View File

@ -22,6 +22,8 @@ mod env_runtime;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization {
Bitsandbytes,
BitsandbytesNF4,
BitsandbytesFP4,
Gptq,
}
@ -32,6 +34,12 @@ impl std::fmt::Display for Quantization {
Quantization::Bitsandbytes => {
write!(f, "bitsandbytes")
}
Quantization::BitsandbytesNF4 => {
write!(f, "bitsandbytes-nf4")
}
Quantization::BitsandbytesFP4 => {
write!(f, "bitsandbytes-fp4")
}
Quantization::Gptq => {
write!(f, "gptq")
}
@ -96,7 +104,8 @@ struct Args {
num_shard: Option<usize>,
/// Whether you want the model to be quantized. This will use `bitsandbytes` for
/// quantization on the fly, or `gptq`.
/// quantization on the fly, or `gptq`. 4bit quantization is available through
/// `bitsandbytes` by providing the `bitsandbytes-fp4` or `bitsandbytes-nf4` options.
#[clap(long, env, value_enum)]
quantize: Option<Quantization>,

View File

@ -8,7 +8,7 @@ authors = ["Olivier Dehaene <olivier@huggingface.co>"]
text-generation-server = 'text_generation_server.cli:app'
[tool.poetry.dependencies]
python = "^3.9"
python = >=3.9,<3.13"
protobuf = "^4.21.7"
grpcio = "^1.51.1"
grpcio-status = "^1.51.1"
@ -16,7 +16,7 @@ grpcio-reflection = "^1.51.1"
grpc-interceptor = "^0.15.0"
typer = "^0.6.1"
accelerate = { version = "^0.19.0", optional = true }
bitsandbytes = { version = "^0.38.1", optional = true }
bitsandbytes = { version = "^0.40.0", optional = true }
safetensors = "0.3.1"
loguru = "^0.6.0"
opentelemetry-api = "^1.15.0"

View File

@ -13,6 +13,8 @@ app = typer.Typer()
class Quantization(str, Enum):
bitsandbytes = "bitsandbytes"
bitsandbytes_nf4 = "bitsandbytes-nf4"
bitsandbytes_fp4 = "bitsandbytes-fp4"
gptq = "gptq"

View File

@ -256,6 +256,11 @@ def get_model(
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
raise ValueError(
"4bit quantization is not supported for AutoModel"
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
model_id,

View File

@ -9,7 +9,7 @@ from typing import List
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
from bitsandbytes.nn import Int8Params, Params4bit
except ImportError:
HAS_BITS_AND_BYTES = False
@ -140,6 +140,39 @@ class Linear8bitLt(nn.Module):
return out
class Linear4bit(nn.Module):
def __init__(self, weight, bias, quant_type):
super().__init__()
self.weight = Params4bit(
weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type
)
self.compute_dtype = None
self.weight.cuda(weight.device)
self.bias = bias
def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
if getattr(self.weight, "quant_state", None) is None:
print(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
)
inp_dtype = x.dtype
if self.compute_dtype is not None:
x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
out = bnb.matmul_4bit(
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
)
out = out.to(inp_dtype)
return out
def get_linear(weight, bias, quantize):
if quantize is None:
linear = FastLinear(weight, bias)
@ -152,9 +185,21 @@ def get_linear(weight, bias, quantize):
)
if bias is not None:
linear.bias = nn.Parameter(bias)
elif quantize == "bitsandbytes-fp4":
linear = Linear4bit(
weight,
bias,
quant_type="fp4",
)
elif quantize == "bitsandbytes-nf4":
linear = Linear4bit(
weight,
bias,
quant_type="nf4",
)
elif quantize == "gptq":
try:
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
qweight, qzeros, scales, g_idx, bits, groupsize = weight
except Exception:
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
@ -219,31 +264,36 @@ class TensorParallelHead(SuperLayer):
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if not self.should_gather:
return super().forward(input)
world_size = self.process_group.size()
# Fast branch for single requests
if (
self.should_gather
and len(input.shape) == 2
and isinstance(self.linear, FastLinear)
and input.shape[0] == 1
):
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
out_dim = self.linear.weight.shape[0]
world_out = input.new_empty(1, out_dim * world_size)
local_out = input.new_empty(1, out_dim)
if input.shape[0] == 1:
world_out = input.new_empty(1, out_dim * world_size)
local_out = input.new_empty(1, out_dim)
gather_input = local_out
else:
world_out = input.new_empty(out_dim * world_size, input.shape[0])
gather_input = input.new_empty(out_dim, input.shape[0])
local_out = gather_input.T
torch.mm(input, self.linear.weight.T, out=local_out)
torch.distributed.all_gather_into_tensor(
world_out, local_out, group=self.process_group
world_out, gather_input, group=self.process_group
)
return world_out
if input.shape[0] == 1:
return world_out
return world_out.T
output = super().forward(input)
if not self.should_gather:
return output
world_output = [torch.empty_like(output) for _ in range(world_size)]
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
return world_output