From 46a5a7e73e8b1adde4a279c7d68818ed9e17f607 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 20 Nov 2024 18:25:23 +0100 Subject: [PATCH] Add support for wNa16 int 2:4 compressed-tensors checkpoints (#2758) This change adds support for wNa16 int checkpoints with 2:4 sparsity using Marlin 2:4 kernels. --- .../test_compressed_tensors_wna16_int_24.json | 104 +++++ ...essed_tensors_wna16_int_24_all_params.json | 99 +++++ ..._compressed_tensors_wna16_int_24_load.json | 418 ++++++++++++++++++ .../test_compressed_tensors_wna16_int_24.py | 90 ++++ .../layers/compressed_tensors/loader.py | 14 +- .../layers/compressed_tensors/wna16_int.py | 4 +- .../layers/compressed_tensors/wna16_int_24.py | 101 +++++ .../layers/marlin/marlin.py | 56 ++- 8 files changed, 860 insertions(+), 26 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int_24/test_compressed_tensors_wna16_int_24.json create mode 100644 integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int_24/test_compressed_tensors_wna16_int_24_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int_24/test_compressed_tensors_wna16_int_24_load.json create mode 100644 integration-tests/models/test_compressed_tensors_wna16_int_24.py create mode 100644 server/text_generation_server/layers/compressed_tensors/wna16_int_24.py diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int_24/test_compressed_tensors_wna16_int_24.json b/integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int_24/test_compressed_tensors_wna16_int_24.json new file mode 100644 index 00000000..74e74801 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int_24/test_compressed_tensors_wna16_int_24.json @@ -0,0 +1,104 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 3923, + "logprob": -7.5390625, + "text": "What" + }, + { + "id": 374, + "logprob": -0.86035156, + "text": " is" + }, + { + "id": 5655, + "logprob": -8.828125, + "text": " deep" + }, + { + "id": 6975, + "logprob": -1.4912109, + "text": " learning" + }, + { + "id": 30, + "logprob": -2.1152344, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 34564, + "logprob": -1.765625, + "special": false, + "text": "Deep" + }, + { + "id": 6975, + "logprob": -0.023864746, + "special": false, + "text": " learning" + }, + { + "id": 374, + "logprob": -0.1060791, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.1940918, + "special": false, + "text": " a" + }, + { + "id": 27084, + "logprob": -0.79785156, + "special": false, + "text": " subset" + }, + { + "id": 315, + "logprob": -0.008262634, + "special": false, + "text": " of" + }, + { + "id": 5780, + "logprob": -0.046569824, + "special": false, + "text": " machine" + }, + { + "id": 6975, + "logprob": -0.0023479462, + "special": false, + "text": " learning" + }, + { + "id": 430, + "logprob": -0.7626953, + "special": false, + "text": " that" + }, + { + "id": 5829, + "logprob": -1.0107422, + "special": false, + "text": " uses" + } + ], + "top_tokens": null + }, + "generated_text": "Deep learning is a subset of machine learning that uses" +} diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int_24/test_compressed_tensors_wna16_int_24_all_params.json b/integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int_24/test_compressed_tensors_wna16_int_24_all_params.json new file mode 100644 index 00000000..596736ff --- /dev/null +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int_24/test_compressed_tensors_wna16_int_24_all_params.json @@ -0,0 +1,99 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 3923, + "logprob": -7.5390625, + "text": "What" + }, + { + "id": 374, + "logprob": -0.86035156, + "text": " is" + }, + { + "id": 5655, + "logprob": -8.828125, + "text": " deep" + }, + { + "id": 6975, + "logprob": -1.4912109, + "text": " learning" + } + ], + "seed": 0, + "tokens": [ + { + "id": 5380, + "logprob": 0.0, + "special": false, + "text": "?\n" + }, + { + "id": 34564, + "logprob": 0.0, + "special": false, + "text": "Deep" + }, + { + "id": 6975, + "logprob": 0.0, + "special": false, + "text": " learning" + }, + { + "id": 320, + "logprob": -0.19580078, + "special": false, + "text": " (" + }, + { + "id": 16931, + "logprob": -1.7783203, + "special": false, + "text": "DL" + }, + { + "id": 8, + "logprob": 0.0, + "special": false, + "text": ")" + }, + { + "id": 374, + "logprob": -1.4287109, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": 0.0, + "special": false, + "text": " a" + }, + { + "id": 27084, + "logprob": 0.0, + "special": false, + "text": " subset" + }, + { + "id": 315, + "logprob": 0.0, + "special": false, + "text": " of" + } + ], + "top_tokens": null + }, + "generated_text": "What is deep learning?\nDeep learning (DL) is a subset of" +} diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int_24/test_compressed_tensors_wna16_int_24_load.json b/integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int_24/test_compressed_tensors_wna16_int_24_load.json new file mode 100644 index 00000000..c32c80cc --- /dev/null +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int_24/test_compressed_tensors_wna16_int_24_load.json @@ -0,0 +1,418 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 3923, + "logprob": -7.5390625, + "text": "What" + }, + { + "id": 374, + "logprob": -0.86035156, + "text": " is" + }, + { + "id": 5655, + "logprob": -8.828125, + "text": " deep" + }, + { + "id": 6975, + "logprob": -1.4912109, + "text": " learning" + }, + { + "id": 30, + "logprob": -2.1152344, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 34564, + "logprob": -1.765625, + "special": false, + "text": "Deep" + }, + { + "id": 6975, + "logprob": -0.024002075, + "special": false, + "text": " learning" + }, + { + "id": 374, + "logprob": -0.10760498, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.19580078, + "special": false, + "text": " a" + }, + { + "id": 27084, + "logprob": -0.7993164, + "special": false, + "text": " subset" + }, + { + "id": 315, + "logprob": -0.008300781, + "special": false, + "text": " of" + }, + { + "id": 5780, + "logprob": -0.046295166, + "special": false, + "text": " machine" + }, + { + "id": 6975, + "logprob": -0.002374649, + "special": false, + "text": " learning" + }, + { + "id": 430, + "logprob": -0.7651367, + "special": false, + "text": " that" + }, + { + "id": 5829, + "logprob": -1.0107422, + "special": false, + "text": " uses" + } + ], + "top_tokens": null + }, + "generated_text": "Deep learning is a subset of machine learning that uses" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 3923, + "logprob": -7.5351562, + "text": "What" + }, + { + "id": 374, + "logprob": -0.85791016, + "text": " is" + }, + { + "id": 5655, + "logprob": -8.828125, + "text": " deep" + }, + { + "id": 6975, + "logprob": -1.4882812, + "text": " learning" + }, + { + "id": 30, + "logprob": -2.1210938, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 34564, + "logprob": -1.7597656, + "special": false, + "text": "Deep" + }, + { + "id": 6975, + "logprob": -0.024032593, + "special": false, + "text": " learning" + }, + { + "id": 374, + "logprob": -0.10748291, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.19592285, + "special": false, + "text": " a" + }, + { + "id": 27084, + "logprob": -0.7988281, + "special": false, + "text": " subset" + }, + { + "id": 315, + "logprob": -0.008354187, + "special": false, + "text": " of" + }, + { + "id": 5780, + "logprob": -0.046569824, + "special": false, + "text": " machine" + }, + { + "id": 6975, + "logprob": -0.0023517609, + "special": false, + "text": " learning" + }, + { + "id": 430, + "logprob": -0.7661133, + "special": false, + "text": " that" + }, + { + "id": 5829, + "logprob": -1.0107422, + "special": false, + "text": " uses" + } + ], + "top_tokens": null + }, + "generated_text": "Deep learning is a subset of machine learning that uses" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 3923, + "logprob": -7.5351562, + "text": "What" + }, + { + "id": 374, + "logprob": -0.85791016, + "text": " is" + }, + { + "id": 5655, + "logprob": -8.828125, + "text": " deep" + }, + { + "id": 6975, + "logprob": -1.4882812, + "text": " learning" + }, + { + "id": 30, + "logprob": -2.1210938, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 34564, + "logprob": -1.7597656, + "special": false, + "text": "Deep" + }, + { + "id": 6975, + "logprob": -0.024032593, + "special": false, + "text": " learning" + }, + { + "id": 374, + "logprob": -0.10748291, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.19592285, + "special": false, + "text": " a" + }, + { + "id": 27084, + "logprob": -0.7988281, + "special": false, + "text": " subset" + }, + { + "id": 315, + "logprob": -0.008354187, + "special": false, + "text": " of" + }, + { + "id": 5780, + "logprob": -0.046569824, + "special": false, + "text": " machine" + }, + { + "id": 6975, + "logprob": -0.0023517609, + "special": false, + "text": " learning" + }, + { + "id": 430, + "logprob": -0.7661133, + "special": false, + "text": " that" + }, + { + "id": 5829, + "logprob": -1.0107422, + "special": false, + "text": " uses" + } + ], + "top_tokens": null + }, + "generated_text": "Deep learning is a subset of machine learning that uses" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 3923, + "logprob": -7.5351562, + "text": "What" + }, + { + "id": 374, + "logprob": -0.85791016, + "text": " is" + }, + { + "id": 5655, + "logprob": -8.828125, + "text": " deep" + }, + { + "id": 6975, + "logprob": -1.4882812, + "text": " learning" + }, + { + "id": 30, + "logprob": -2.1210938, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 34564, + "logprob": -1.7597656, + "special": false, + "text": "Deep" + }, + { + "id": 6975, + "logprob": -0.024032593, + "special": false, + "text": " learning" + }, + { + "id": 374, + "logprob": -0.10748291, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.19592285, + "special": false, + "text": " a" + }, + { + "id": 27084, + "logprob": -0.7988281, + "special": false, + "text": " subset" + }, + { + "id": 315, + "logprob": -0.008354187, + "special": false, + "text": " of" + }, + { + "id": 5780, + "logprob": -0.046569824, + "special": false, + "text": " machine" + }, + { + "id": 6975, + "logprob": -0.0023517609, + "special": false, + "text": " learning" + }, + { + "id": 430, + "logprob": -0.7661133, + "special": false, + "text": " that" + }, + { + "id": 5829, + "logprob": -1.0107422, + "special": false, + "text": " uses" + } + ], + "top_tokens": null + }, + "generated_text": "Deep learning is a subset of machine learning that uses" + } +] diff --git a/integration-tests/models/test_compressed_tensors_wna16_int_24.py b/integration-tests/models/test_compressed_tensors_wna16_int_24.py new file mode 100644 index 00000000..0f76f6a8 --- /dev/null +++ b/integration-tests/models/test_compressed_tensors_wna16_int_24.py @@ -0,0 +1,90 @@ +import pytest + + +@pytest.fixture(scope="module") +def compressed_tensors_wna16_int_24_handle(launcher): + with launcher( + "danieldk/Llama-3.1-8B-w4a16-int-24", + num_shard=2, + quantize="compressed-tensors", + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def compressed_tensors_wna16_int_24(compressed_tensors_wna16_int_24_handle): + await compressed_tensors_wna16_int_24_handle.health(300) + return compressed_tensors_wna16_int_24_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_compressed_tensors_wna16_int_24( + compressed_tensors_wna16_int_24, response_snapshot +): + response = await compressed_tensors_wna16_int_24.generate( + "What is deep learning?", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert ( + response.generated_text + == "Deep learning is a subset of machine learning that uses" + ) + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_compressed_tensors_wna16_int_24_all_params( + compressed_tensors_wna16_int_24, response_snapshot +): + response = await compressed_tensors_wna16_int_24.generate( + "What is deep learning", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert ( + response.generated_text + == "What is deep learning?\nDeep learning (DL) is a subset of" + ) + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_compressed_tensors_wna16_int_24_load( + compressed_tensors_wna16_int_24, generate_load, response_snapshot +): + responses = await generate_load( + compressed_tensors_wna16_int_24, + "What is deep learning?", + max_new_tokens=10, + n=4, + ) + + assert ( + responses[0].generated_text + == "Deep learning is a subset of machine learning that uses" + ) + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/text_generation_server/layers/compressed_tensors/loader.py b/server/text_generation_server/layers/compressed_tensors/loader.py index 957277bf..17d0224e 100644 --- a/server/text_generation_server/layers/compressed_tensors/loader.py +++ b/server/text_generation_server/layers/compressed_tensors/loader.py @@ -13,7 +13,10 @@ from torch import nn from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader from text_generation_server.layers.compressed_tensors.w8a8_int import W8A8IntLoader -from text_generation_server.layers.compressed_tensors.wna16_int import WNA16Loader +from text_generation_server.layers.compressed_tensors.wna16_int_24 import ( + WNA16Int24Loader, +) +from text_generation_server.layers.compressed_tensors.wna16_int import WNA16IntLoader from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import ( DefaultWeightsLoader, @@ -151,7 +154,14 @@ class CompressedTensorsLoader(WeightsLoader): and weights.num_bits in (4, 8) ): # INT W4A16 or W8A16 (GPTQ/AWQ-like). - return WNA16Loader(weights) + return WNA16IntLoader(weights) + elif ( + format == CompressionFormat.marlin_24.value + and weights is not None + and weights.type == QuantizationType.INT + and weights.num_bits in (4, 8) + ): + return WNA16Int24Loader(weights) elif ( format in { diff --git a/server/text_generation_server/layers/compressed_tensors/wna16_int.py b/server/text_generation_server/layers/compressed_tensors/wna16_int.py index a616867a..bb69c6b5 100644 --- a/server/text_generation_server/layers/compressed_tensors/wna16_int.py +++ b/server/text_generation_server/layers/compressed_tensors/wna16_int.py @@ -9,7 +9,7 @@ from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weights, WeightsLoader -class WNA16Loader(WeightsLoader): +class WNA16IntLoader(WeightsLoader): """ Loader for W4A16/W8A16 INT compressed-tensors parameters. """ @@ -22,7 +22,7 @@ class WNA16Loader(WeightsLoader): ) def __str__(self) -> str: - quantization_type = f"W{self.weights.num_bits}8A16" + quantization_type = f"W{self.weights.num_bits}A16" return f"{self.__class__.__name__} ({quantization_type})" diff --git a/server/text_generation_server/layers/compressed_tensors/wna16_int_24.py b/server/text_generation_server/layers/compressed_tensors/wna16_int_24.py new file mode 100644 index 00000000..27b8614c --- /dev/null +++ b/server/text_generation_server/layers/compressed_tensors/wna16_int_24.py @@ -0,0 +1,101 @@ +from typing import List, Union + +import torch + + +from compressed_tensors.quantization import QuantizationArgs, QuantizationType +from text_generation_server.layers.marlin.marlin import GPTQMarlin24Weight +from text_generation_server.utils.weights import Weights, WeightsLoader + + +class WNA16Int24Loader(WeightsLoader): + """ + Loader for W4A16/W8A16 INT 2:4 sparsity compressed-tensors checkpoints. + """ + + def __init__(self, weight_args: QuantizationArgs): + super().__init__() + + if weight_args.type != QuantizationType.INT: + raise ValueError( + f"{type(self).__name__} only supports wNa8 int checkpoints" + ) + + if weight_args.strategy == "group" and weight_args.group_size is None: + raise ValueError("`group_size` must be set when `actorder` is `group`") + + self.bits = weight_args.num_bits + self.group_size = weight_args.group_size + + def __str__(self) -> str: + quantization_type = f"W{self.bits}A16 2:4 sparsity" + + return f"{self.__class__.__name__} ({quantization_type})" + + def get_weights(self, weights: Weights, prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + weight_packed = weights.get_tensor(f"{prefix}.weight_packed") + meta = weights.get_tensor(f"{prefix}.meta") + scale_packed = weights.get_tensor(f"{prefix}.scale_packed") + return GPTQMarlin24Weight( + weight_packed=weight_packed, + meta=meta, + scale_packed=scale_packed, + bits=self.bits, + ) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + weight_packed = weights.get_packed_sharded( + f"{prefix}.weight_packed", dim=1, block_sizes=block_sizes + ) + meta = weights.get_packed_sharded( + f"{prefix}.meta", dim=1, block_sizes=block_sizes + ) + scale_packed = weights.get_packed_sharded( + f"{prefix}.scale_packed", dim=1, block_sizes=block_sizes + ) + return GPTQMarlin24Weight( + weight_packed=weight_packed, + meta=meta, + scale_packed=scale_packed, + bits=self.bits, + ) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + weight_packed = torch.cat( + [weights.get_sharded(f"{p}.weight_packed", dim=1) for p in prefixes], dim=1 + ) + meta = torch.cat( + [weights.get_sharded(f"{p}.meta", dim=1) for p in prefixes], dim=1 + ) + scale_packed = torch.cat( + [weights.get_sharded(f"{p}.scale_packed", dim=1) for p in prefixes], dim=1 + ) + return GPTQMarlin24Weight( + weight_packed=weight_packed, + meta=meta, + scale_packed=scale_packed, + bits=self.bits, + ) + + def get_weights_row(self, weights: Weights, prefix: str): + weight_packed = weights.get_sharded(f"{prefix}.weight_packed", dim=0) + meta = weights.get_sharded(f"{prefix}.meta", dim=0) + if self.group_size is None: + scale_packed = weights.get_tensor(f"{prefix}.scale_packed") + else: + scale_packed = weights.get_sharded(f"{prefix}.scale_packed", dim=0) + + return GPTQMarlin24Weight( + weight_packed=weight_packed, + meta=meta, + scale_packed=scale_packed, + bits=self.bits, + ) diff --git a/server/text_generation_server/layers/marlin/marlin.py b/server/text_generation_server/layers/marlin/marlin.py index 89ebaca6..1c80e31e 100644 --- a/server/text_generation_server/layers/marlin/marlin.py +++ b/server/text_generation_server/layers/marlin/marlin.py @@ -34,7 +34,9 @@ class MarlinWeightsLoader(WeightsLoader): B_meta = weights.get_tensor(f"{prefix}.B_meta") s = weights.get_tensor(f"{prefix}.s") - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + weight = GPTQMarlin24Weight( + weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits + ) else: try: B = weights.get_tensor(f"{prefix}.B") @@ -65,7 +67,9 @@ class MarlinWeightsLoader(WeightsLoader): f"{prefix}.s", dim=1, block_sizes=block_sizes ) - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + weight = GPTQMarlin24Weight( + weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits + ) else: B = weights.get_packed_sharded( f"{prefix}.B", dim=1, block_sizes=block_sizes @@ -96,7 +100,9 @@ class MarlinWeightsLoader(WeightsLoader): [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 ) - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + weight = GPTQMarlin24Weight( + weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits + ) else: try: B = torch.cat( @@ -132,7 +138,9 @@ class MarlinWeightsLoader(WeightsLoader): else: s = weights.get_sharded(f"{prefix}.s", dim=0) - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + weight = GPTQMarlin24Weight( + weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits + ) else: try: B = weights.get_sharded(f"{prefix}.B", dim=0) @@ -247,15 +255,15 @@ class GPTQMarlin24Weight: bits: quantized weight size. """ - B: torch.Tensor - B_meta: torch.Tensor - s: torch.Tensor + weight_packed: torch.Tensor + meta: torch.Tensor + scale_packed: torch.Tensor bits: int def __post_init__(self): - assert self.B.dtype == torch.int32 - assert self.B_meta.dtype == torch.int16 - assert self.s.dtype == torch.float16 + assert self.weight_packed.dtype == torch.int32 + assert self.meta.dtype == torch.int16 + assert self.scale_packed.dtype == torch.float16 def get_linear(self, bias: torch.Tensor): return GPTQMarlin24Linear( @@ -279,9 +287,13 @@ class GPTQMarlin24Linear(nn.Module): f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}" ) - in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2 - out_features = weight.s.shape[1] - groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] + in_features = weight.weight_packed.shape[0] * MARLIN_TILE_SIZE * 2 + out_features = weight.scale_packed.shape[1] + groupsize = ( + -1 + if weight.scale_packed.shape[0] == 1 + else in_features // weight.scale_packed.shape[0] + ) if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: supported_sizes = ", ".join( @@ -309,9 +321,9 @@ class GPTQMarlin24Linear(nn.Module): f"Number of input features ({in_features}) not divisable by group size ({groupsize})" ) - self.B = weight.B - self.B_meta = weight.B_meta - self.s = weight.s + self.weight_packed = weight.weight_packed + self.meta = weight.meta + self.scale_packed = weight.scale_packed if bias is not None: self.bias = bias else: @@ -320,7 +332,7 @@ class GPTQMarlin24Linear(nn.Module): self.workspace = torch.zeros( (out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL, dtype=torch.int, - device=weight.B.device, + device=weight.weight_packed.device, ) def forward(self, A: torch.Tensor) -> torch.Tensor: @@ -328,17 +340,17 @@ class GPTQMarlin24Linear(nn.Module): C = marlin_kernels.gptq_marlin_24_gemm( A.view(-1, A.shape[-1]), - self.B, - self.B_meta, - self.s, + self.weight_packed, + self.meta, + self.scale_packed, self.workspace, self.bits, A.shape[0], - self.s.shape[1], + self.scale_packed.shape[1], A.shape[1], ) - C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) + C = C.reshape(A.shape[:-1] + (self.scale_packed.shape[1],)) if self.bias is not None: C += self.bias