Add support for exl2 quantization

Mostly straightforward, changes to existing code:

* Wrap quantizer parameters in a small wrapper to avoid passing
  around untyped tuples and needing to repack them as a dict.
* Move scratch space computation to warmup, because we need the
  maximum input sequence length to avoid allocating huge
  scratch buffers that OOM.
This commit is contained in:
Daniël de Kok 2024-05-28 09:51:31 +00:00 committed by Daniël de Kok
parent cbced7f0f9
commit 36dd16017c
23 changed files with 972 additions and 177 deletions

View File

@ -62,6 +62,7 @@ Options:
Possible values:
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
- exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1)
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16

View File

@ -2,7 +2,6 @@
## What is Guidance?
Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format. A prominent example is JSON grammar, where the model is forced to output valid JSON.
## How is it used?

View File

@ -38,6 +38,7 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
class ResponseComparator(JSONSnapshotExtension):
rtol = 0.2
ignore_logprob = False
def serialize(
self,
@ -95,7 +96,10 @@ class ResponseComparator(JSONSnapshotExtension):
return (
token.id == other.id
and token.text == other.text
and math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
and (
self.ignore_logprob
or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
)
and token.special == other.special
)
@ -105,8 +109,11 @@ class ResponseComparator(JSONSnapshotExtension):
prefill_token.id == other.id
and prefill_token.text == other.text
and (
math.isclose(
prefill_token.logprob, other.logprob, rel_tol=self.rtol
self.ignore_logprob
or math.isclose(
prefill_token.logprob,
other.logprob,
rel_tol=self.rtol,
)
if prefill_token.logprob is not None
else prefill_token.logprob == other.logprob
@ -223,6 +230,10 @@ class GenerousResponseComparator(ResponseComparator):
rtol = 0.75
class IgnoreLogProbResponseComparator(ResponseComparator):
ignore_logprob = True
class LauncherHandle:
def __init__(self, port: int):
self.client = AsyncClient(f"http://localhost:{port}")
@ -274,6 +285,11 @@ def generous_response_snapshot(snapshot):
return snapshot.use_extension(GenerousResponseComparator)
@pytest.fixture
def ignore_logprob_response_snapshot(snapshot):
return snapshot.use_extension(IgnoreLogProbResponseComparator)
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()

View File

@ -0,0 +1,84 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.4375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9316406,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.5136719,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.7783203,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2314453,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -2.0019531,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.5009766,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.057434082,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.4912109,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2636719,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.4042969,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
}

View File

@ -0,0 +1,84 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.453125,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": -1.9980469,
"special": false,
"text": "."
},
{
"id": 578,
"logprob": -0.15795898,
"special": false,
"text": " The"
},
{
"id": 3622,
"logprob": -1.0458984,
"special": false,
"text": " server"
},
{
"id": 31680,
"logprob": -1.3623047,
"special": false,
"text": " responds"
},
{
"id": 449,
"logprob": 0.0,
"special": false,
"text": " with"
},
{
"id": 264,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 330,
"logprob": -0.5678711,
"special": false,
"text": " \""
},
{
"id": 1049,
"logprob": -0.12322998,
"special": false,
"text": "200"
},
{
"id": 10619,
"logprob": 0.0,
"special": false,
"text": " OK"
},
{
"id": 1,
"logprob": 0.0,
"special": false,
"text": "\""
}
],
"top_tokens": null
},
"generated_text": "Test request. The server responds with a \"200 OK\""
}

View File

@ -0,0 +1,338 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.453125,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9785156,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.4941406,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.79345703,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2324219,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -1.9794922,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.4892578,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.058258057,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.4892578,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2783203,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.3945312,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.40625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9433594,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.4726562,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.8022461,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2509766,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -1.984375,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.4677734,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.059173584,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.4990234,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2822266,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.3867188,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.421875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9511719,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.46875,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.77490234,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2558594,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -1.984375,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.4990234,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.059143066,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.4941406,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2578125,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.3964844,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.4140625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9101562,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.5039062,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.8076172,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2236328,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -1.9853516,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.4892578,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.056671143,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.5107422,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2597656,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.4042969,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
}
]

View File

@ -0,0 +1,73 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_exl2_handle(launcher):
with launcher(
"turboderp/Llama-3-8B-Instruct-exl2",
revision="2.5bpw",
# Set max input length to avoid OOM due to extremely large
# scratch buffer.
max_input_length=1024,
num_shard=1,
quantize="exl2",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_exl2(flash_llama_exl2_handle):
await flash_llama_exl2_handle.health(300)
return flash_llama_exl2_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
response = await flash_llama_exl2.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == ignore_logprob_response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2_all_params(
flash_llama_exl2, ignore_logprob_response_snapshot
):
response = await flash_llama_exl2.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert (
response.generated_text == 'Test request. The server responds with a "200 OK"'
)
assert response == ignore_logprob_response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2_load(
flash_llama_exl2, generate_load, ignore_logprob_response_snapshot
):
responses = await generate_load(
flash_llama_exl2, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == ignore_logprob_response_snapshot

View File

@ -55,6 +55,10 @@ enum Quantization {
/// Should be a drop-in replacement to bitsandbytes with much better performance.
/// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
Eetq,
/// Variable bit quantization. Requires a specific EXL2 quantized model:
/// <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does
/// not support tensor parallelism (num_shard > 1).
Exl2,
/// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>.
/// text-generation-inference will use exllama (faster) kernels wherever possible, and use
/// triton kernel (wider support) when it's not.
@ -95,6 +99,9 @@ impl std::fmt::Display for Quantization {
Quantization::BitsandbytesFP4 => {
write!(f, "bitsandbytes-fp4")
}
Quantization::Exl2 => {
write!(f, "exl2")
}
Quantization::Gptq => {
write!(f, "gptq")
}
@ -1461,6 +1468,11 @@ fn main() -> Result<(), LauncherError> {
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
if num_shard > 1 {
if matches!(args.quantize, Some(Quantization::Exl2)) {
return Err(LauncherError::ArgumentValidation(
"Sharding is currently not supported with `exl2` quantization".into(),
));
}
tracing::info!("Sharding model on {num_shard} processes");
}

View File

@ -19,6 +19,7 @@ class Quantization(str, Enum):
gptq = "gptq"
awq = "awq"
eetq = "eetq"
exl2 = "exl2"
fp8 = "fp8"

View File

@ -0,0 +1,23 @@
import torch
from dataclasses import dataclass
@dataclass
class Exl2Weight:
"""
Exllama2 exl2 quantized weights.
"""
q_weight: torch.Tensor
q_scale: torch.Tensor
q_invperm: torch.Tensor
q_scale_max: torch.Tensor
q_groups: torch.Tensor
def __post_init__(self):
self.q_scale_max /= 256
self.q_invperm = self.q_invperm.short()
@property
def device(self) -> torch.device:
return self.q_weight.device

View File

@ -1,9 +1,31 @@
from dataclasses import dataclass
import os
from typing import Optional
import torch
from text_generation_server.utils.import_utils import (
SYSTEM,
)
@dataclass
class GPTQWeight:
qweight: torch.Tensor
qzeros: torch.Tensor
scales: torch.Tensor
g_idx: Optional[torch.Tensor]
bits: int
groupsize: int
use_exllama: bool
def __post_init__(self):
if self.scales.dtype == torch.float:
self.scales = self.scales.half()
@property
def device(self) -> torch.device:
return self.qweight.device
try:
major, _minor = torch.cuda.get_device_capability()
except Exception:

View File

@ -1,3 +1,4 @@
from text_generation_server.utils.weights import GPTQWeight
import torch
from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params
@ -65,24 +66,25 @@ def create_exllama_buffers(max_total_tokens: int):
class Ex4bitLinear(torch.nn.Module):
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
def __init__(self, weight: GPTQWeight, bias):
super().__init__()
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
assert bits == 4
assert weight.bits == 4
self.device = qweight.device
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.g_idx = g_idx.cpu() if g_idx is not None else None
self.device = weight.qweight.device
self.qweight = weight.qweight
self.qzeros = weight.qzeros
self.scales = weight.scales
self.g_idx = weight.g_idx.cpu() if weight.g_idx is not None else None
self.bias = bias if bias is not None else None
if self.g_idx is not None and (
(self.g_idx == 0).all()
or torch.equal(
g_idx.cpu(),
weight.g_idx.cpu(),
torch.tensor(
[i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32
[i // weight.groupsize for i in range(weight.g_idx.shape[0])],
dtype=torch.int32,
),
)
):
@ -96,8 +98,8 @@ class Ex4bitLinear(torch.nn.Module):
self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index
)
self.height = qweight.shape[0] * 8
self.width = qweight.shape[1]
self.height = weight.qweight.shape[0] * 8
self.width = weight.qweight.shape[1]
# Infer groupsize from height of qzeros
self.groupsize = None
@ -105,7 +107,7 @@ class Ex4bitLinear(torch.nn.Module):
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
if self.groupsize is not None:
assert groupsize == self.groupsize
assert weight.groupsize == self.groupsize
# Handle act-order matrix
if self.g_idx is not None:

View File

@ -1,10 +1,15 @@
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from loguru import logger
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight
try:
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
except ImportError:
@ -15,6 +20,15 @@ except ImportError:
none_tensor = torch.empty((1, 1), device="meta")
@dataclass
class _ExtraTensors:
"""Additional generated quantizer tensors."""
q_group_map: Optional[torch.Tensor] = None
q_invperm: Optional[torch.Tensor] = None
q_perm: Optional[torch.Tensor] = None
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
"""Matrix multiplication, returns x @ q4"""
output_shape = x.shape[:-1] + (q4_width,)
@ -24,11 +38,7 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
return output.view(output_shape)
# Group map needed for irregular group sizes
def make_group_map(q_groups, num_qrows):
def make_group_map(q_groups: torch.Tensor, num_qrows: int):
gr = q_groups.tolist()
group_map = []
num_groups = len(gr) // 2
@ -50,72 +60,72 @@ def make_group_map(q_groups, num_qrows):
# Create Q matrix
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
def ext_make_q_matrix(
w: Exl2Weight | GPTQWeight,
extra: _ExtraTensors,
temp_dq,
key: Optional[str] = None,
):
"""
Create Q matrix
"""
# EXL2
# won't work as the moment because the tensors are not the same.
if "q_weight" in w:
w["q_scale_max"] /= 256
w["q_perm"] = w["q_perm"].short()
w["q_invperm"] = w["q_invperm"].short()
if "q_group_map" not in w:
w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0])
if isinstance(w, Exl2Weight):
extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0])
extra.q_perm = torch.argsort(w.q_invperm).short()
return make_q_matrix(
w["q_weight"],
w["q_perm"],
w["q_invperm"],
w["q_scale"],
w["q_scale_max"],
w["q_groups"],
w["q_group_map"],
w.q_weight,
extra.q_perm,
w.q_invperm,
w.q_scale,
w.q_scale_max,
w.q_groups,
extra.q_group_map,
none_tensor,
none_tensor,
none_tensor,
temp_dq,
)
# GPTQ
elif "qweight" in w:
if w["scales"].dtype == torch.float:
w["scales"] = w["scales"].half()
elif isinstance(w, GPTQWeight):
if w.scales.dtype == torch.float:
w.scales = w.scales.half()
# GPTQ with g_idx (act_order)
if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item():
w["q_perm"] = torch.empty(
(w["qweight"].shape[0] * 8,),
if w.g_idx is not None and not (w.g_idx == 0).all().item():
extra.q_perm = torch.empty(
(w.qweight.shape[0] * 8,),
dtype=torch.short,
device=w["qweight"].device,
device=w.qweight.device,
)
w["q_invperm"] = torch.empty_like(w["q_perm"])
extra.q_invperm = torch.empty_like(extra.q_perm)
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
return make_q_matrix(
w["qweight"],
w["q_perm"],
w["q_invperm"],
w.qweight,
extra.q_perm,
extra.q_invperm,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
w["qzeros"],
w["scales"],
w["g_idx"].cpu(),
w.qzeros,
w.scales,
w.g_idx.cpu(),
temp_dq,
)
# GPTQ without g_idx
else:
return make_q_matrix(
w["qweight"],
w.qweight,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
w["qzeros"],
w["scales"],
w.qzeros,
w.scales,
none_tensor,
temp_dq,
)
@ -124,7 +134,6 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
DEVICE = None
FIXED_BYTES = 0
LAYERS = []
@ -134,8 +143,13 @@ def set_device(device):
def create_exllama_buffers(max_total_tokens: int):
global FIXED_BYTES, LAYERS, DEVICE
temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES)
global LAYERS, DEVICE
# Find the size of the scratch space.
scratch_bytes = max(
layer.scratch_space_fixed(max_input_len=max_total_tokens) for layer in LAYERS
)
temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes)
for layer in LAYERS:
layer.post_init(temp_dq)
@ -146,49 +160,48 @@ class QuantLinear(nn.Module):
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
# def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
def __init__(
self,
weight: Exl2Weight | GPTQWeight,
bias: torch.Tensor,
):
super().__init__()
if bits != 4:
raise ValueError(
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization."
)
self.q_handle = None
self.q_tensors = None
self.bits = bits
self.maxq = 2**self.bits - 1
self.infeatures = qweight.shape[0] // self.bits * 32
self.outfeatures = qweight.shape[1]
self.q_tensors = weight
self.extra_tensors = _ExtraTensors()
if isinstance(weight, Exl2Weight):
self.infeatures = weight.q_invperm.shape[0]
self.outfeatures = weight.q_weight.shape[1]
elif isinstance(weight, GPTQWeight):
if weight.bits != 4:
raise ValueError(
f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization."
)
self.infeatures = weight.qweight.shape[0] // weight.bits * 32
self.outfeatures = weight.qweight.shape[1]
self.padding = -self.outfeatures % 32
self.outfeatures = self.outfeatures + self.padding
self.device = qweight.device
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.g_idx = g_idx
self.device = weight.device
self.bias = bias if bias is not None else None
self.group_size = groupsize
global FIXED_BYTES, LAYERS
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
global LAYERS
LAYERS.append(self)
def post_init(self, temp_dq):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None
self.q_tensors = {
"qweight": self.qweight,
"qzeros": self.qzeros,
"scales": self.scales,
"g_idx": self.g_idx,
}
device = self.q_tensors.device
assert device.type == "cuda"
assert device.index is not None
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
# We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us,
# and `Memory access fault by GPU node-2` will EAT you.
self.temp_dq = temp_dq
self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq)
self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq)
def forward(self, x, force_cuda=False):
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)

View File

@ -1,6 +1,9 @@
from typing import Optional
import torch
from torch.nn import functional as F
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight
if SYSTEM == "rocm":
try:
@ -151,15 +154,23 @@ def get_linear(weight, bias, quantize):
bias,
quant_type="nf4",
)
elif quantize == "exl2":
if not isinstance(weight, Exl2Weight):
raise NotImplementedError(
f"The passed weight is not `exl2` compatible, loader needs to be updated."
)
from text_generation_server.layers.gptq import ExllamaQuantLinear
linear = ExllamaQuantLinear(weight, bias)
elif quantize == "gptq":
try:
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
except Exception:
if not isinstance(weight, GPTQWeight):
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
)
if use_exllama:
if weight.use_exllama:
try:
from text_generation_server.layers.gptq import (
ExllamaQuantLinear,
@ -169,25 +180,21 @@ def get_linear(weight, bias, quantize):
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
)
linear = ExllamaQuantLinear(
qweight, qzeros, scales, g_idx, bias, bits, groupsize
)
linear = ExllamaQuantLinear(weight, bias)
else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear
linear = QuantLinear(
qweight,
qzeros,
scales,
g_idx,
weight.qweight,
weight.qzeros,
weight.scales,
weight.g_idx,
bias,
bits,
groupsize,
weight.bits,
weight.groupsize,
)
elif quantize == "awq":
try:
qweight, qzeros, scales, _, bits, groupsize, _ = weight
except Exception:
if not isinstance(weight, GPTQWeight):
raise NotImplementedError(
f"The passed weight is not `awq` compatible, loader needs to be updated."
)
@ -200,11 +207,11 @@ def get_linear(weight, bias, quantize):
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
linear = WQLinear(
w_bit=bits,
group_size=groupsize,
qweight=qweight,
qzeros=qzeros,
scales=scales,
w_bit=weight.bits,
group_size=weight.groupsize,
qweight=weight.qweight,
qzeros=weight.qzeros,
scales=weight.scales,
bias=bias is not None,
)
except ImportError:

