Add support for wNa16 int 2:4 compressed-tensors checkpoints (#2758)

This change adds support for wNa16 int checkpoints with 2:4 sparsity
using Marlin 2:4 kernels.
This commit is contained in:
Daniël de Kok 2024-11-20 18:25:23 +01:00 committed by GitHub
parent 2fda8845a7
commit 46a5a7e73e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 860 additions and 26 deletions

View File

@ -0,0 +1,104 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -7.5390625,
"text": "What"
},
{
"id": 374,
"logprob": -0.86035156,
"text": " is"
},
{
"id": 5655,
"logprob": -8.828125,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.4912109,
"text": " learning"
},
{
"id": 30,
"logprob": -2.1152344,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 34564,
"logprob": -1.765625,
"special": false,
"text": "Deep"
},
{
"id": 6975,
"logprob": -0.023864746,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.1060791,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.1940918,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -0.79785156,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.008262634,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.046569824,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0023479462,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -0.7626953,
"special": false,
"text": " that"
},
{
"id": 5829,
"logprob": -1.0107422,
"special": false,
"text": " uses"
}
],
"top_tokens": null
},
"generated_text": "Deep learning is a subset of machine learning that uses"
}

View File

@ -0,0 +1,99 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -7.5390625,
"text": "What"
},
{
"id": 374,
"logprob": -0.86035156,
"text": " is"
},
{
"id": 5655,
"logprob": -8.828125,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.4912109,
"text": " learning"
}
],
"seed": 0,
"tokens": [
{
"id": 5380,
"logprob": 0.0,
"special": false,
"text": "?\n"
},
{
"id": 34564,
"logprob": 0.0,
"special": false,
"text": "Deep"
},
{
"id": 6975,
"logprob": 0.0,
"special": false,
"text": " learning"
},
{
"id": 320,
"logprob": -0.19580078,
"special": false,
"text": " ("
},
{
"id": 16931,
"logprob": -1.7783203,
"special": false,
"text": "DL"
},
{
"id": 8,
"logprob": 0.0,
"special": false,
"text": ")"
},
{
"id": 374,
"logprob": -1.4287109,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": 0.0,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": 0.0,
"special": false,
"text": " of"
}
],
"top_tokens": null
},
"generated_text": "What is deep learning?\nDeep learning (DL) is a subset of"
}

View File

@ -0,0 +1,418 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -7.5390625,
"text": "What"
},
{
"id": 374,
"logprob": -0.86035156,
"text": " is"
},
{
"id": 5655,
"logprob": -8.828125,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.4912109,
"text": " learning"
},
{
"id": 30,
"logprob": -2.1152344,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 34564,
"logprob": -1.765625,
"special": false,
"text": "Deep"
},
{
"id": 6975,
"logprob": -0.024002075,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.10760498,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.19580078,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -0.7993164,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.008300781,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.046295166,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.002374649,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -0.7651367,
"special": false,
"text": " that"
},
{
"id": 5829,
"logprob": -1.0107422,
"special": false,
"text": " uses"
}
],
"top_tokens": null
},
"generated_text": "Deep learning is a subset of machine learning that uses"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -7.5351562,
"text": "What"
},
{
"id": 374,
"logprob": -0.85791016,
"text": " is"
},
{
"id": 5655,
"logprob": -8.828125,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.4882812,
"text": " learning"
},
{
"id": 30,
"logprob": -2.1210938,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 34564,
"logprob": -1.7597656,
"special": false,
"text": "Deep"
},
{
"id": 6975,
"logprob": -0.024032593,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.10748291,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.19592285,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -0.7988281,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.008354187,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.046569824,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0023517609,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -0.7661133,
"special": false,
"text": " that"
},
{
"id": 5829,
"logprob": -1.0107422,
"special": false,
"text": " uses"
}
],
"top_tokens": null
},
"generated_text": "Deep learning is a subset of machine learning that uses"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -7.5351562,
"text": "What"
},
{
"id": 374,
"logprob": -0.85791016,
"text": " is"
},
{
"id": 5655,
"logprob": -8.828125,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.4882812,
"text": " learning"
},
{
"id": 30,
"logprob": -2.1210938,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 34564,
"logprob": -1.7597656,
"special": false,
"text": "Deep"
},
{
"id": 6975,
"logprob": -0.024032593,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.10748291,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.19592285,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -0.7988281,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.008354187,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.046569824,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0023517609,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -0.7661133,
"special": false,
"text": " that"
},
{
"id": 5829,
"logprob": -1.0107422,
"special": false,
"text": " uses"
}
],
"top_tokens": null
},
"generated_text": "Deep learning is a subset of machine learning that uses"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -7.5351562,
"text": "What"
},
{
"id": 374,
"logprob": -0.85791016,
"text": " is"
},
{
"id": 5655,
"logprob": -8.828125,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.4882812,
"text": " learning"
},
{
"id": 30,
"logprob": -2.1210938,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 34564,
"logprob": -1.7597656,
"special": false,
"text": "Deep"
},
{
"id": 6975,
"logprob": -0.024032593,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.10748291,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.19592285,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -0.7988281,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.008354187,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.046569824,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0023517609,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -0.7661133,
"special": false,
"text": " that"
},
{
"id": 5829,
"logprob": -1.0107422,
"special": false,
"text": " uses"
}
],
"top_tokens": null
},
"generated_text": "Deep learning is a subset of machine learning that uses"
}
]

View File

@ -0,0 +1,90 @@
import pytest
@pytest.fixture(scope="module")
def compressed_tensors_wna16_int_24_handle(launcher):
with launcher(
"danieldk/Llama-3.1-8B-w4a16-int-24",
num_shard=2,
quantize="compressed-tensors",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def compressed_tensors_wna16_int_24(compressed_tensors_wna16_int_24_handle):
await compressed_tensors_wna16_int_24_handle.health(300)
return compressed_tensors_wna16_int_24_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_compressed_tensors_wna16_int_24(
compressed_tensors_wna16_int_24, response_snapshot
):
response = await compressed_tensors_wna16_int_24.generate(
"What is deep learning?",
max_new_tokens=10,
decoder_input_details=True,
)
assert (
response.generated_text
== "Deep learning is a subset of machine learning that uses"
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_compressed_tensors_wna16_int_24_all_params(
compressed_tensors_wna16_int_24, response_snapshot
):
response = await compressed_tensors_wna16_int_24.generate(
"What is deep learning",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "What is deep learning?\nDeep learning (DL) is a subset of"
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_compressed_tensors_wna16_int_24_load(
compressed_tensors_wna16_int_24, generate_load, response_snapshot
):
responses = await generate_load(
compressed_tensors_wna16_int_24,
"What is deep learning?",
max_new_tokens=10,
n=4,
)
assert (
responses[0].generated_text
== "Deep learning is a subset of machine learning that uses"
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -13,7 +13,10 @@ from torch import nn
from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader
from text_generation_server.layers.compressed_tensors.w8a8_int import W8A8IntLoader from text_generation_server.layers.compressed_tensors.w8a8_int import W8A8IntLoader
from text_generation_server.layers.compressed_tensors.wna16_int import WNA16Loader from text_generation_server.layers.compressed_tensors.wna16_int_24 import (
WNA16Int24Loader,
)
from text_generation_server.layers.compressed_tensors.wna16_int import WNA16IntLoader
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
DefaultWeightsLoader, DefaultWeightsLoader,
@ -151,7 +154,14 @@ class CompressedTensorsLoader(WeightsLoader):
and weights.num_bits in (4, 8) and weights.num_bits in (4, 8)
): ):
# INT W4A16 or W8A16 (GPTQ/AWQ-like). # INT W4A16 or W8A16 (GPTQ/AWQ-like).
return WNA16Loader(weights) return WNA16IntLoader(weights)
elif (
format == CompressionFormat.marlin_24.value
and weights is not None
and weights.type == QuantizationType.INT
and weights.num_bits in (4, 8)
):
return WNA16Int24Loader(weights)
elif ( elif (
format format
in { in {

View File

@ -9,7 +9,7 @@ from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights, WeightsLoader from text_generation_server.utils.weights import Weights, WeightsLoader
class WNA16Loader(WeightsLoader): class WNA16IntLoader(WeightsLoader):
""" """
Loader for W4A16/W8A16 INT compressed-tensors parameters. Loader for W4A16/W8A16 INT compressed-tensors parameters.
""" """
@ -22,7 +22,7 @@ class WNA16Loader(WeightsLoader):
) )
def __str__(self) -> str: def __str__(self) -> str:
quantization_type = f"W{self.weights.num_bits}8A16" quantization_type = f"W{self.weights.num_bits}A16"
return f"{self.__class__.__name__} ({quantization_type})" return f"{self.__class__.__name__} ({quantization_type})"

View File

@ -0,0 +1,101 @@
from typing import List, Union
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
from text_generation_server.layers.marlin.marlin import GPTQMarlin24Weight
from text_generation_server.utils.weights import Weights, WeightsLoader
class WNA16Int24Loader(WeightsLoader):
"""
Loader for W4A16/W8A16 INT 2:4 sparsity compressed-tensors checkpoints.
"""
def __init__(self, weight_args: QuantizationArgs):
super().__init__()
if weight_args.type != QuantizationType.INT:
raise ValueError(
f"{type(self).__name__} only supports wNa8 int checkpoints"
)
if weight_args.strategy == "group" and weight_args.group_size is None:
raise ValueError("`group_size` must be set when `actorder` is `group`")
self.bits = weight_args.num_bits
self.group_size = weight_args.group_size
def __str__(self) -> str:
quantization_type = f"W{self.bits}A16 2:4 sparsity"
return f"{self.__class__.__name__} ({quantization_type})"
def get_weights(self, weights: Weights, prefix: str):
"""
Get weights at the given prefix and apply without tensor paralllism.
"""
weight_packed = weights.get_tensor(f"{prefix}.weight_packed")
meta = weights.get_tensor(f"{prefix}.meta")
scale_packed = weights.get_tensor(f"{prefix}.scale_packed")
return GPTQMarlin24Weight(
weight_packed=weight_packed,
meta=meta,
scale_packed=scale_packed,
bits=self.bits,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
weight_packed = weights.get_packed_sharded(
f"{prefix}.weight_packed", dim=1, block_sizes=block_sizes
)
meta = weights.get_packed_sharded(
f"{prefix}.meta", dim=1, block_sizes=block_sizes
)
scale_packed = weights.get_packed_sharded(
f"{prefix}.scale_packed", dim=1, block_sizes=block_sizes
)
return GPTQMarlin24Weight(
weight_packed=weight_packed,
meta=meta,
scale_packed=scale_packed,
bits=self.bits,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
weight_packed = torch.cat(
[weights.get_sharded(f"{p}.weight_packed", dim=1) for p in prefixes], dim=1
)
meta = torch.cat(
[weights.get_sharded(f"{p}.meta", dim=1) for p in prefixes], dim=1
)
scale_packed = torch.cat(
[weights.get_sharded(f"{p}.scale_packed", dim=1) for p in prefixes], dim=1
)
return GPTQMarlin24Weight(
weight_packed=weight_packed,
meta=meta,
scale_packed=scale_packed,
bits=self.bits,
)
def get_weights_row(self, weights: Weights, prefix: str):
weight_packed = weights.get_sharded(f"{prefix}.weight_packed", dim=0)
meta = weights.get_sharded(f"{prefix}.meta", dim=0)
if self.group_size is None:
scale_packed = weights.get_tensor(f"{prefix}.scale_packed")
else:
scale_packed = weights.get_sharded(f"{prefix}.scale_packed", dim=0)
return GPTQMarlin24Weight(
weight_packed=weight_packed,
meta=meta,
scale_packed=scale_packed,
bits=self.bits,
)

View File

@ -34,7 +34,9 @@ class MarlinWeightsLoader(WeightsLoader):
B_meta = weights.get_tensor(f"{prefix}.B_meta") B_meta = weights.get_tensor(f"{prefix}.B_meta")
s = weights.get_tensor(f"{prefix}.s") s = weights.get_tensor(f"{prefix}.s")
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) weight = GPTQMarlin24Weight(
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
)
else: else:
try: try:
B = weights.get_tensor(f"{prefix}.B") B = weights.get_tensor(f"{prefix}.B")
@ -65,7 +67,9 @@ class MarlinWeightsLoader(WeightsLoader):
f"{prefix}.s", dim=1, block_sizes=block_sizes f"{prefix}.s", dim=1, block_sizes=block_sizes
) )
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) weight = GPTQMarlin24Weight(
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
)
else: else:
B = weights.get_packed_sharded( B = weights.get_packed_sharded(
f"{prefix}.B", dim=1, block_sizes=block_sizes f"{prefix}.B", dim=1, block_sizes=block_sizes
@ -96,7 +100,9 @@ class MarlinWeightsLoader(WeightsLoader):
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
) )
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) weight = GPTQMarlin24Weight(
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
)
else: else:
try: try:
B = torch.cat( B = torch.cat(
@ -132,7 +138,9 @@ class MarlinWeightsLoader(WeightsLoader):
else: else:
s = weights.get_sharded(f"{prefix}.s", dim=0) s = weights.get_sharded(f"{prefix}.s", dim=0)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) weight = GPTQMarlin24Weight(
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
)
else: else:
try: try:
B = weights.get_sharded(f"{prefix}.B", dim=0) B = weights.get_sharded(f"{prefix}.B", dim=0)
@ -247,15 +255,15 @@ class GPTQMarlin24Weight:
bits: quantized weight size. bits: quantized weight size.
""" """
B: torch.Tensor weight_packed: torch.Tensor
B_meta: torch.Tensor meta: torch.Tensor
s: torch.Tensor scale_packed: torch.Tensor
bits: int bits: int
def __post_init__(self): def __post_init__(self):
assert self.B.dtype == torch.int32 assert self.weight_packed.dtype == torch.int32
assert self.B_meta.dtype == torch.int16 assert self.meta.dtype == torch.int16
assert self.s.dtype == torch.float16 assert self.scale_packed.dtype == torch.float16
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
return GPTQMarlin24Linear( return GPTQMarlin24Linear(
@ -279,9 +287,13 @@ class GPTQMarlin24Linear(nn.Module):
f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}" f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}"
) )
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2 in_features = weight.weight_packed.shape[0] * MARLIN_TILE_SIZE * 2
out_features = weight.s.shape[1] out_features = weight.scale_packed.shape[1]
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] groupsize = (
-1
if weight.scale_packed.shape[0] == 1
else in_features // weight.scale_packed.shape[0]
)
if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
supported_sizes = ", ".join( supported_sizes = ", ".join(
@ -309,9 +321,9 @@ class GPTQMarlin24Linear(nn.Module):
f"Number of input features ({in_features}) not divisable by group size ({groupsize})" f"Number of input features ({in_features}) not divisable by group size ({groupsize})"
) )
self.B = weight.B self.weight_packed = weight.weight_packed
self.B_meta = weight.B_meta self.meta = weight.meta
self.s = weight.s self.scale_packed = weight.scale_packed
if bias is not None: if bias is not None:
self.bias = bias self.bias = bias
else: else:
@ -320,7 +332,7 @@ class GPTQMarlin24Linear(nn.Module):
self.workspace = torch.zeros( self.workspace = torch.zeros(
(out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL, (out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL,
dtype=torch.int, dtype=torch.int,
device=weight.B.device, device=weight.weight_packed.device,
) )
def forward(self, A: torch.Tensor) -> torch.Tensor: def forward(self, A: torch.Tensor) -> torch.Tensor:
@ -328,17 +340,17 @@ class GPTQMarlin24Linear(nn.Module):
C = marlin_kernels.gptq_marlin_24_gemm( C = marlin_kernels.gptq_marlin_24_gemm(
A.view(-1, A.shape[-1]), A.view(-1, A.shape[-1]),
self.B, self.weight_packed,
self.B_meta, self.meta,
self.s, self.scale_packed,
self.workspace, self.workspace,
self.bits, self.bits,
A.shape[0], A.shape[0],
self.s.shape[1], self.scale_packed.shape[1],
A.shape[1], A.shape[1],
) )
C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) C = C.reshape(A.shape[:-1] + (self.scale_packed.shape[1],))
if self.bias is not None: if self.bias is not None:
C += self.bias C += self.bias