Add support for exl2 quantization
Mostly straightforward, changes to existing code: * Wrap quantizer parameters in a small wrapper to avoid passing around untyped tuples and needing to repack them as a dict. * Move scratch space computation to warmup, because we need the maximum input sequence length to avoid allocating huge scratch buffers that OOM.
This commit is contained in:
parent
cbced7f0f9
commit
36dd16017c
|
@ -62,6 +62,7 @@ Options:
|
|||
Possible values:
|
||||
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
|
||||
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
|
||||
- exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1)
|
||||
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
|
||||
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
|
||||
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
## What is Guidance?
|
||||
|
||||
|
||||
Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format. A prominent example is JSON grammar, where the model is forced to output valid JSON.
|
||||
|
||||
## How is it used?
|
||||
|
|
|
@ -38,6 +38,7 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
|
|||
|
||||
class ResponseComparator(JSONSnapshotExtension):
|
||||
rtol = 0.2
|
||||
ignore_logprob = False
|
||||
|
||||
def serialize(
|
||||
self,
|
||||
|
@ -95,7 +96,10 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||
return (
|
||||
token.id == other.id
|
||||
and token.text == other.text
|
||||
and math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
|
||||
and (
|
||||
self.ignore_logprob
|
||||
or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
|
||||
)
|
||||
and token.special == other.special
|
||||
)
|
||||
|
||||
|
@ -105,8 +109,11 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||
prefill_token.id == other.id
|
||||
and prefill_token.text == other.text
|
||||
and (
|
||||
math.isclose(
|
||||
prefill_token.logprob, other.logprob, rel_tol=self.rtol
|
||||
self.ignore_logprob
|
||||
or math.isclose(
|
||||
prefill_token.logprob,
|
||||
other.logprob,
|
||||
rel_tol=self.rtol,
|
||||
)
|
||||
if prefill_token.logprob is not None
|
||||
else prefill_token.logprob == other.logprob
|
||||
|
@ -223,6 +230,10 @@ class GenerousResponseComparator(ResponseComparator):
|
|||
rtol = 0.75
|
||||
|
||||
|
||||
class IgnoreLogProbResponseComparator(ResponseComparator):
|
||||
ignore_logprob = True
|
||||
|
||||
|
||||
class LauncherHandle:
|
||||
def __init__(self, port: int):
|
||||
self.client = AsyncClient(f"http://localhost:{port}")
|
||||
|
@ -274,6 +285,11 @@ def generous_response_snapshot(snapshot):
|
|||
return snapshot.use_extension(GenerousResponseComparator)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ignore_logprob_response_snapshot(snapshot):
|
||||
return snapshot.use_extension(IgnoreLogProbResponseComparator)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
loop = asyncio.get_event_loop()
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -11.4375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25,
|
||||
"logprob": -2.9316406,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 330,
|
||||
"logprob": -3.5136719,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 489,
|
||||
"logprob": -0.7783203,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -1.2314453,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 489,
|
||||
"logprob": -2.0019531,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 2990,
|
||||
"logprob": -1.5009766,
|
||||
"special": false,
|
||||
"text": " \"\\"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.057434082,
|
||||
"special": false,
|
||||
"text": "n"
|
||||
},
|
||||
{
|
||||
"id": 702,
|
||||
"logprob": -1.4912109,
|
||||
"special": false,
|
||||
"text": "\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -1.2636719,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 557,
|
||||
"logprob": -2.4042969,
|
||||
"special": false,
|
||||
"text": " }\n\n"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -11.453125,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.9980469,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 578,
|
||||
"logprob": -0.15795898,
|
||||
"special": false,
|
||||
"text": " The"
|
||||
},
|
||||
{
|
||||
"id": 3622,
|
||||
"logprob": -1.0458984,
|
||||
"special": false,
|
||||
"text": " server"
|
||||
},
|
||||
{
|
||||
"id": 31680,
|
||||
"logprob": -1.3623047,
|
||||
"special": false,
|
||||
"text": " responds"
|
||||
},
|
||||
{
|
||||
"id": 449,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " with"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 330,
|
||||
"logprob": -0.5678711,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 1049,
|
||||
"logprob": -0.12322998,
|
||||
"special": false,
|
||||
"text": "200"
|
||||
},
|
||||
{
|
||||
"id": 10619,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " OK"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\""
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request. The server responds with a \"200 OK\""
|
||||
}
|
|
@ -0,0 +1,338 @@
|
|||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -11.453125,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25,
|
||||
"logprob": -2.9785156,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 330,
|
||||
"logprob": -3.4941406,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 489,
|
||||
"logprob": -0.79345703,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -1.2324219,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 489,
|
||||
"logprob": -1.9794922,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 2990,
|
||||
"logprob": -1.4892578,
|
||||
"special": false,
|
||||
"text": " \"\\"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.058258057,
|
||||
"special": false,
|
||||
"text": "n"
|
||||
},
|
||||
{
|
||||
"id": 702,
|
||||
"logprob": -1.4892578,
|
||||
"special": false,
|
||||
"text": "\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -1.2783203,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 557,
|
||||
"logprob": -2.3945312,
|
||||
"special": false,
|
||||
"text": " }\n\n"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -11.40625,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25,
|
||||
"logprob": -2.9433594,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 330,
|
||||
"logprob": -3.4726562,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 489,
|
||||
"logprob": -0.8022461,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -1.2509766,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 489,
|
||||
"logprob": -1.984375,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 2990,
|
||||
"logprob": -1.4677734,
|
||||
"special": false,
|
||||
"text": " \"\\"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.059173584,
|
||||
"special": false,
|
||||
"text": "n"
|
||||
},
|
||||
{
|
||||
"id": 702,
|
||||
"logprob": -1.4990234,
|
||||
"special": false,
|
||||
"text": "\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -1.2822266,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 557,
|
||||
"logprob": -2.3867188,
|
||||
"special": false,
|
||||
"text": " }\n\n"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -11.421875,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25,
|
||||
"logprob": -2.9511719,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 330,
|
||||
"logprob": -3.46875,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 489,
|
||||
"logprob": -0.77490234,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -1.2558594,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 489,
|
||||
"logprob": -1.984375,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 2990,
|
||||
"logprob": -1.4990234,
|
||||
"special": false,
|
||||
"text": " \"\\"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.059143066,
|
||||
"special": false,
|
||||
"text": "n"
|
||||
},
|
||||
{
|
||||
"id": 702,
|
||||
"logprob": -1.4941406,
|
||||
"special": false,
|
||||
"text": "\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -1.2578125,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 557,
|
||||
"logprob": -2.3964844,
|
||||
"special": false,
|
||||
"text": " }\n\n"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -11.4140625,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25,
|
||||
"logprob": -2.9101562,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 330,
|
||||
"logprob": -3.5039062,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 489,
|
||||
"logprob": -0.8076172,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -1.2236328,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 489,
|
||||
"logprob": -1.9853516,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 2990,
|
||||
"logprob": -1.4892578,
|
||||
"special": false,
|
||||
"text": " \"\\"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.056671143,
|
||||
"special": false,
|
||||
"text": "n"
|
||||
},
|
||||
{
|
||||
"id": 702,
|
||||
"logprob": -1.5107422,
|
||||
"special": false,
|
||||
"text": "\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -1.2597656,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 557,
|
||||
"logprob": -2.4042969,
|
||||
"special": false,
|
||||
"text": " }\n\n"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
|
||||
}
|
||||
]
|
|
@ -0,0 +1,73 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_exl2_handle(launcher):
|
||||
with launcher(
|
||||
"turboderp/Llama-3-8B-Instruct-exl2",
|
||||
revision="2.5bpw",
|
||||
# Set max input length to avoid OOM due to extremely large
|
||||
# scratch buffer.
|
||||
max_input_length=1024,
|
||||
num_shard=1,
|
||||
quantize="exl2",
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_exl2(flash_llama_exl2_handle):
|
||||
await flash_llama_exl2_handle.health(300)
|
||||
return flash_llama_exl2_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
|
||||
response = await flash_llama_exl2.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == ignore_logprob_response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_exl2_all_params(
|
||||
flash_llama_exl2, ignore_logprob_response_snapshot
|
||||
):
|
||||
response = await flash_llama_exl2.generate(
|
||||
"Test request",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert (
|
||||
response.generated_text == 'Test request. The server responds with a "200 OK"'
|
||||
)
|
||||
assert response == ignore_logprob_response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_exl2_load(
|
||||
flash_llama_exl2, generate_load, ignore_logprob_response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_llama_exl2, "Test request", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == ignore_logprob_response_snapshot
|
|
@ -55,6 +55,10 @@ enum Quantization {
|
|||
/// Should be a drop-in replacement to bitsandbytes with much better performance.
|
||||
/// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
|
||||
Eetq,
|
||||
/// Variable bit quantization. Requires a specific EXL2 quantized model:
|
||||
/// <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does
|
||||
/// not support tensor parallelism (num_shard > 1).
|
||||
Exl2,
|
||||
/// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>.
|
||||
/// text-generation-inference will use exllama (faster) kernels wherever possible, and use
|
||||
/// triton kernel (wider support) when it's not.
|
||||
|
@ -95,6 +99,9 @@ impl std::fmt::Display for Quantization {
|
|||
Quantization::BitsandbytesFP4 => {
|
||||
write!(f, "bitsandbytes-fp4")
|
||||
}
|
||||
Quantization::Exl2 => {
|
||||
write!(f, "exl2")
|
||||
}
|
||||
Quantization::Gptq => {
|
||||
write!(f, "gptq")
|
||||
}
|
||||
|
@ -1461,6 +1468,11 @@ fn main() -> Result<(), LauncherError> {
|
|||
|
||||
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
|
||||
if num_shard > 1 {
|
||||
if matches!(args.quantize, Some(Quantization::Exl2)) {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"Sharding is currently not supported with `exl2` quantization".into(),
|
||||
));
|
||||
}
|
||||
tracing::info!("Sharding model on {num_shard} processes");
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ class Quantization(str, Enum):
|
|||
gptq = "gptq"
|
||||
awq = "awq"
|
||||
eetq = "eetq"
|
||||
exl2 = "exl2"
|
||||
fp8 = "fp8"
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
import torch
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Exl2Weight:
|
||||
"""
|
||||
Exllama2 exl2 quantized weights.
|
||||
"""
|
||||
|
||||
q_weight: torch.Tensor
|
||||
q_scale: torch.Tensor
|
||||
q_invperm: torch.Tensor
|
||||
q_scale_max: torch.Tensor
|
||||
q_groups: torch.Tensor
|
||||
|
||||
def __post_init__(self):
|
||||
self.q_scale_max /= 256
|
||||
self.q_invperm = self.q_invperm.short()
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.q_weight.device
|
|
@ -1,9 +1,31 @@
|
|||
from dataclasses import dataclass
|
||||
import os
|
||||
from typing import Optional
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import (
|
||||
SYSTEM,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQWeight:
|
||||
qweight: torch.Tensor
|
||||
qzeros: torch.Tensor
|
||||
scales: torch.Tensor
|
||||
g_idx: Optional[torch.Tensor]
|
||||
bits: int
|
||||
groupsize: int
|
||||
use_exllama: bool
|
||||
|
||||
def __post_init__(self):
|
||||
if self.scales.dtype == torch.float:
|
||||
self.scales = self.scales.half()
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.qweight.device
|
||||
|
||||
|
||||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
except Exception:
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from text_generation_server.utils.weights import GPTQWeight
|
||||
import torch
|
||||
from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params
|
||||
|
||||
|
@ -65,24 +66,25 @@ def create_exllama_buffers(max_total_tokens: int):
|
|||
class Ex4bitLinear(torch.nn.Module):
|
||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||
|
||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||
def __init__(self, weight: GPTQWeight, bias):
|
||||
super().__init__()
|
||||
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
|
||||
assert bits == 4
|
||||
assert weight.bits == 4
|
||||
|
||||
self.device = qweight.device
|
||||
self.qweight = qweight
|
||||
self.qzeros = qzeros
|
||||
self.scales = scales
|
||||
self.g_idx = g_idx.cpu() if g_idx is not None else None
|
||||
self.device = weight.qweight.device
|
||||
self.qweight = weight.qweight
|
||||
self.qzeros = weight.qzeros
|
||||
self.scales = weight.scales
|
||||
self.g_idx = weight.g_idx.cpu() if weight.g_idx is not None else None
|
||||
self.bias = bias if bias is not None else None
|
||||
|
||||
if self.g_idx is not None and (
|
||||
(self.g_idx == 0).all()
|
||||
or torch.equal(
|
||||
g_idx.cpu(),
|
||||
weight.g_idx.cpu(),
|
||||
torch.tensor(
|
||||
[i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32
|
||||
[i // weight.groupsize for i in range(weight.g_idx.shape[0])],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
)
|
||||
):
|
||||
|
@ -96,8 +98,8 @@ class Ex4bitLinear(torch.nn.Module):
|
|||
self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index
|
||||
)
|
||||
|
||||
self.height = qweight.shape[0] * 8
|
||||
self.width = qweight.shape[1]
|
||||
self.height = weight.qweight.shape[0] * 8
|
||||
self.width = weight.qweight.shape[1]
|
||||
|
||||
# Infer groupsize from height of qzeros
|
||||
self.groupsize = None
|
||||
|
@ -105,7 +107,7 @@ class Ex4bitLinear(torch.nn.Module):
|
|||
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
|
||||
|
||||
if self.groupsize is not None:
|
||||
assert groupsize == self.groupsize
|
||||
assert weight.groupsize == self.groupsize
|
||||
|
||||
# Handle act-order matrix
|
||||
if self.g_idx is not None:
|
||||
|
|
|
@ -1,10 +1,15 @@
|
|||
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
|
||||
try:
|
||||
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
||||
except ImportError:
|
||||
|
@ -15,6 +20,15 @@ except ImportError:
|
|||
none_tensor = torch.empty((1, 1), device="meta")
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ExtraTensors:
|
||||
"""Additional generated quantizer tensors."""
|
||||
|
||||
q_group_map: Optional[torch.Tensor] = None
|
||||
q_invperm: Optional[torch.Tensor] = None
|
||||
q_perm: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
||||
"""Matrix multiplication, returns x @ q4"""
|
||||
output_shape = x.shape[:-1] + (q4_width,)
|
||||
|
@ -24,11 +38,7 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
|||
return output.view(output_shape)
|
||||
|
||||
|
||||
# Group map needed for irregular group sizes
|
||||
|
||||
|
||||
def make_group_map(q_groups, num_qrows):
|
||||
|
||||
def make_group_map(q_groups: torch.Tensor, num_qrows: int):
|
||||
gr = q_groups.tolist()
|
||||
group_map = []
|
||||
num_groups = len(gr) // 2
|
||||
|
@ -50,72 +60,72 @@ def make_group_map(q_groups, num_qrows):
|
|||
# Create Q matrix
|
||||
|
||||
|
||||
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||
def ext_make_q_matrix(
|
||||
w: Exl2Weight | GPTQWeight,
|
||||
extra: _ExtraTensors,
|
||||
temp_dq,
|
||||
key: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Create Q matrix
|
||||
"""
|
||||
# EXL2
|
||||
# won't work as the moment because the tensors are not the same.
|
||||
if "q_weight" in w:
|
||||
w["q_scale_max"] /= 256
|
||||
w["q_perm"] = w["q_perm"].short()
|
||||
w["q_invperm"] = w["q_invperm"].short()
|
||||
|
||||
if "q_group_map" not in w:
|
||||
w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0])
|
||||
if isinstance(w, Exl2Weight):
|
||||
extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0])
|
||||
extra.q_perm = torch.argsort(w.q_invperm).short()
|
||||
|
||||
return make_q_matrix(
|
||||
w["q_weight"],
|
||||
w["q_perm"],
|
||||
w["q_invperm"],
|
||||
w["q_scale"],
|
||||
w["q_scale_max"],
|
||||
w["q_groups"],
|
||||
w["q_group_map"],
|
||||
w.q_weight,
|
||||
extra.q_perm,
|
||||
w.q_invperm,
|
||||
w.q_scale,
|
||||
w.q_scale_max,
|
||||
w.q_groups,
|
||||
extra.q_group_map,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
temp_dq,
|
||||
)
|
||||
# GPTQ
|
||||
elif "qweight" in w:
|
||||
if w["scales"].dtype == torch.float:
|
||||
w["scales"] = w["scales"].half()
|
||||
elif isinstance(w, GPTQWeight):
|
||||
if w.scales.dtype == torch.float:
|
||||
w.scales = w.scales.half()
|
||||
|
||||
# GPTQ with g_idx (act_order)
|
||||
if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item():
|
||||
w["q_perm"] = torch.empty(
|
||||
(w["qweight"].shape[0] * 8,),
|
||||
if w.g_idx is not None and not (w.g_idx == 0).all().item():
|
||||
extra.q_perm = torch.empty(
|
||||
(w.qweight.shape[0] * 8,),
|
||||
dtype=torch.short,
|
||||
device=w["qweight"].device,
|
||||
device=w.qweight.device,
|
||||
)
|
||||
w["q_invperm"] = torch.empty_like(w["q_perm"])
|
||||
extra.q_invperm = torch.empty_like(extra.q_perm)
|
||||
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
|
||||
return make_q_matrix(
|
||||
w["qweight"],
|
||||
w["q_perm"],
|
||||
w["q_invperm"],
|
||||
w.qweight,
|
||||
extra.q_perm,
|
||||
extra.q_invperm,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
w["qzeros"],
|
||||
w["scales"],
|
||||
w["g_idx"].cpu(),
|
||||
w.qzeros,
|
||||
w.scales,
|
||||
w.g_idx.cpu(),
|
||||
temp_dq,
|
||||
)
|
||||
# GPTQ without g_idx
|
||||
else:
|
||||
return make_q_matrix(
|
||||
w["qweight"],
|
||||
w.qweight,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
w["qzeros"],
|
||||
w["scales"],
|
||||
w.qzeros,
|
||||
w.scales,
|
||||
none_tensor,
|
||||
temp_dq,
|
||||
)
|
||||
|
@ -124,7 +134,6 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
|||
|
||||
|
||||
DEVICE = None
|
||||
FIXED_BYTES = 0
|
||||
LAYERS = []
|
||||
|
||||
|
||||
|
@ -134,8 +143,13 @@ def set_device(device):
|
|||
|
||||
|
||||
def create_exllama_buffers(max_total_tokens: int):
|
||||
global FIXED_BYTES, LAYERS, DEVICE
|
||||
temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES)
|
||||
global LAYERS, DEVICE
|
||||
|
||||
# Find the size of the scratch space.
|
||||
scratch_bytes = max(
|
||||
layer.scratch_space_fixed(max_input_len=max_total_tokens) for layer in LAYERS
|
||||
)
|
||||
temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes)
|
||||
|
||||
for layer in LAYERS:
|
||||
layer.post_init(temp_dq)
|
||||
|
@ -146,49 +160,48 @@ class QuantLinear(nn.Module):
|
|||
|
||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||
|
||||
# def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
|
||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||
def __init__(
|
||||
self,
|
||||
weight: Exl2Weight | GPTQWeight,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
super().__init__()
|
||||
if bits != 4:
|
||||
raise ValueError(
|
||||
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization."
|
||||
)
|
||||
|
||||
self.q_handle = None
|
||||
self.q_tensors = None
|
||||
self.bits = bits
|
||||
self.maxq = 2**self.bits - 1
|
||||
self.infeatures = qweight.shape[0] // self.bits * 32
|
||||
self.outfeatures = qweight.shape[1]
|
||||
self.q_tensors = weight
|
||||
self.extra_tensors = _ExtraTensors()
|
||||
|
||||
if isinstance(weight, Exl2Weight):
|
||||
self.infeatures = weight.q_invperm.shape[0]
|
||||
self.outfeatures = weight.q_weight.shape[1]
|
||||
elif isinstance(weight, GPTQWeight):
|
||||
if weight.bits != 4:
|
||||
raise ValueError(
|
||||
f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization."
|
||||
)
|
||||
|
||||
self.infeatures = weight.qweight.shape[0] // weight.bits * 32
|
||||
self.outfeatures = weight.qweight.shape[1]
|
||||
|
||||
self.padding = -self.outfeatures % 32
|
||||
self.outfeatures = self.outfeatures + self.padding
|
||||
|
||||
self.device = qweight.device
|
||||
self.qweight = qweight
|
||||
self.qzeros = qzeros
|
||||
self.scales = scales
|
||||
self.g_idx = g_idx
|
||||
self.device = weight.device
|
||||
self.bias = bias if bias is not None else None
|
||||
self.group_size = groupsize
|
||||
|
||||
global FIXED_BYTES, LAYERS
|
||||
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
|
||||
global LAYERS
|
||||
LAYERS.append(self)
|
||||
|
||||
def post_init(self, temp_dq):
|
||||
assert self.qweight.device.type == "cuda"
|
||||
assert self.qweight.device.index is not None
|
||||
self.q_tensors = {
|
||||
"qweight": self.qweight,
|
||||
"qzeros": self.qzeros,
|
||||
"scales": self.scales,
|
||||
"g_idx": self.g_idx,
|
||||
}
|
||||
device = self.q_tensors.device
|
||||
assert device.type == "cuda"
|
||||
assert device.index is not None
|
||||
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
|
||||
|
||||
# We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us,
|
||||
# and `Memory access fault by GPU node-2` will EAT you.
|
||||
self.temp_dq = temp_dq
|
||||
self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq)
|
||||
self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq)
|
||||
|
||||
def forward(self, x, force_cuda=False):
|
||||
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
from typing import Optional
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
try:
|
||||
|
@ -151,15 +154,23 @@ def get_linear(weight, bias, quantize):
|
|||
bias,
|
||||
quant_type="nf4",
|
||||
)
|
||||
elif quantize == "exl2":
|
||||
if not isinstance(weight, Exl2Weight):
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `exl2` compatible, loader needs to be updated."
|
||||
)
|
||||
|
||||
from text_generation_server.layers.gptq import ExllamaQuantLinear
|
||||
|
||||
linear = ExllamaQuantLinear(weight, bias)
|
||||
|
||||
elif quantize == "gptq":
|
||||
try:
|
||||
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
|
||||
except Exception:
|
||||
if not isinstance(weight, GPTQWeight):
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||
)
|
||||
|
||||
if use_exllama:
|
||||
if weight.use_exllama:
|
||||
try:
|
||||
from text_generation_server.layers.gptq import (
|
||||
ExllamaQuantLinear,
|
||||
|
@ -169,25 +180,21 @@ def get_linear(weight, bias, quantize):
|
|||
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
|
||||
)
|
||||
|
||||
linear = ExllamaQuantLinear(
|
||||
qweight, qzeros, scales, g_idx, bias, bits, groupsize
|
||||
)
|
||||
linear = ExllamaQuantLinear(weight, bias)
|
||||
else:
|
||||
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
||||
|
||||
linear = QuantLinear(
|
||||
qweight,
|
||||
qzeros,
|
||||
scales,
|
||||
g_idx,
|
||||
weight.qweight,
|
||||
weight.qzeros,
|
||||
weight.scales,
|
||||
weight.g_idx,
|
||||
bias,
|
||||
bits,
|
||||
groupsize,
|
||||
weight.bits,
|
||||
weight.groupsize,
|
||||
)
|
||||
elif quantize == "awq":
|
||||
try:
|
||||
qweight, qzeros, scales, _, bits, groupsize, _ = weight
|
||||
except Exception:
|
||||
if not isinstance(weight, GPTQWeight):
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
||||
)
|
||||
|
@ -200,11 +207,11 @@ def get_linear(weight, bias, quantize):
|
|||
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
|
||||
|
||||
linear = WQLinear(
|
||||
w_bit=bits,
|
||||
group_size=groupsize,
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
w_bit=weight.bits,
|
||||
group_size=weight.groupsize,
|
||||
qweight=weight.qweight,
|
||||
qzeros=weight.qzeros,
|
||||
scales=weight.scales,
|
||||
bias=bias is not None,
|
||||
)
|
||||
except ImportError:
|
||||
|
|
|
@ -1,7 +1,27 @@
|
|||
import torch
|
||||
from torch.nn import functional as F
|
||||
from typing import List
|
||||
from typing import Iterable, List
|
||||
from text_generation_server.layers.linear import get_linear, FastLinear
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
|
||||
|
||||
class LayerConcat(torch.nn.Module):
|
||||
"""
|
||||
Apply multiple layers to the input and concatenate their
|
||||
outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, layers: Iterable[torch.nn.Module], dim: int = -1):
|
||||
"""
|
||||
`dim` is the dimension along which layer outputs are concatenated.
|
||||
"""
|
||||
super().__init__()
|
||||
self.layers = layers
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
outputs = [layer(x) for layer in self.layers]
|
||||
return torch.cat(outputs, self.dim)
|
||||
|
||||
|
||||
class SuperLayer(torch.nn.Module):
|
||||
|
@ -21,7 +41,16 @@ class TensorParallelHead(SuperLayer):
|
|||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
if weights.process_group.size() > 1:
|
||||
if config.quantize == "exl2":
|
||||
try:
|
||||
# If the piece and LM head embeddings are shared, we have
|
||||
# non-quantized weights...
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
except:
|
||||
# ...otherwise they are quantized.
|
||||
weight = weights.get_weights_col(prefix, config.quantize)
|
||||
should_gather = weights.process_group.size() > 1
|
||||
elif weights.process_group.size() > 1:
|
||||
try:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
should_gather = True
|
||||
|
@ -37,8 +66,12 @@ class TensorParallelHead(SuperLayer):
|
|||
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
|
||||
if config.quantize in ["gptq", "awq", "eetq"]:
|
||||
quantize = None
|
||||
# See above, exl2 LM head can be quantized or not.
|
||||
elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight):
|
||||
quantize = None
|
||||
else:
|
||||
quantize = config.quantize
|
||||
|
||||
return TensorParallelHead(
|
||||
get_linear(weight, bias=None, quantize=quantize),
|
||||
process_group=weights.process_group,
|
||||
|
@ -108,14 +141,27 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
return cls.load_multi(config, [prefix], weights, bias, dim=0)
|
||||
weight = weights.get_weights_col(prefix, config.quantize)
|
||||
if bias:
|
||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
return cls(linear)
|
||||
|
||||
@classmethod
|
||||
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
||||
if config.quantize == "exl2":
|
||||
linears = []
|
||||
for prefix in prefixes:
|
||||
weight = weights.get_weights_col(prefix, config.quantize)
|
||||
b = weights.get_tensor(f"{prefix}.bias") if bias else None
|
||||
linears.append(get_linear(weight, b, config.quantize))
|
||||
linear = LayerConcat(linears)
|
||||
else:
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes, quantize=config.quantize, dim=dim
|
||||
)
|
||||
|
||||
if bias:
|
||||
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
||||
bias = torch.cat(b, dim=dim)
|
||||
|
|
|
@ -263,7 +263,7 @@ def get_model(
|
|||
trust_remote_code: bool,
|
||||
) -> Model:
|
||||
if dtype is None:
|
||||
if quantize in ["awq", "gptq"]:
|
||||
if quantize in ["awq", "exl2", "gptq"]:
|
||||
# These quantizers only work with float16 params.
|
||||
dtype = torch.float16
|
||||
else:
|
||||
|
@ -402,12 +402,17 @@ def get_model(
|
|||
quantization_config = config_dict.get("quantization_config", None)
|
||||
if quantization_config is not None and quantize is None:
|
||||
method = quantization_config.get("quant_method", None)
|
||||
if method in {"gptq", "awq"}:
|
||||
if method in {"gptq", "awq", "exl2"}:
|
||||
logger.info(f"Auto selecting quantization method {method}")
|
||||
quantize = method
|
||||
else:
|
||||
logger.info(f"Unknown quantization method {method}")
|
||||
|
||||
if quantize == "exl2" and sharded:
|
||||
raise RuntimeError(
|
||||
"Sharding is currently not supported with `exl2` quantization"
|
||||
)
|
||||
|
||||
if model_type == MAMBA:
|
||||
return Mamba(
|
||||
model_id,
|
||||
|
@ -881,6 +886,8 @@ def get_model(
|
|||
raise NotImplementedError("4bit quantization is not supported for AutoModel")
|
||||
elif quantize == "eetq":
|
||||
raise NotImplementedError("Eetq quantization is not supported for AutoModel")
|
||||
elif quantize == "exl2":
|
||||
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
return CausalLM(
|
||||
model_id,
|
||||
|
|
|
@ -21,6 +21,7 @@ from transformers.activations import ACT2FN
|
|||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from loguru import logger
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM != "xpu":
|
||||
|
@ -256,7 +257,15 @@ def _load_gqa(config, prefix: str, weights):
|
|||
else:
|
||||
g_idx = None
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
weight = GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=bits,
|
||||
groupsize=groupsize,
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
else:
|
||||
qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight")
|
||||
q = qkv_slice[q_start:q_stop]
|
||||
|
|
|
@ -395,7 +395,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix=suffix if not prefix else f"{prefix}.suffix",
|
||||
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
|
|
|
@ -102,45 +102,6 @@ class MistralConfig(PretrainedConfig):
|
|||
)
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if config.quantize not in ["gptq", "awq"]:
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||
assert list(weight.shape) == [
|
||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=None, quantize=config.quantize)
|
||||
)
|
||||
|
||||
|
||||
class MistralAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -175,7 +136,13 @@ class MistralAttention(torch.nn.Module):
|
|||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
|
|
|
@ -5,6 +5,7 @@ from torch import nn
|
|||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
|
@ -90,8 +91,15 @@ def _load_multi_mqa_gptq(
|
|||
|
||||
from text_generation_server.layers.gptq import HAS_EXLLAMA
|
||||
|
||||
use_exllama = HAS_EXLLAMA
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
weight = GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=bits,
|
||||
groupsize=groupsize,
|
||||
use_exllama=HAS_EXLLAMA,
|
||||
)
|
||||
|
||||
if bias:
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
||||
|
|
|
@ -67,7 +67,7 @@ class FlashLlama(FlashCausalLM):
|
|||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
if config.quantize in ["gptq", "awq", "exl2"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
prefix = ""
|
||||
|
|
|
@ -89,7 +89,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||
|
||||
async def Warmup(self, request, context):
|
||||
if self.quantize == "gptq":
|
||||
if self.quantize in {"exl2", "gptq"}:
|
||||
try:
|
||||
# When using GPTQ, Exllama kernels need some global kernels
|
||||
# For which we have the finale shapes only after the model has loaded
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
from dataclasses import dataclass, field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from typing import List, Dict, Optional, Set, Tuple, Union
|
||||
from safetensors import safe_open, SafetensorError
|
||||
import torch
|
||||
from loguru import logger
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
|
||||
|
@ -76,8 +79,9 @@ class Weights:
|
|||
f = self._get_handle(filename)
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
# Special case for gptq which shouldn't convert
|
||||
# u4 which are disguised as int32
|
||||
if tensor.dtype not in [torch.int32, torch.int64]:
|
||||
# u4 which are disguised as int32. Exl2 uses int16
|
||||
# as well.
|
||||
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
if to_device:
|
||||
tensor = tensor.to(device=self.device)
|
||||
|
@ -102,8 +106,8 @@ class Weights:
|
|||
else:
|
||||
raise NotImplementedError("Let's make that generic when needed")
|
||||
# Special case for gptq which shouldn't convert
|
||||
# u4 which are disguised as int32
|
||||
if tensor.dtype != torch.int32:
|
||||
# u4 which are disguised as int32. exl2 uses int16.
|
||||
if tensor.dtype not in (torch.int16, torch.int32):
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
@ -183,7 +187,15 @@ class Weights:
|
|||
else:
|
||||
g_idx = None
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
weight = GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=bits,
|
||||
groupsize=groupsize,
|
||||
use_exllama=False,
|
||||
)
|
||||
else:
|
||||
slice_ = self._get_slice(f"{prefix}.weight")
|
||||
total_size = slice_.get_shape()[0]
|
||||
|
@ -207,8 +219,34 @@ class Weights:
|
|||
weight = weight.to(dtype=self.dtype)
|
||||
return weight
|
||||
|
||||
def get_weights_col(self, prefix: str, quantize: str):
|
||||
if quantize == "exl2":
|
||||
try:
|
||||
q_weight = self.get_tensor(f"{prefix}.q_weight")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
q_scale = self.get_tensor(f"{prefix}.q_scale")
|
||||
q_invperm = self.get_tensor(f"{prefix}.q_invperm")
|
||||
q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
|
||||
q_groups = self.get_tensor(f"{prefix}.q_groups")
|
||||
|
||||
return Exl2Weight(
|
||||
q_weight=q_weight,
|
||||
q_scale=q_scale,
|
||||
q_invperm=q_invperm,
|
||||
q_scale_max=q_scale_max,
|
||||
q_groups=q_groups,
|
||||
)
|
||||
|
||||
return self.get_multi_weights_col([prefix], quantize, 0)
|
||||
|
||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||
if quantize in ["gptq", "awq"]:
|
||||
if quantize == "exl2":
|
||||
raise ValueError("get_multi_weights_col is not supported for exl2")
|
||||
elif quantize in ["gptq", "awq"]:
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
|
@ -259,7 +297,15 @@ class Weights:
|
|||
else:
|
||||
g_idx = None
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
weight = GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=bits,
|
||||
groupsize=groupsize,
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
weight = torch.cat(w, dim=dim)
|
||||
|
@ -282,7 +328,28 @@ class Weights:
|
|||
return tensor
|
||||
|
||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||
if quantize == "gptq":
|
||||
if quantize == "exl2":
|
||||
try:
|
||||
q_weight = self.get_tensor(f"{prefix}.q_weight")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
q_scale = self.get_tensor(f"{prefix}.q_scale")
|
||||
q_invperm = self.get_tensor(f"{prefix}.q_invperm")
|
||||
q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
|
||||
q_groups = self.get_tensor(f"{prefix}.q_groups")
|
||||
|
||||
return Exl2Weight(
|
||||
q_weight=q_weight,
|
||||
q_scale=q_scale,
|
||||
q_invperm=q_invperm,
|
||||
q_scale_max=q_scale_max,
|
||||
q_groups=q_groups,
|
||||
)
|
||||
|
||||
elif quantize == "gptq":
|
||||
use_exllama = True
|
||||
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
|
||||
|
||||
|
@ -363,7 +430,15 @@ class Weights:
|
|||
// groupsize
|
||||
).to(dtype=torch.int32)
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
weight = GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=bits,
|
||||
groupsize=groupsize,
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
elif quantize == "awq":
|
||||
bits, groupsize, _, _ = self._get_gptq_params()
|
||||
|
||||
|
@ -379,7 +454,15 @@ class Weights:
|
|||
g_idx = None
|
||||
use_exllama = False
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
weight = GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=bits,
|
||||
groupsize=groupsize,
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
else:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
|
Loading…
Reference in New Issue