View File

@ -1,7 +1,27 @@
import torch
from torch.nn import functional as F
from typing import List
from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear
from text_generation_server.layers.exl2 import Exl2Weight
class LayerConcat(torch.nn.Module):
"""
Apply multiple layers to the input and concatenate their
outputs.
"""
def __init__(self, layers: Iterable[torch.nn.Module], dim: int = -1):
"""
`dim` is the dimension along which layer outputs are concatenated.
"""
super().__init__()
self.layers = layers
self.dim = dim
def forward(self, x: torch.Tensor):
outputs = [layer(x) for layer in self.layers]
return torch.cat(outputs, self.dim)
class SuperLayer(torch.nn.Module):
@ -21,7 +41,16 @@ class TensorParallelHead(SuperLayer):
@staticmethod
def load(config, prefix: str, weights):
if weights.process_group.size() > 1:
if config.quantize == "exl2":
try:
# If the piece and LM head embeddings are shared, we have
# non-quantized weights...
weight = weights.get_tensor(f"{prefix}.weight")
except:
# ...otherwise they are quantized.
weight = weights.get_weights_col(prefix, config.quantize)
should_gather = weights.process_group.size() > 1
elif weights.process_group.size() > 1:
try:
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True
@ -37,8 +66,12 @@ class TensorParallelHead(SuperLayer):
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
if config.quantize in ["gptq", "awq", "eetq"]:
quantize = None
# See above, exl2 LM head can be quantized or not.
elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight):
quantize = None
else:
quantize = config.quantize
return TensorParallelHead(
get_linear(weight, bias=None, quantize=quantize),
process_group=weights.process_group,
@ -108,14 +141,27 @@ class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
return cls.load_multi(config, [prefix], weights, bias, dim=0)
weight = weights.get_weights_col(prefix, config.quantize)
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
return cls(linear)
@classmethod
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
if config.quantize == "exl2":
linears = []
for prefix in prefixes:
weight = weights.get_weights_col(prefix, config.quantize)
b = weights.get_tensor(f"{prefix}.bias") if bias else None
linears.append(get_linear(weight, b, config.quantize))
linear = LayerConcat(linears)
else:
weight = weights.get_multi_weights_col(
prefixes, quantize=config.quantize, dim=dim
)
if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
bias = torch.cat(b, dim=dim)

View File

@ -263,7 +263,7 @@ def get_model(
trust_remote_code: bool,
) -> Model:
if dtype is None:
if quantize in ["awq", "gptq"]:
if quantize in ["awq", "exl2", "gptq"]:
# These quantizers only work with float16 params.
dtype = torch.float16
else:
@ -402,12 +402,17 @@ def get_model(
quantization_config = config_dict.get("quantization_config", None)
if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None)
if method in {"gptq", "awq"}:
if method in {"gptq", "awq", "exl2"}:
logger.info(f"Auto selecting quantization method {method}")
quantize = method
else:
logger.info(f"Unknown quantization method {method}")
if quantize == "exl2" and sharded:
raise RuntimeError(
"Sharding is currently not supported with `exl2` quantization"
)
if model_type == MAMBA:
return Mamba(
model_id,
@ -881,6 +886,8 @@ def get_model(
raise NotImplementedError("4bit quantization is not supported for AutoModel")
elif quantize == "eetq":
raise NotImplementedError("Eetq quantization is not supported for AutoModel")
elif quantize == "exl2":
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
model_id,

View File

@ -21,6 +21,7 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any
from loguru import logger
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "xpu":
@ -256,7 +257,15 @@ def _load_gqa(config, prefix: str, weights):
else:
g_idx = None
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=use_exllama,
)
else:
qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight")
q = qkv_slice[q_start:q_stop]

View File

@ -395,7 +395,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
self.lm_head = SpeculativeHead.load(
config,
prefix=suffix if not prefix else f"{prefix}.suffix",
prefix=suffix if not prefix else f"{prefix}.{suffix}",
weights=weights,
)

View File

@ -102,45 +102,6 @@ class MistralConfig(PretrainedConfig):
)
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
)
class MistralAttention(torch.nn.Module):
def __init__(
self,
@ -175,7 +136,13 @@ class MistralAttention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.query_key_value = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config,

View File

@ -5,6 +5,7 @@ from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -90,8 +91,15 @@ def _load_multi_mqa_gptq(
from text_generation_server.layers.gptq import HAS_EXLLAMA
use_exllama = HAS_EXLLAMA
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=HAS_EXLLAMA,
)
if bias:
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")

View File

@ -67,7 +67,7 @@ class FlashLlama(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "exl2"]:
weights._set_gptq_params(model_id, revision)
prefix = ""

View File

@ -89,7 +89,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
if self.quantize == "gptq":
if self.quantize in {"exl2", "gptq"}:
try:
# When using GPTQ, Exllama kernels need some global kernels
# For which we have the finale shapes only after the model has loaded

View File

@ -1,11 +1,14 @@
from dataclasses import dataclass, field
import os
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from typing import List, Dict, Optional, Set, Tuple, Union
from safetensors import safe_open, SafetensorError
import torch
from loguru import logger
from huggingface_hub import hf_hub_download
import json
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils.log import log_once
@ -76,8 +79,9 @@ class Weights:
f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name)
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32
if tensor.dtype not in [torch.int32, torch.int64]:
# u4 which are disguised as int32. Exl2 uses int16
# as well.
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
tensor = tensor.to(dtype=self.dtype)
if to_device:
tensor = tensor.to(device=self.device)
@ -102,8 +106,8 @@ class Weights:
else:
raise NotImplementedError("Let's make that generic when needed")
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32
if tensor.dtype != torch.int32:
# u4 which are disguised as int32. exl2 uses int16.
if tensor.dtype not in (torch.int16, torch.int32):
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
@ -183,7 +187,15 @@ class Weights:
else:
g_idx = None
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=False,
)
else:
slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0]
@ -207,8 +219,34 @@ class Weights:
weight = weight.to(dtype=self.dtype)
return weight
def get_weights_col(self, prefix: str, quantize: str):
if quantize == "exl2":
try:
q_weight = self.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = self.get_tensor(f"{prefix}.q_scale")
q_invperm = self.get_tensor(f"{prefix}.q_invperm")
q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
q_groups = self.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
return self.get_multi_weights_col([prefix], quantize, 0)
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize in ["gptq", "awq"]:
if quantize == "exl2":
raise ValueError("get_multi_weights_col is not supported for exl2")
elif quantize in ["gptq", "awq"]:
try:
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
@ -259,7 +297,15 @@ class Weights:
else:
g_idx = None
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=use_exllama,
)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
@ -282,7 +328,28 @@ class Weights:
return tensor
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
if quantize == "exl2":
try:
q_weight = self.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = self.get_tensor(f"{prefix}.q_scale")
q_invperm = self.get_tensor(f"{prefix}.q_invperm")
q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
q_groups = self.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
elif quantize == "gptq":
use_exllama = True
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
@ -363,7 +430,15 @@ class Weights:
// groupsize
).to(dtype=torch.int32)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=use_exllama,
)
elif quantize == "awq":
bits, groupsize, _, _ = self._get_gptq_params()
@ -379,7 +454,15 @@ class Weights:
g_idx = None
use_exllama = False
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=use_exllama,
)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight