From 36dd16017c7211b7760d1daa188172bb902e486f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Tue, 28 May 2024 09:51:31 +0000 Subject: [PATCH] Add support for exl2 quantization Mostly straightforward, changes to existing code: * Wrap quantizer parameters in a small wrapper to avoid passing around untyped tuples and needing to repack them as a dict. * Move scratch space computation to warmup, because we need the maximum input sequence length to avoid allocating huge scratch buffers that OOM. --- docs/source/basic_tutorials/launcher.md | 1 + docs/source/conceptual/guidance.md | 1 - integration-tests/conftest.py | 22 +- .../test_flash_llama_exl2.json | 84 +++++ .../test_flash_llama_exl2_all_params.json | 84 +++++ .../test_flash_llama_exl2_load.json | 338 ++++++++++++++++++ .../models/test_flash_llama_exl2.py | 73 ++++ launcher/src/main.rs | 12 + server/text_generation_server/cli.py | 1 + server/text_generation_server/layers/exl2.py | 23 ++ .../layers/gptq/__init__.py | 22 ++ .../layers/gptq/exllama.py | 26 +- .../layers/gptq/exllamav2.py | 151 ++++---- .../text_generation_server/layers/linear.py | 49 +-- .../layers/tensor_parallel.py | 70 +++- .../text_generation_server/models/__init__.py | 11 +- .../custom_modeling/flash_dbrx_modeling.py | 11 +- .../custom_modeling/flash_llama_modeling.py | 2 +- .../custom_modeling/flash_mistral_modeling.py | 47 +-- .../flash_santacoder_modeling.py | 12 +- .../models/flash_llama.py | 2 +- server/text_generation_server/server.py | 2 +- .../text_generation_server/utils/weights.py | 105 +++++- 23 files changed, 972 insertions(+), 177 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_load.json create mode 100644 integration-tests/models/test_flash_llama_exl2.py create mode 100644 server/text_generation_server/layers/exl2.py diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 1e5b6fd2..c00d2e1a 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -62,6 +62,7 @@ Options: Possible values: - awq: 4 bit quantization. Requires a specific AWQ quantized model: . Should replace GPTQ models wherever possible because of the better latency - eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from + - exl2: Variable bit quantization. Requires a specific EXL2 quantized model: . Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1) - gptq: 4 bit quantization. Requires a specific GTPQ quantized model: . text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels - bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 diff --git a/docs/source/conceptual/guidance.md b/docs/source/conceptual/guidance.md index ad1fc2ec..3059e3de 100644 --- a/docs/source/conceptual/guidance.md +++ b/docs/source/conceptual/guidance.md @@ -2,7 +2,6 @@ ## What is Guidance? - Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format. A prominent example is JSON grammar, where the model is forced to output valid JSON. ## How is it used? diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index d81b8736..2ef85da6 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -38,6 +38,7 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") class ResponseComparator(JSONSnapshotExtension): rtol = 0.2 + ignore_logprob = False def serialize( self, @@ -95,7 +96,10 @@ class ResponseComparator(JSONSnapshotExtension): return ( token.id == other.id and token.text == other.text - and math.isclose(token.logprob, other.logprob, rel_tol=self.rtol) + and ( + self.ignore_logprob + or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol) + ) and token.special == other.special ) @@ -105,8 +109,11 @@ class ResponseComparator(JSONSnapshotExtension): prefill_token.id == other.id and prefill_token.text == other.text and ( - math.isclose( - prefill_token.logprob, other.logprob, rel_tol=self.rtol + self.ignore_logprob + or math.isclose( + prefill_token.logprob, + other.logprob, + rel_tol=self.rtol, ) if prefill_token.logprob is not None else prefill_token.logprob == other.logprob @@ -223,6 +230,10 @@ class GenerousResponseComparator(ResponseComparator): rtol = 0.75 +class IgnoreLogProbResponseComparator(ResponseComparator): + ignore_logprob = True + + class LauncherHandle: def __init__(self, port: int): self.client = AsyncClient(f"http://localhost:{port}") @@ -274,6 +285,11 @@ def generous_response_snapshot(snapshot): return snapshot.use_extension(GenerousResponseComparator) +@pytest.fixture +def ignore_logprob_response_snapshot(snapshot): + return snapshot.use_extension(IgnoreLogProbResponseComparator) + + @pytest.fixture(scope="module") def event_loop(): loop = asyncio.get_event_loop() diff --git a/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2.json b/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2.json new file mode 100644 index 00000000..f6e4bb90 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2.json @@ -0,0 +1,84 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.4375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.9316406, + "special": false, + "text": ":" + }, + { + "id": 330, + "logprob": -3.5136719, + "special": false, + "text": " \"" + }, + { + "id": 489, + "logprob": -0.7783203, + "special": false, + "text": " +" + }, + { + "id": 1715, + "logprob": -1.2314453, + "special": false, + "text": " request" + }, + { + "id": 489, + "logprob": -2.0019531, + "special": false, + "text": " +" + }, + { + "id": 2990, + "logprob": -1.5009766, + "special": false, + "text": " \"\\" + }, + { + "id": 77, + "logprob": -0.057434082, + "special": false, + "text": "n" + }, + { + "id": 702, + "logprob": -1.4912109, + "special": false, + "text": "\"\n" + }, + { + "id": 262, + "logprob": -1.2636719, + "special": false, + "text": " " + }, + { + "id": 557, + "logprob": -2.4042969, + "special": false, + "text": " }\n\n" + } + ], + "top_tokens": null + }, + "generated_text": ": \" + request + \"\\n\"\n }\n\n" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_all_params.json new file mode 100644 index 00000000..6b38e709 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_all_params.json @@ -0,0 +1,84 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.453125, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 13, + "logprob": -1.9980469, + "special": false, + "text": "." + }, + { + "id": 578, + "logprob": -0.15795898, + "special": false, + "text": " The" + }, + { + "id": 3622, + "logprob": -1.0458984, + "special": false, + "text": " server" + }, + { + "id": 31680, + "logprob": -1.3623047, + "special": false, + "text": " responds" + }, + { + "id": 449, + "logprob": 0.0, + "special": false, + "text": " with" + }, + { + "id": 264, + "logprob": 0.0, + "special": false, + "text": " a" + }, + { + "id": 330, + "logprob": -0.5678711, + "special": false, + "text": " \"" + }, + { + "id": 1049, + "logprob": -0.12322998, + "special": false, + "text": "200" + }, + { + "id": 10619, + "logprob": 0.0, + "special": false, + "text": " OK" + }, + { + "id": 1, + "logprob": 0.0, + "special": false, + "text": "\"" + } + ], + "top_tokens": null + }, + "generated_text": "Test request. The server responds with a \"200 OK\"" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_load.json b/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_load.json new file mode 100644 index 00000000..ed369a87 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_load.json @@ -0,0 +1,338 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.453125, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.9785156, + "special": false, + "text": ":" + }, + { + "id": 330, + "logprob": -3.4941406, + "special": false, + "text": " \"" + }, + { + "id": 489, + "logprob": -0.79345703, + "special": false, + "text": " +" + }, + { + "id": 1715, + "logprob": -1.2324219, + "special": false, + "text": " request" + }, + { + "id": 489, + "logprob": -1.9794922, + "special": false, + "text": " +" + }, + { + "id": 2990, + "logprob": -1.4892578, + "special": false, + "text": " \"\\" + }, + { + "id": 77, + "logprob": -0.058258057, + "special": false, + "text": "n" + }, + { + "id": 702, + "logprob": -1.4892578, + "special": false, + "text": "\"\n" + }, + { + "id": 262, + "logprob": -1.2783203, + "special": false, + "text": " " + }, + { + "id": 557, + "logprob": -2.3945312, + "special": false, + "text": " }\n\n" + } + ], + "top_tokens": null + }, + "generated_text": ": \" + request + \"\\n\"\n }\n\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.40625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.9433594, + "special": false, + "text": ":" + }, + { + "id": 330, + "logprob": -3.4726562, + "special": false, + "text": " \"" + }, + { + "id": 489, + "logprob": -0.8022461, + "special": false, + "text": " +" + }, + { + "id": 1715, + "logprob": -1.2509766, + "special": false, + "text": " request" + }, + { + "id": 489, + "logprob": -1.984375, + "special": false, + "text": " +" + }, + { + "id": 2990, + "logprob": -1.4677734, + "special": false, + "text": " \"\\" + }, + { + "id": 77, + "logprob": -0.059173584, + "special": false, + "text": "n" + }, + { + "id": 702, + "logprob": -1.4990234, + "special": false, + "text": "\"\n" + }, + { + "id": 262, + "logprob": -1.2822266, + "special": false, + "text": " " + }, + { + "id": 557, + "logprob": -2.3867188, + "special": false, + "text": " }\n\n" + } + ], + "top_tokens": null + }, + "generated_text": ": \" + request + \"\\n\"\n }\n\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.421875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.9511719, + "special": false, + "text": ":" + }, + { + "id": 330, + "logprob": -3.46875, + "special": false, + "text": " \"" + }, + { + "id": 489, + "logprob": -0.77490234, + "special": false, + "text": " +" + }, + { + "id": 1715, + "logprob": -1.2558594, + "special": false, + "text": " request" + }, + { + "id": 489, + "logprob": -1.984375, + "special": false, + "text": " +" + }, + { + "id": 2990, + "logprob": -1.4990234, + "special": false, + "text": " \"\\" + }, + { + "id": 77, + "logprob": -0.059143066, + "special": false, + "text": "n" + }, + { + "id": 702, + "logprob": -1.4941406, + "special": false, + "text": "\"\n" + }, + { + "id": 262, + "logprob": -1.2578125, + "special": false, + "text": " " + }, + { + "id": 557, + "logprob": -2.3964844, + "special": false, + "text": " }\n\n" + } + ], + "top_tokens": null + }, + "generated_text": ": \" + request + \"\\n\"\n }\n\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.4140625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.9101562, + "special": false, + "text": ":" + }, + { + "id": 330, + "logprob": -3.5039062, + "special": false, + "text": " \"" + }, + { + "id": 489, + "logprob": -0.8076172, + "special": false, + "text": " +" + }, + { + "id": 1715, + "logprob": -1.2236328, + "special": false, + "text": " request" + }, + { + "id": 489, + "logprob": -1.9853516, + "special": false, + "text": " +" + }, + { + "id": 2990, + "logprob": -1.4892578, + "special": false, + "text": " \"\\" + }, + { + "id": 77, + "logprob": -0.056671143, + "special": false, + "text": "n" + }, + { + "id": 702, + "logprob": -1.5107422, + "special": false, + "text": "\"\n" + }, + { + "id": 262, + "logprob": -1.2597656, + "special": false, + "text": " " + }, + { + "id": 557, + "logprob": -2.4042969, + "special": false, + "text": " }\n\n" + } + ], + "top_tokens": null + }, + "generated_text": ": \" + request + \"\\n\"\n }\n\n" + } +] diff --git a/integration-tests/models/test_flash_llama_exl2.py b/integration-tests/models/test_flash_llama_exl2.py new file mode 100644 index 00000000..18319f60 --- /dev/null +++ b/integration-tests/models/test_flash_llama_exl2.py @@ -0,0 +1,73 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_exl2_handle(launcher): + with launcher( + "turboderp/Llama-3-8B-Instruct-exl2", + revision="2.5bpw", + # Set max input length to avoid OOM due to extremely large + # scratch buffer. + max_input_length=1024, + num_shard=1, + quantize="exl2", + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_exl2(flash_llama_exl2_handle): + await flash_llama_exl2_handle.health(300) + return flash_llama_exl2_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot): + response = await flash_llama_exl2.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == ignore_logprob_response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_exl2_all_params( + flash_llama_exl2, ignore_logprob_response_snapshot +): + response = await flash_llama_exl2.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert ( + response.generated_text == 'Test request. The server responds with a "200 OK"' + ) + assert response == ignore_logprob_response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_exl2_load( + flash_llama_exl2, generate_load, ignore_logprob_response_snapshot +): + responses = await generate_load( + flash_llama_exl2, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == ignore_logprob_response_snapshot diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a97a75c0..125d9239 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -55,6 +55,10 @@ enum Quantization { /// Should be a drop-in replacement to bitsandbytes with much better performance. /// Kernels are from Eetq, + /// Variable bit quantization. Requires a specific EXL2 quantized model: + /// . Requires exllama2 kernels and does + /// not support tensor parallelism (num_shard > 1). + Exl2, /// 4 bit quantization. Requires a specific GTPQ quantized model: . /// text-generation-inference will use exllama (faster) kernels wherever possible, and use /// triton kernel (wider support) when it's not. @@ -95,6 +99,9 @@ impl std::fmt::Display for Quantization { Quantization::BitsandbytesFP4 => { write!(f, "bitsandbytes-fp4") } + Quantization::Exl2 => { + write!(f, "exl2") + } Quantization::Gptq => { write!(f, "gptq") } @@ -1461,6 +1468,11 @@ fn main() -> Result<(), LauncherError> { let num_shard = find_num_shards(args.sharded, args.num_shard)?; if num_shard > 1 { + if matches!(args.quantize, Some(Quantization::Exl2)) { + return Err(LauncherError::ArgumentValidation( + "Sharding is currently not supported with `exl2` quantization".into(), + )); + } tracing::info!("Sharding model on {num_shard} processes"); } diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index ad623ccc..16375ecd 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -19,6 +19,7 @@ class Quantization(str, Enum): gptq = "gptq" awq = "awq" eetq = "eetq" + exl2 = "exl2" fp8 = "fp8" diff --git a/server/text_generation_server/layers/exl2.py b/server/text_generation_server/layers/exl2.py new file mode 100644 index 00000000..f6cb729e --- /dev/null +++ b/server/text_generation_server/layers/exl2.py @@ -0,0 +1,23 @@ +import torch +from dataclasses import dataclass + + +@dataclass +class Exl2Weight: + """ + Exllama2 exl2 quantized weights. + """ + + q_weight: torch.Tensor + q_scale: torch.Tensor + q_invperm: torch.Tensor + q_scale_max: torch.Tensor + q_groups: torch.Tensor + + def __post_init__(self): + self.q_scale_max /= 256 + self.q_invperm = self.q_invperm.short() + + @property + def device(self) -> torch.device: + return self.q_weight.device diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 1c46f493..1172775f 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -1,9 +1,31 @@ +from dataclasses import dataclass import os +from typing import Optional import torch from text_generation_server.utils.import_utils import ( SYSTEM, ) + +@dataclass +class GPTQWeight: + qweight: torch.Tensor + qzeros: torch.Tensor + scales: torch.Tensor + g_idx: Optional[torch.Tensor] + bits: int + groupsize: int + use_exllama: bool + + def __post_init__(self): + if self.scales.dtype == torch.float: + self.scales = self.scales.half() + + @property + def device(self) -> torch.device: + return self.qweight.device + + try: major, _minor = torch.cuda.get_device_capability() except Exception: diff --git a/server/text_generation_server/layers/gptq/exllama.py b/server/text_generation_server/layers/gptq/exllama.py index 32f817db..4875af38 100644 --- a/server/text_generation_server/layers/gptq/exllama.py +++ b/server/text_generation_server/layers/gptq/exllama.py @@ -1,3 +1,4 @@ +from text_generation_server.utils.weights import GPTQWeight import torch from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params @@ -65,24 +66,25 @@ def create_exllama_buffers(max_total_tokens: int): class Ex4bitLinear(torch.nn.Module): """Linear layer implementation with per-group 4-bit quantization of the weights""" - def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + def __init__(self, weight: GPTQWeight, bias): super().__init__() global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE - assert bits == 4 + assert weight.bits == 4 - self.device = qweight.device - self.qweight = qweight - self.qzeros = qzeros - self.scales = scales - self.g_idx = g_idx.cpu() if g_idx is not None else None + self.device = weight.qweight.device + self.qweight = weight.qweight + self.qzeros = weight.qzeros + self.scales = weight.scales + self.g_idx = weight.g_idx.cpu() if weight.g_idx is not None else None self.bias = bias if bias is not None else None if self.g_idx is not None and ( (self.g_idx == 0).all() or torch.equal( - g_idx.cpu(), + weight.g_idx.cpu(), torch.tensor( - [i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32 + [i // weight.groupsize for i in range(weight.g_idx.shape[0])], + dtype=torch.int32, ), ) ): @@ -96,8 +98,8 @@ class Ex4bitLinear(torch.nn.Module): self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index ) - self.height = qweight.shape[0] * 8 - self.width = qweight.shape[1] + self.height = weight.qweight.shape[0] * 8 + self.width = weight.qweight.shape[1] # Infer groupsize from height of qzeros self.groupsize = None @@ -105,7 +107,7 @@ class Ex4bitLinear(torch.nn.Module): self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0]) if self.groupsize is not None: - assert groupsize == self.groupsize + assert weight.groupsize == self.groupsize # Handle act-order matrix if self.g_idx is not None: diff --git a/server/text_generation_server/layers/gptq/exllamav2.py b/server/text_generation_server/layers/gptq/exllamav2.py index 321ced97..2ae9628a 100644 --- a/server/text_generation_server/layers/gptq/exllamav2.py +++ b/server/text_generation_server/layers/gptq/exllamav2.py @@ -1,10 +1,15 @@ # Adapted from turboderp exllama: https://github.com/turboderp/exllamav2 +from dataclasses import dataclass +from typing import Optional import torch import torch.nn as nn from loguru import logger +from text_generation_server.layers.exl2 import Exl2Weight +from text_generation_server.layers.gptq import GPTQWeight + try: from exllamav2_kernels import make_q_matrix, gemm_half_q_half except ImportError: @@ -15,6 +20,15 @@ except ImportError: none_tensor = torch.empty((1, 1), device="meta") +@dataclass +class _ExtraTensors: + """Additional generated quantizer tensors.""" + + q_group_map: Optional[torch.Tensor] = None + q_invperm: Optional[torch.Tensor] = None + q_perm: Optional[torch.Tensor] = None + + def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): """Matrix multiplication, returns x @ q4""" output_shape = x.shape[:-1] + (q4_width,) @@ -24,11 +38,7 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): return output.view(output_shape) -# Group map needed for irregular group sizes - - -def make_group_map(q_groups, num_qrows): - +def make_group_map(q_groups: torch.Tensor, num_qrows: int): gr = q_groups.tolist() group_map = [] num_groups = len(gr) // 2 @@ -50,72 +60,72 @@ def make_group_map(q_groups, num_qrows): # Create Q matrix -def ext_make_q_matrix(w: dict, temp_dq, key: str = None): +def ext_make_q_matrix( + w: Exl2Weight | GPTQWeight, + extra: _ExtraTensors, + temp_dq, + key: Optional[str] = None, +): """ Create Q matrix """ # EXL2 - # won't work as the moment because the tensors are not the same. - if "q_weight" in w: - w["q_scale_max"] /= 256 - w["q_perm"] = w["q_perm"].short() - w["q_invperm"] = w["q_invperm"].short() - - if "q_group_map" not in w: - w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0]) + if isinstance(w, Exl2Weight): + extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0]) + extra.q_perm = torch.argsort(w.q_invperm).short() return make_q_matrix( - w["q_weight"], - w["q_perm"], - w["q_invperm"], - w["q_scale"], - w["q_scale_max"], - w["q_groups"], - w["q_group_map"], + w.q_weight, + extra.q_perm, + w.q_invperm, + w.q_scale, + w.q_scale_max, + w.q_groups, + extra.q_group_map, none_tensor, none_tensor, none_tensor, temp_dq, ) # GPTQ - elif "qweight" in w: - if w["scales"].dtype == torch.float: - w["scales"] = w["scales"].half() + elif isinstance(w, GPTQWeight): + if w.scales.dtype == torch.float: + w.scales = w.scales.half() # GPTQ with g_idx (act_order) - if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item(): - w["q_perm"] = torch.empty( - (w["qweight"].shape[0] * 8,), + if w.g_idx is not None and not (w.g_idx == 0).all().item(): + extra.q_perm = torch.empty( + (w.qweight.shape[0] * 8,), dtype=torch.short, - device=w["qweight"].device, + device=w.qweight.device, ) - w["q_invperm"] = torch.empty_like(w["q_perm"]) + extra.q_invperm = torch.empty_like(extra.q_perm) # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. return make_q_matrix( - w["qweight"], - w["q_perm"], - w["q_invperm"], + w.qweight, + extra.q_perm, + extra.q_invperm, none_tensor, none_tensor, none_tensor, none_tensor, - w["qzeros"], - w["scales"], - w["g_idx"].cpu(), + w.qzeros, + w.scales, + w.g_idx.cpu(), temp_dq, ) # GPTQ without g_idx else: return make_q_matrix( - w["qweight"], + w.qweight, none_tensor, none_tensor, none_tensor, none_tensor, none_tensor, none_tensor, - w["qzeros"], - w["scales"], + w.qzeros, + w.scales, none_tensor, temp_dq, ) @@ -124,7 +134,6 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): DEVICE = None -FIXED_BYTES = 0 LAYERS = [] @@ -134,8 +143,13 @@ def set_device(device): def create_exllama_buffers(max_total_tokens: int): - global FIXED_BYTES, LAYERS, DEVICE - temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES) + global LAYERS, DEVICE + + # Find the size of the scratch space. + scratch_bytes = max( + layer.scratch_space_fixed(max_input_len=max_total_tokens) for layer in LAYERS + ) + temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes) for layer in LAYERS: layer.post_init(temp_dq) @@ -146,49 +160,48 @@ class QuantLinear(nn.Module): """Linear layer implementation with per-group 4-bit quantization of the weights""" - # def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs): - def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + def __init__( + self, + weight: Exl2Weight | GPTQWeight, + bias: torch.Tensor, + ): super().__init__() - if bits != 4: - raise ValueError( - f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization." - ) + self.q_handle = None - self.q_tensors = None - self.bits = bits - self.maxq = 2**self.bits - 1 - self.infeatures = qweight.shape[0] // self.bits * 32 - self.outfeatures = qweight.shape[1] + self.q_tensors = weight + self.extra_tensors = _ExtraTensors() + + if isinstance(weight, Exl2Weight): + self.infeatures = weight.q_invperm.shape[0] + self.outfeatures = weight.q_weight.shape[1] + elif isinstance(weight, GPTQWeight): + if weight.bits != 4: + raise ValueError( + f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization." + ) + + self.infeatures = weight.qweight.shape[0] // weight.bits * 32 + self.outfeatures = weight.qweight.shape[1] + self.padding = -self.outfeatures % 32 self.outfeatures = self.outfeatures + self.padding - self.device = qweight.device - self.qweight = qweight - self.qzeros = qzeros - self.scales = scales - self.g_idx = g_idx + self.device = weight.device self.bias = bias if bias is not None else None - self.group_size = groupsize - global FIXED_BYTES, LAYERS - FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed()) + global LAYERS LAYERS.append(self) def post_init(self, temp_dq): - assert self.qweight.device.type == "cuda" - assert self.qweight.device.index is not None - self.q_tensors = { - "qweight": self.qweight, - "qzeros": self.qzeros, - "scales": self.scales, - "g_idx": self.g_idx, - } + device = self.q_tensors.device + assert device.type == "cuda" + assert device.index is not None temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) # We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us, # and `Memory access fault by GPU node-2` will EAT you. self.temp_dq = temp_dq - self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq) + self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq) def forward(self, x, force_cuda=False): output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index 5bd6aa95..570aa75c 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -1,6 +1,9 @@ +from typing import Optional import torch from torch.nn import functional as F from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.exl2 import Exl2Weight +from text_generation_server.layers.gptq import GPTQWeight if SYSTEM == "rocm": try: @@ -151,15 +154,23 @@ def get_linear(weight, bias, quantize): bias, quant_type="nf4", ) + elif quantize == "exl2": + if not isinstance(weight, Exl2Weight): + raise NotImplementedError( + f"The passed weight is not `exl2` compatible, loader needs to be updated." + ) + + from text_generation_server.layers.gptq import ExllamaQuantLinear + + linear = ExllamaQuantLinear(weight, bias) + elif quantize == "gptq": - try: - qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight - except Exception: + if not isinstance(weight, GPTQWeight): raise NotImplementedError( f"The passed weight is not `gptq` compatible, loader needs to be updated." ) - if use_exllama: + if weight.use_exllama: try: from text_generation_server.layers.gptq import ( ExllamaQuantLinear, @@ -169,25 +180,21 @@ def get_linear(weight, bias, quantize): f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" ) - linear = ExllamaQuantLinear( - qweight, qzeros, scales, g_idx, bias, bits, groupsize - ) + linear = ExllamaQuantLinear(weight, bias) else: from text_generation_server.layers.gptq.quant_linear import QuantLinear linear = QuantLinear( - qweight, - qzeros, - scales, - g_idx, + weight.qweight, + weight.qzeros, + weight.scales, + weight.g_idx, bias, - bits, - groupsize, + weight.bits, + weight.groupsize, ) elif quantize == "awq": - try: - qweight, qzeros, scales, _, bits, groupsize, _ = weight - except Exception: + if not isinstance(weight, GPTQWeight): raise NotImplementedError( f"The passed weight is not `awq` compatible, loader needs to be updated." ) @@ -200,11 +207,11 @@ def get_linear(weight, bias, quantize): from text_generation_server.layers.awq.quantize.qmodule import WQLinear linear = WQLinear( - w_bit=bits, - group_size=groupsize, - qweight=qweight, - qzeros=qzeros, - scales=scales, + w_bit=weight.bits, + group_size=weight.groupsize, + qweight=weight.qweight, + qzeros=weight.qzeros, + scales=weight.scales, bias=bias is not None, ) except ImportError: diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 34b9c51e..afaaa1b8 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -1,7 +1,27 @@ import torch from torch.nn import functional as F -from typing import List +from typing import Iterable, List from text_generation_server.layers.linear import get_linear, FastLinear +from text_generation_server.layers.exl2 import Exl2Weight + + +class LayerConcat(torch.nn.Module): + """ + Apply multiple layers to the input and concatenate their + outputs. + """ + + def __init__(self, layers: Iterable[torch.nn.Module], dim: int = -1): + """ + `dim` is the dimension along which layer outputs are concatenated. + """ + super().__init__() + self.layers = layers + self.dim = dim + + def forward(self, x: torch.Tensor): + outputs = [layer(x) for layer in self.layers] + return torch.cat(outputs, self.dim) class SuperLayer(torch.nn.Module): @@ -21,7 +41,16 @@ class TensorParallelHead(SuperLayer): @staticmethod def load(config, prefix: str, weights): - if weights.process_group.size() > 1: + if config.quantize == "exl2": + try: + # If the piece and LM head embeddings are shared, we have + # non-quantized weights... + weight = weights.get_tensor(f"{prefix}.weight") + except: + # ...otherwise they are quantized. + weight = weights.get_weights_col(prefix, config.quantize) + should_gather = weights.process_group.size() > 1 + elif weights.process_group.size() > 1: try: weight = weights.get_sharded(f"{prefix}.weight", dim=0) should_gather = True @@ -37,8 +66,12 @@ class TensorParallelHead(SuperLayer): # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings) if config.quantize in ["gptq", "awq", "eetq"]: quantize = None + # See above, exl2 LM head can be quantized or not. + elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight): + quantize = None else: quantize = config.quantize + return TensorParallelHead( get_linear(weight, bias=None, quantize=quantize), process_group=weights.process_group, @@ -108,22 +141,35 @@ class TensorParallelColumnLinear(SuperLayer): @classmethod def load(cls, config, prefix: str, weights, bias: bool): - return cls.load_multi(config, [prefix], weights, bias, dim=0) - - @classmethod - def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): - weight = weights.get_multi_weights_col( - prefixes, quantize=config.quantize, dim=dim - ) - + weight = weights.get_weights_col(prefix, config.quantize) if bias: - b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] - bias = torch.cat(b, dim=dim) + bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None linear = get_linear(weight, bias, config.quantize) return cls(linear) + @classmethod + def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): + if config.quantize == "exl2": + linears = [] + for prefix in prefixes: + weight = weights.get_weights_col(prefix, config.quantize) + b = weights.get_tensor(f"{prefix}.bias") if bias else None + linears.append(get_linear(weight, b, config.quantize)) + linear = LayerConcat(linears) + else: + weight = weights.get_multi_weights_col( + prefixes, quantize=config.quantize, dim=dim + ) + if bias: + b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] + bias = torch.cat(b, dim=dim) + else: + bias = None + linear = get_linear(weight, bias, config.quantize) + return cls(linear) + class TensorParallelRowLinear(SuperLayer): def __init__(self, linear, process_group): diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 92a20639..d086f87b 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -263,7 +263,7 @@ def get_model( trust_remote_code: bool, ) -> Model: if dtype is None: - if quantize in ["awq", "gptq"]: + if quantize in ["awq", "exl2", "gptq"]: # These quantizers only work with float16 params. dtype = torch.float16 else: @@ -402,12 +402,17 @@ def get_model( quantization_config = config_dict.get("quantization_config", None) if quantization_config is not None and quantize is None: method = quantization_config.get("quant_method", None) - if method in {"gptq", "awq"}: + if method in {"gptq", "awq", "exl2"}: logger.info(f"Auto selecting quantization method {method}") quantize = method else: logger.info(f"Unknown quantization method {method}") + if quantize == "exl2" and sharded: + raise RuntimeError( + "Sharding is currently not supported with `exl2` quantization" + ) + if model_type == MAMBA: return Mamba( model_id, @@ -881,6 +886,8 @@ def get_model( raise NotImplementedError("4bit quantization is not supported for AutoModel") elif quantize == "eetq": raise NotImplementedError("Eetq quantization is not supported for AutoModel") + elif quantize == "exl2": + raise NotImplementedError("exl2 quantization is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: return CausalLM( model_id, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 9d652b67..56bfb9d0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -21,6 +21,7 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any from loguru import logger +from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.utils.import_utils import SYSTEM if SYSTEM != "xpu": @@ -256,7 +257,15 @@ def _load_gqa(config, prefix: str, weights): else: g_idx = None - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) + weight = GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=bits, + groupsize=groupsize, + use_exllama=use_exllama, + ) else: qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight") q = qkv_slice[q_start:q_stop] diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index f722bf73..fa3a78f8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -395,7 +395,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): self.lm_head = SpeculativeHead.load( config, - prefix=suffix if not prefix else f"{prefix}.suffix", + prefix=suffix if not prefix else f"{prefix}.{suffix}", weights=weights, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index ef3777da..65043dee 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -102,45 +102,6 @@ class MistralConfig(PretrainedConfig): ) -def load_attention(config, prefix, weights): - if config.num_attention_heads != config.num_key_value_heads: - return _load_gqa(config, prefix, weights) - else: - return TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=False, - ) - - -def _load_gqa(config, prefix: str, weights): - assert config.hidden_size % config.num_attention_heads == 0 - assert config.num_attention_heads % weights.process_group.size() == 0 - - weight = weights.get_multi_weights_col( - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, - dim=0, - ) - - if config.quantize not in ["gptq", "awq"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) - - head_size = config.hidden_size // config.num_attention_heads - num_heads = config.num_attention_heads // weights.process_group.size() - num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ - (num_heads + 2 * num_key_value_heads) * head_size, - config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) - - class MistralAttention(torch.nn.Module): def __init__( self, @@ -175,7 +136,13 @@ class MistralAttention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) - self.query_key_value = load_attention(config, prefix, weights) + self.query_key_value = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) self.o_proj = TensorParallelRowLinear.load( config, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index d2f6d9af..cfa4243f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -5,6 +5,7 @@ from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple +from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.layers import ( TensorParallelRowLinear, @@ -90,8 +91,15 @@ def _load_multi_mqa_gptq( from text_generation_server.layers.gptq import HAS_EXLLAMA - use_exllama = HAS_EXLLAMA - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) + weight = GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=bits, + groupsize=groupsize, + use_exllama=HAS_EXLLAMA, + ) if bias: slice_ = weights._get_slice(f"{prefix}.c_attn.bias") diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 9a7dfaee..c5cbd2b8 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -67,7 +67,7 @@ class FlashLlama(FlashCausalLM): filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq"]: + if config.quantize in ["gptq", "awq", "exl2"]: weights._set_gptq_params(model_id, revision) prefix = "" diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 37c46032..4118b3f6 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -89,7 +89,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): - if self.quantize == "gptq": + if self.quantize in {"exl2", "gptq"}: try: # When using GPTQ, Exllama kernels need some global kernels # For which we have the finale shapes only after the model has loaded diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 6af7d3fb..710ea680 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,11 +1,14 @@ +from dataclasses import dataclass, field import os from pathlib import Path -from typing import List, Dict, Optional, Tuple +from typing import List, Dict, Optional, Set, Tuple, Union from safetensors import safe_open, SafetensorError import torch from loguru import logger from huggingface_hub import hf_hub_download import json +from text_generation_server.layers.exl2 import Exl2Weight +from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.utils.log import log_once @@ -76,8 +79,9 @@ class Weights: f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) # Special case for gptq which shouldn't convert - # u4 which are disguised as int32 - if tensor.dtype not in [torch.int32, torch.int64]: + # u4 which are disguised as int32. Exl2 uses int16 + # as well. + if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: tensor = tensor.to(dtype=self.dtype) if to_device: tensor = tensor.to(device=self.device) @@ -102,8 +106,8 @@ class Weights: else: raise NotImplementedError("Let's make that generic when needed") # Special case for gptq which shouldn't convert - # u4 which are disguised as int32 - if tensor.dtype != torch.int32: + # u4 which are disguised as int32. exl2 uses int16. + if tensor.dtype not in (torch.int16, torch.int32): tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor @@ -183,7 +187,15 @@ class Weights: else: g_idx = None - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) + weight = GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=bits, + groupsize=groupsize, + use_exllama=False, + ) else: slice_ = self._get_slice(f"{prefix}.weight") total_size = slice_.get_shape()[0] @@ -207,8 +219,34 @@ class Weights: weight = weight.to(dtype=self.dtype) return weight + def get_weights_col(self, prefix: str, quantize: str): + if quantize == "exl2": + try: + q_weight = self.get_tensor(f"{prefix}.q_weight") + except RuntimeError: + raise RuntimeError( + f"Cannot load `exl2`-quantized weight, make sure the model is already quantized." + ) + + q_scale = self.get_tensor(f"{prefix}.q_scale") + q_invperm = self.get_tensor(f"{prefix}.q_invperm") + q_scale_max = self.get_tensor(f"{prefix}.q_scale_max") + q_groups = self.get_tensor(f"{prefix}.q_groups") + + return Exl2Weight( + q_weight=q_weight, + q_scale=q_scale, + q_invperm=q_invperm, + q_scale_max=q_scale_max, + q_groups=q_groups, + ) + + return self.get_multi_weights_col([prefix], quantize, 0) + def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): - if quantize in ["gptq", "awq"]: + if quantize == "exl2": + raise ValueError("get_multi_weights_col is not supported for exl2") + elif quantize in ["gptq", "awq"]: try: qweight = torch.cat( [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 @@ -259,7 +297,15 @@ class Weights: else: g_idx = None - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) + weight = GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=bits, + groupsize=groupsize, + use_exllama=use_exllama, + ) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] weight = torch.cat(w, dim=dim) @@ -282,7 +328,28 @@ class Weights: return tensor def get_multi_weights_row(self, prefix: str, quantize: str): - if quantize == "gptq": + if quantize == "exl2": + try: + q_weight = self.get_tensor(f"{prefix}.q_weight") + except RuntimeError: + raise RuntimeError( + f"Cannot load `exl2`-quantized weight, make sure the model is already quantized." + ) + + q_scale = self.get_tensor(f"{prefix}.q_scale") + q_invperm = self.get_tensor(f"{prefix}.q_invperm") + q_scale_max = self.get_tensor(f"{prefix}.q_scale_max") + q_groups = self.get_tensor(f"{prefix}.q_groups") + + return Exl2Weight( + q_weight=q_weight, + q_scale=q_scale, + q_invperm=q_invperm, + q_scale_max=q_scale_max, + q_groups=q_groups, + ) + + elif quantize == "gptq": use_exllama = True bits, groupsize, desc_act, quant_method = self._get_gptq_params() @@ -363,7 +430,15 @@ class Weights: // groupsize ).to(dtype=torch.int32) - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) + weight = GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=bits, + groupsize=groupsize, + use_exllama=use_exllama, + ) elif quantize == "awq": bits, groupsize, _, _ = self._get_gptq_params() @@ -379,7 +454,15 @@ class Weights: g_idx = None use_exllama = False - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) + weight = GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=bits, + groupsize=groupsize, + use_exllama=use_exllama, + ) else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight