Move quantized weight handling out of the `Weights` class (#2194)

Quantized weights were loaded in the `Weights` class, but this was
getting quite unwieldy, where every higher level method to load weights
was a long conditional to cover all the different quantizers.

This change moves loading of quantized weights out of the `Weights`
class. This is done by defining a simple `WeightsLoader` interface
that is implemented by `Exl2WeightsLoader`, `GPTQWeightsLoader`,
and `MarlinWeightsLoader`. These implementations are in the quantizers'
respective modules. The `Weights` class provides the low-level load
operations (such as loading tensors or sharded tensors), but delegates
loads that need quantizer-specific weight processing to a loader. The
loaders still use the low-level functionality provided by `Weights`.

I initially tried making a hierarchy where a class like `GPTQWeights`
would inherit from `Weights`. But it is not very flexible (e.g. does
not work well with the new weight storage mock used in tests) and
the implicit indirections made the code harder to follow.
This commit is contained in:
Daniël de Kok 2024-07-09 20:04:03 +02:00 committed by GitHub
parent 4c976fb406
commit 8511669cb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 896 additions and 731 deletions

View File

@ -2,6 +2,7 @@ import torch
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
) )
from text_generation_server.utils.weights import DefaultWeightsLoader
class ProcessGroup: class ProcessGroup:
@ -42,7 +43,12 @@ class Weights:
def test_weight_hub_files_offline_error(): def test_weight_hub_files_offline_error():
vocab_size = 17 vocab_size = 17
weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256) weights = Weights(
rank=0,
world_size=1,
vocab_size=vocab_size,
hidden_dim=256,
)
embeddings = TensorParallelEmbedding("", weights) embeddings = TensorParallelEmbedding("", weights)
input_ids = torch.arange(vocab_size) input_ids = torch.arange(vocab_size)

View File

@ -1,13 +1,47 @@
import pytest import pytest
import torch import torch
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import (
from text_generation_server.layers.gptq import GPTQWeight DefaultWeightsLoader,
from text_generation_server.layers.exl2 import Exl2Weight Weights,
from text_generation_server.layers.marlin import MarlinWeight WeightsLoader,
)
from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader
from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader
from text_generation_server.layers.marlin import MarlinWeight, MarlinWeightsLoader
from types import SimpleNamespace from types import SimpleNamespace
from typing import List, Optional, Dict, Union from typing import List, Optional, Dict, Union
from pathlib import Path from pathlib import Path
@pytest.fixture
def gptq_weights_loader():
return GPTQWeightsLoader(
bits=4,
groupsize=-1,
desc_act=False,
quant_method="gptq",
quantize="gptq",
sym=True,
)
@pytest.fixture
def gptq_weights_loader_awq():
return GPTQWeightsLoader(
bits=4,
groupsize=-1,
desc_act=False,
quant_method="awq",
quantize="awq",
sym=True,
)
@pytest.fixture
def marlin_weights_loader():
return MarlinWeightsLoader(bits=4, is_marlin_24=False)
dummy_file_system = { dummy_file_system = {
"test_weights": { "test_weights": {
"layer.0.weight": torch.tensor( "layer.0.weight": torch.tensor(
@ -58,7 +92,7 @@ dummy_file_system = {
dtype=torch.float32, dtype=torch.float32,
), ),
}, },
"test_get_multi_weights_row": { "test_get_weights_row": {
"weight.weight": torch.tensor( "weight.weight": torch.tensor(
[ [
[1, 2], [1, 2],
@ -101,7 +135,7 @@ dummy_file_system = {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), "weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
}, },
"test_get_multi_weights_row_gptq": { "test_get_weights_row_gptq": {
"weight.qweight": torch.tensor( "weight.qweight": torch.tensor(
[ [
[1, 2], [1, 2],
@ -200,7 +234,7 @@ dummy_file_system = {
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16), "weight.q_groups": torch.tensor([4], dtype=torch.int16),
}, },
"test_get_multi_weights_row_exl2": { "test_get_weights_row_exl2": {
"weight.q_weight": torch.tensor( "weight.q_weight": torch.tensor(
[ [
[1, 2], [1, 2],
@ -245,7 +279,7 @@ dummy_file_system = {
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16), "weight.q_groups": torch.tensor([4], dtype=torch.int16),
}, },
"test_get_multi_weights_row_marlin": { "test_get_weights_row_marlin": {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
}, },
@ -308,6 +342,7 @@ class MockWeights(Weights):
dummy_fs, dummy_fs,
aliases: Optional[Dict[str, List[str]]] = None, aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None, prefix: Optional[str] = None,
weights_loader: Optional[WeightsLoader] = None,
): ):
routing = {} routing = {}
self.dummy_fs = dummy_fs self.dummy_fs = dummy_fs
@ -327,6 +362,9 @@ class MockWeights(Weights):
self.dtype = dtype self.dtype = dtype
self.process_group = process_group self.process_group = process_group
self.prefix = prefix self.prefix = prefix
self.weights_loader = (
DefaultWeightsLoader() if weights_loader is None else weights_loader
)
self._handles = {} self._handles = {}
def _get_handle(self, filename: Union[Path, str]): def _get_handle(self, filename: Union[Path, str]):
@ -412,12 +450,10 @@ def test_get_weights_col_packed():
) )
prefix = "weight" prefix = "weight"
quantize = None
block_sizes = 1 block_sizes = 1
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
@ -448,12 +484,10 @@ def test_get_weights_col_packed_block_size():
) )
prefix = "weight" prefix = "weight"
quantize = None
block_sizes = 2 block_sizes = 2
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
@ -484,12 +518,10 @@ def test_get_weights_col_packed_block_size_arr():
) )
prefix = "weight" prefix = "weight"
quantize = None
block_sizes = [1, 1] block_sizes = [1, 1]
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
@ -519,11 +551,9 @@ def test_get_multi_weights_col():
) )
prefixes = ["weight", "weight"] prefixes = ["weight", "weight"]
quantize = None
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=prefixes, prefixes=prefixes,
quantize=quantize,
dim=0, dim=0,
) )
@ -545,10 +575,10 @@ def test_get_multi_weights_col():
) )
def test_get_multi_weights_row(): def test_get_weights_row():
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row", "test_get_weights_row",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
@ -557,11 +587,9 @@ def test_get_multi_weights_row():
) )
prefix = "weight" prefix = "weight"
quantize = None
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
assert torch.allclose( assert torch.allclose(
@ -576,7 +604,7 @@ def test_get_multi_weights_row():
# test_get_weights_col # test_get_weights_col
def test_get_weights_col_awq(): def test_get_weights_col_awq(gptq_weights_loader_awq):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_gptq", "test_get_weights_col_gptq",
@ -585,14 +613,13 @@ def test_get_weights_col_awq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
) )
prefix = "weight" prefix = "weight"
quantize = "awq"
w = weights.get_weights_col( w = weights.get_weights_col(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
@ -617,7 +644,7 @@ def test_get_weights_col_awq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_weights_col_gtpq(): def test_get_weights_col_gtpq(gptq_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_gptq", "test_get_weights_col_gptq",
@ -626,14 +653,13 @@ def test_get_weights_col_gtpq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "gptq"
w = weights.get_weights_col( w = weights.get_weights_col(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
@ -664,14 +690,13 @@ def test_get_weights_col_exl2():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
) )
prefix = "weight" prefix = "weight"
quantize = "exl2"
w = weights.get_weights_col( w = weights.get_weights_col(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
scaled_scale_max = 0.3906 * 256 scaled_scale_max = 0.3906 * 256
@ -692,7 +717,7 @@ def test_get_weights_col_exl2():
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_weights_col_marlin(): def test_get_weights_col_marlin(marlin_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_marlin", "test_get_weights_col_marlin",
@ -701,14 +726,13 @@ def test_get_weights_col_marlin():
dtype=torch.float16, dtype=torch.float16,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "marlin"
w = weights.get_weights_col( w = weights.get_weights_col(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = MarlinWeight( expected_weight = MarlinWeight(
@ -723,7 +747,7 @@ def test_get_weights_col_marlin():
# test_get_weights_col_packed # test_get_weights_col_packed
def test_get_weights_col_packed_awq(): def test_get_weights_col_packed_awq(gptq_weights_loader_awq):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_packed_gptq", "test_get_weights_col_packed_gptq",
@ -732,15 +756,14 @@ def test_get_weights_col_packed_awq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
) )
prefix = "weight" prefix = "weight"
quantize = "awq"
block_sizes = 1 block_sizes = 1
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
@ -773,15 +796,14 @@ def test_get_weights_col_packed_exl2():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
) )
prefix = "weight" prefix = "weight"
quantize = "exl2"
block_sizes = 1 block_sizes = 1
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
@ -803,7 +825,7 @@ def test_get_weights_col_packed_exl2():
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_weights_col_packed_gptq(): def test_get_weights_col_packed_gptq(gptq_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_packed_gptq", "test_get_weights_col_packed_gptq",
@ -812,14 +834,13 @@ def test_get_weights_col_packed_gptq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
) )
prefixes = ["weight"] prefixes = ["weight"]
quantize = "gptq"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=prefixes, prefixes=prefixes,
quantize=quantize,
dim=0, dim=0,
) )
@ -842,7 +863,7 @@ def test_get_weights_col_packed_gptq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_weights_col_packed_marlin(): def test_get_weights_col_packed_marlin(marlin_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_packed_marlin", "test_get_weights_col_packed_marlin",
@ -851,14 +872,13 @@ def test_get_weights_col_packed_marlin():
dtype=torch.float16, dtype=torch.float16,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=[prefix], prefixes=[prefix],
quantize=quantize,
dim=0, dim=0,
) )
@ -876,7 +896,7 @@ def test_get_weights_col_packed_marlin():
# test_get_multi_weights_col # test_get_multi_weights_col
def test_get_multi_weights_col_awq(): def test_get_multi_weights_col_awq(gptq_weights_loader_awq):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_col_gptq", "test_get_multi_weights_col_gptq",
@ -885,14 +905,13 @@ def test_get_multi_weights_col_awq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
) )
prefixes = ["weight"] prefixes = ["weight"]
quantize = "awq"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=prefixes, prefixes=prefixes,
quantize=quantize,
dim=0, dim=0,
) )
@ -924,22 +943,21 @@ def test_get_multi_weights_col_exl2():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
) )
prefix = "weight" prefix = "weight"
quantize = "exl2"
try: try:
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=[prefix], prefixes=[prefix],
quantize=quantize,
dim=0, dim=0,
) )
except ValueError as e: except ValueError as e:
assert e.args[0] == "get_multi_weights_col is not supported for exl2" assert e.args[0] == "get_multi_weights_col is not supported for exl2"
def test_get_multi_weights_col_gptq(): def test_get_multi_weights_col_gptq(gptq_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_col_gptq", "test_get_multi_weights_col_gptq",
@ -948,14 +966,13 @@ def test_get_multi_weights_col_gptq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
) )
prefixes = ["weight"] prefixes = ["weight"]
quantize = "gptq"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=prefixes, prefixes=prefixes,
quantize=quantize,
dim=0, dim=0,
) )
@ -978,7 +995,7 @@ def test_get_multi_weights_col_gptq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_col_marlin(): def test_get_multi_weights_col_marlin(marlin_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_col_marlin", "test_get_multi_weights_col_marlin",
@ -987,14 +1004,13 @@ def test_get_multi_weights_col_marlin():
dtype=torch.float16, dtype=torch.float16,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=[prefix], prefixes=[prefix],
quantize=quantize,
dim=0, dim=0,
) )
@ -1007,26 +1023,25 @@ def test_get_multi_weights_col_marlin():
assert torch.allclose(w.s, expected_weight.s), "s mismatch" assert torch.allclose(w.s, expected_weight.s), "s mismatch"
# test_get_multi_weights_row # test_get_weights_row
def test_get_multi_weights_row_awq(): def test_get_weights_row_awq(gptq_weights_loader_awq):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row_gptq", "test_get_weights_row_gptq",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
) )
prefix = "weight" prefix = "weight"
quantize = "awq"
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
@ -1048,23 +1063,22 @@ def test_get_multi_weights_row_awq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_row_exl2(): def test_get_weights_row_exl2():
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row_exl2", "test_get_weights_row_exl2",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
) )
prefix = "weight" prefix = "weight"
quantize = "exl2"
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
print(w) print(w)
@ -1086,23 +1100,22 @@ def test_get_multi_weights_row_exl2():
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_multi_weights_row_gptq(): def test_get_weights_row_gptq(gptq_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row_gptq", "test_get_weights_row_gptq",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "gptq"
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
@ -1124,23 +1137,22 @@ def test_get_multi_weights_row_gptq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_row_marlin(): def test_get_weights_row_marlin(marlin_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row_marlin", "test_get_weights_row_marlin",
], ],
device="cpu", device="cpu",
dtype=torch.float16, dtype=torch.float16,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = MarlinWeight( expected_weight = MarlinWeight(

View File

@ -1,6 +1,9 @@
import torch import torch
from typing import List, Union
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils.weights import WeightsLoader, Weights
@dataclass @dataclass
class Exl2Weight: class Exl2Weight:
@ -21,3 +24,60 @@ class Exl2Weight:
@property @property
def device(self) -> torch.device: def device(self) -> torch.device:
return self.q_weight.device return self.q_weight.device
class Exl2WeightsLoader(WeightsLoader):
"""Loader for exl2-quantized weights."""
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
raise RuntimeError("Column-packed weights are not supported for exl")
def get_weights_col(self, weights: Weights, prefix: str):
try:
q_weight = weights.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = weights.get_tensor(f"{prefix}.q_scale")
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
q_groups = weights.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
raise ValueError("get_multi_weights_col is not supported for exl2")
def get_weights_row(self, weights: Weights, prefix: str):
try:
q_weight = weights.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = weights.get_tensor(f"{prefix}.q_scale")
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
q_groups = weights.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)

View File

@ -1,20 +1,14 @@
from dataclasses import dataclass from dataclasses import dataclass
from loguru import logger
import os import os
from typing import Optional from typing import List, Optional, Union
from safetensors import SafetensorError
from text_generation_server.utils.weights import Weights, WeightsLoader
import torch import torch
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import (
SYSTEM, SYSTEM,
) )
from text_generation_server.utils.log import log_once
@dataclass
class GPTQParams:
bits: int
checkpoint_format: Optional[str]
groupsize: int
desc_act: bool
quant_method: str
sym: bool
@dataclass @dataclass
@ -69,3 +63,341 @@ elif CAN_EXLLAMA:
pass pass
from text_generation_server.layers.gptq.quant_linear import QuantLinear from text_generation_server.layers.gptq.quant_linear import QuantLinear
class GPTQWeightsLoader(WeightsLoader):
"""
Loader for GPTQ- and AWQ-quantized weights.
"""
def __init__(
self,
*,
bits: int,
desc_act: bool,
groupsize: int,
quant_method: str,
quantize: str,
sym: bool,
):
self.bits = bits
self.desc_act = desc_act
self.groupsize = groupsize
self.quant_method = quant_method
self.quantize = quantize
self.sym = sym
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
)
scales = weights.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
scales = scales.to(dtype=weights.dtype)
self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
g_idx = weights.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
sym=self.sym,
sharded_infeatures=False,
)
qzeros = weights.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
if self.quantize == "gptq" and self.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.g_idx")
elif self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_exllama=False,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try:
qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat(
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
sym=self.sym,
sharded_infeatures=False,
)
qzeros = torch.cat(
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
from text_generation_server.layers.gptq import HAS_EXLLAMA
use_exllama = (
self.bits == 4
and HAS_EXLLAMA
and self.quantize == "gptq"
and not self.desc_act
)
if self.quantize == "gptq" and self.quant_method == "gptq":
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_exllama=use_exllama,
)
def get_weights_row(self, weights: Weights, prefix: str):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
if self.desc_act or self.groupsize == -1:
scales = weights.get_tensor(f"{prefix}.scales")
else:
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = weights.process_group.size() > 1
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
sym=self.sym,
sharded_infeatures=sharded_in_features,
)
use_exllama = True
if self.bits != 4:
use_exllama = False
if self.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if self.quantize == "gptq" and self.quant_method == "gptq":
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
else:
g_idx = None
if weights.process_group.size() > 1:
if g_idx is not None:
if (
not torch.equal(
g_idx.cpu(),
torch.tensor(
[i // self.groupsize for i in range(g_idx.shape[0])],
dtype=torch.int32,
),
)
and not (g_idx == 0).all()
):
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
from text_generation_server.layers.gptq import (
HAS_EXLLAMA,
CAN_EXLLAMA,
GPTQWeight,
)
if use_exllama:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
)
use_exllama = False
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama and self.groupsize != -1:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
else:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales")
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_exllama=use_exllama,
)
def _get_gptq_params(self, weights: Weights):
try:
self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False
self.sym = False
self.quant_method = "gptq"
except (SafetensorError, RuntimeError) as e:
pass

View File

@ -16,6 +16,8 @@ from text_generation_server.layers.gptq.quant_linear import QuantLinear
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
from text_generation_server.utils.weights import DefaultWeightsLoader
DEV = torch.device("cuda:0") DEV = torch.device("cuda:0")
@ -891,6 +893,7 @@ def quantize(
dtype=torch.float16, dtype=torch.float16,
process_group=process_group, process_group=process_group,
aliases={"embed_tokens.weight": ["lm_head.weight"]}, aliases={"embed_tokens.weight": ["lm_head.weight"]},
weights_loader=DefaultWeightsLoader(),
) )
hooks = [] hooks = []
for name, module in model.named_modules(): for name, module in model.named_modules():

View File

@ -1,10 +1,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
from text_generation_server.utils.weights import Weights, WeightsLoader
import torch import torch
import torch.nn as nn import torch.nn as nn
from text_generation_server.layers.gptq import GPTQParams
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
try: try:
@ -24,16 +24,132 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
MARLIN_TILE_SIZE = 16 MARLIN_TILE_SIZE = 16
def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool: class MarlinWeightsLoader(WeightsLoader):
"""Loader for Marlin-quantized weights."""
def __init__(self, *, bits: int, is_marlin_24: bool):
self.bits = bits
self.is_marlin_24 = is_marlin_24
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
if self.is_marlin_24:
B = weights.get_packed_sharded(
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
)
B_meta = weights.get_packed_sharded(
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
)
s = weights.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
B = weights.get_packed_sharded(
f"{prefix}.B", dim=1, block_sizes=block_sizes
)
s = weights.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = MarlinWeight(B=B, s=s)
return weight
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized"
)
B_meta = torch.cat(
[weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized"
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = MarlinWeight(B=B, s=s)
return weight
def get_weights_row(self, weights: Weights, prefix: str):
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = weights.get_sharded(f"{prefix}.B_24", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0)
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = weights.get_tensor(f"{prefix}.s")
else:
s = weights.get_sharded(f"{prefix}.s", dim=0)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = weights.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized."
)
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = weights.get_tensor(f"{prefix}.s")
else:
s = weights.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
return weight
def can_use_gptq_marlin(
*, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool
) -> bool:
return ( return (
SYSTEM == "cuda" SYSTEM == "cuda"
and marlin_kernels is not None and marlin_kernels is not None
and has_sm_8_0 and has_sm_8_0
and quantize == "gptq" and quantize == "gptq"
and gptq_params.quant_method == "gptq" and quant_method == "gptq"
and gptq_params.bits in GPTQ_MARLIN_BITS and bits in GPTQ_MARLIN_BITS
and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES and groupsize in GPTQ_MARLIN_GROUP_SIZES
and gptq_params.sym and sym
) )

View File

@ -52,7 +52,7 @@ class TensorParallelHead(SuperLayer):
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
except: except:
# ...otherwise they are quantized. # ...otherwise they are quantized.
weight = weights.get_weights_col(prefix, config.quantize) weight = weights.get_weights_col(prefix)
should_gather = weights.process_group.size() > 1 should_gather = weights.process_group.size() > 1
elif weights.process_group.size() > 1: elif weights.process_group.size() > 1:
try: try:
@ -129,9 +129,7 @@ class TensorParallelColumnLinear(SuperLayer):
@classmethod @classmethod
def load_gate_up(cls, config, prefix: str, weights, bias: bool): def load_gate_up(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact""" """Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_gate_up( weight = weights.get_weights_col_packed_gate_up(prefix)
prefix, quantize=config.quantize
)
if bias: if bias:
raise NotImplementedError("packed_gate_up only implemented without bias") raise NotImplementedError("packed_gate_up only implemented without bias")
else: else:
@ -152,7 +150,6 @@ class TensorParallelColumnLinear(SuperLayer):
"""Specific method when the QKV was joined after the fact""" """Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_qkv( weight = weights.get_weights_col_packed_qkv(
prefix, prefix,
quantize=config.quantize,
num_heads=num_heads, num_heads=num_heads,
num_key_value_heads=num_key_value_heads, num_key_value_heads=num_key_value_heads,
) )
@ -165,7 +162,7 @@ class TensorParallelColumnLinear(SuperLayer):
@classmethod @classmethod
def load(cls, config, prefix: str, weights, bias: bool): def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_weights_col(prefix, config.quantize) weight = weights.get_weights_col(prefix)
if bias: if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else: else:
@ -178,14 +175,12 @@ class TensorParallelColumnLinear(SuperLayer):
if config.quantize == "exl2": if config.quantize == "exl2":
linears = [] linears = []
for prefix in prefixes: for prefix in prefixes:
weight = weights.get_weights_col(prefix, config.quantize) weight = weights.get_weights_col(prefix)
b = weights.get_tensor(f"{prefix}.bias") if bias else None b = weights.get_tensor(f"{prefix}.bias") if bias else None
linears.append(get_linear(weight, b, config.quantize)) linears.append(get_linear(weight, b, config.quantize))
linear = LayerConcat(linears) linear = LayerConcat(linears)
else: else:
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(prefixes, dim=dim)
prefixes, quantize=config.quantize, dim=dim
)
if bias: if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
bias = torch.cat(b, dim=dim) bias = torch.cat(b, dim=dim)
@ -202,7 +197,7 @@ class TensorParallelRowLinear(SuperLayer):
@classmethod @classmethod
def load(cls, config, prefix: str, weights, bias: bool): def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process

View File

@ -20,6 +20,7 @@ from text_generation_server.utils import (
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
@ -546,12 +547,17 @@ class CausalLM(Model):
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
weights_loader=weights_loader,
) )
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = "" prefix = ""
model = model_class(prefix, config, weights) model = model_class(prefix, config, weights)

View File

@ -163,7 +163,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )

View File

@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )

View File

@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )

View File

@ -61,7 +61,6 @@ def _load_qkv_gptq(config, prefix: str, weights):
# Weights # Weights
weight = weights.get_weights_col_packed_qkv( weight = weights.get_weights_col_packed_qkv(
f"{prefix}.c_attn", f"{prefix}.c_attn",
config.quantize,
config.num_attention_heads, config.num_attention_heads,
config.num_attention_heads, config.num_attention_heads,
) )
@ -137,7 +136,7 @@ def load_row(config, prefix: str, weights, bias: bool):
"""load_row, but with transposed weight matrices.""" """load_row, but with transposed weight matrices."""
if config.quantize == "gptq": if config.quantize == "gptq":
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
else: else:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
@ -155,9 +154,7 @@ def load_row(config, prefix: str, weights, bias: bool):
def load_col(config, prefix: str, weights, bias: bool): def load_col(config, prefix: str, weights, bias: bool):
"""load_col, but with transposed weight matrices.""" """load_col, but with transposed weight matrices."""
if config.quantize == "gptq": if config.quantize == "gptq":
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col([prefix], dim=1)
[prefix], quantize=config.quantize, dim=1
)
else: else:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T weight = weights.get_sharded(f"{prefix}.weight", dim=1).T

View File

@ -135,7 +135,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )

View File

@ -48,7 +48,7 @@ from text_generation_server.layers.rotary import (
def load_row(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process
@ -64,7 +64,7 @@ def load_row(config, prefix: str, weights, bias: bool):
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
weight = weights.get_multi_weights_col([prefix], quantize=config.quantize, dim=0) weight = weights.get_multi_weights_col([prefix], dim=0)
if isinstance(weight, torch.Tensor): if isinstance(weight, torch.Tensor):
# Only on non quantized versions # Only on non quantized versions
weight = ( weight = (

View File

@ -85,7 +85,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )

View File

@ -23,7 +23,7 @@ from text_generation_server.layers.attention import (
def load_row(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process

View File

@ -17,6 +17,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
get_linear, get_linear,
) )
from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
@ -81,11 +82,13 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device) qzeros = qzeros.to(device=weights.device)
gptq_params = weights._get_gptq_params() loader = weights.weights_loader
if gptq_params.quant_method == "gptq": assert isinstance(loader, GPTQWeightsLoader)
loader._get_gptq_params(weights)
if loader.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device) g_idx = g_idx.to(device=weights.device)
elif gptq_params.quant_method == "awq": elif loader.quant_method == "awq":
g_idx = None g_idx = None
from text_generation_server.layers.awq.conversion_utils import ( from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq, fast_awq_to_gptq,
@ -100,8 +103,8 @@ def _load_multi_mqa_gptq(
qzeros=qzeros, qzeros=qzeros,
scales=scales, scales=scales,
g_idx=g_idx, g_idx=g_idx,
bits=gptq_params.bits, bits=loader.bits,
groupsize=gptq_params.groupsize, groupsize=loader.groupsize,
use_exllama=HAS_EXLLAMA, use_exllama=HAS_EXLLAMA,
) )
@ -197,9 +200,7 @@ def load_col(config, prefix: str, weights, bias: bool):
if config.transpose: if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
else: else:
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col([prefix], dim=0)
[prefix], quantize=config.quantize, dim=0
)
if bias: if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = weights.get_sharded(f"{prefix}.bias", dim=0)
@ -212,7 +213,7 @@ def load_row(config, prefix: str, weights, bias: bool):
if config.transpose: if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
else: else:
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process

View File

@ -126,7 +126,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )

View File

@ -50,6 +50,7 @@ from text_generation_server.models.globals import (
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import (
@ -881,12 +882,16 @@ class FlashCausalLM(Model):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(quantize, model_id, revision)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
filenames, device, dtype, process_group=self.process_group, aliases=aliases filenames,
device,
dtype,
process_group=self.process_group,
aliases=aliases,
weights_loader=weights_loader,
) )
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = "" prefix = ""
model = model_class(prefix, config, weights) model = model_class(prefix, config, weights)

View File

@ -23,6 +23,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.quantization import get_loader
class IDEFICSSharded(IdeficsCausalLM): class IDEFICSSharded(IdeficsCausalLM):
@ -70,6 +71,9 @@ class IDEFICSSharded(IdeficsCausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
@ -77,6 +81,7 @@ class IDEFICSSharded(IdeficsCausalLM):
device=device, device=device,
dtype=dtype, dtype=dtype,
process_group=self.process_group, process_group=self.process_group,
weights_loader=weights_loader,
) )
model = IdeficsForVisionText2Text(config, weights) model = IdeficsForVisionText2Text(config, weights)

View File

@ -28,6 +28,7 @@ from text_generation_server.models.types import (
GeneratedText, GeneratedText,
) )
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens, Sampling from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -448,8 +449,17 @@ class Mamba(Model):
config.quantize = quantize config.quantize = quantize
config.speculator = speculator config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
weights_loader=weights_loader,
)
model = MambaModel(config, weights) model = MambaModel(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(Mamba, self).__init__( super(Mamba, self).__init__(

View File

@ -18,6 +18,7 @@ from text_generation_server.utils import (
Weights, Weights,
) )
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
@ -586,6 +587,9 @@ class Seq2SeqLM(Model):
) )
tokenizer.bos_token_id = config.decoder_start_token_id tokenizer.bos_token_id = config.decoder_start_token_id
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
@ -594,6 +598,7 @@ class Seq2SeqLM(Model):
dtype=dtype, dtype=dtype,
process_group=self.process_group, process_group=self.process_group,
aliases=aliases, aliases=aliases,
weights_loader=weights_loader,
) )
if config.quantize in ["awq", "exl2", "gptq", "marlin"]: if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)

View File

@ -0,0 +1,119 @@
from typing import Optional
import os
import json
from dataclasses import dataclass
from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import DefaultWeightsLoader, WeightsLoader
@dataclass
class _QuantizerConfig:
bits: int
checkpoint_format: Optional[str]
desc_act: bool
groupsize: int
quant_method: str
sym: bool
# We should probably do this with Pytantic JSON deserialization,
# but for now we'll stay close to the old _set_gptq_params.
def _get_quantizer_config(model_id, revision):
bits = 4
groupsize = -1
quant_method = "gptq"
checkpoint_format = None
sym = True
desc_act = False
filename = "config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(model_id, filename=filename, revision=revision)
with open(filename, "r") as f:
data = json.load(f)
bits = data["quantization_config"]["bits"]
groupsize = data["quantization_config"]["group_size"]
# Order is important here, desc_act is missing on some real models
quant_method = data["quantization_config"]["quant_method"]
checkpoint_format = data["quantization_config"].get("checkpoint_format")
sym = data["quantization_config"]["sym"]
desc_act = data["quantization_config"]["desc_act"]
except Exception:
filename = "quantize_config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
bits = data["bits"]
groupsize = data["group_size"]
sym = data["sym"]
desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
quant_method = "awq"
except Exception:
filename = "quant_config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
bits = data["w_bit"]
groupsize = data["q_group_size"]
desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
quant_method = "awq"
except Exception:
pass
return _QuantizerConfig(
bits=bits,
groupsize=groupsize,
quant_method=quant_method,
checkpoint_format=checkpoint_format,
sym=sym,
desc_act=desc_act,
)
def get_loader(
quantize: Optional[str], model_id: str, revision: Optional[str]
) -> WeightsLoader:
quantizer_config = _get_quantizer_config(model_id, revision)
if quantize in {"awq", "gptq"}:
from text_generation_server.layers.gptq import GPTQWeightsLoader
return GPTQWeightsLoader(
bits=quantizer_config.bits,
desc_act=quantizer_config.desc_act,
groupsize=quantizer_config.groupsize,
quant_method=quantizer_config.quant_method,
quantize=quantize,
sym=quantizer_config.sym,
)
elif quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2WeightsLoader
return Exl2WeightsLoader()
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeightsLoader
return MarlinWeightsLoader(
bits=quantizer_config.bits,
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
)
else:
return DefaultWeightsLoader()

View File

@ -1,13 +1,88 @@
import os from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from safetensors import safe_open, SafetensorError from safetensors import safe_open
import torch import torch
from loguru import logger
from huggingface_hub import hf_hub_download
import json class WeightsLoader(ABC):
from text_generation_server.layers.gptq import GPTQParams """
from text_generation_server.utils.log import log_once Instances of this type implement higher-level weight loading.
At a low-level, every weight is stored in the Safetensors format.
The interpretation of weights may be different however, for instance
could be packed, quantized weights. Loaders are responsible for
interpreting the raw tensors, sharding tensors in a manner compatible
with the format, etc.
"""
@abstractmethod
def get_weights_col_packed(
self,
weights: "Weights",
prefix: str,
block_sizes: Union[int, List[int]],
):
"""
Get the packed weights at the given prefix with column-splitting for
tensor parallelism. This method should be used when multiple different
weights are packed into a tensor, for instance, query/key/value
weights or a gate/up projection.
The `block_sizes` determines the proportions of the packed tensors.
The columns are split in equally sized blocks when `block_sizes` is an
`int`, or in blocks proportional given to the sizes. For instance
`[2, 1, 1]` will divide an input with dimensionality `1024` in
`[512, 256, 256]`.
"""
...
def get_weights_col(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply column-splitting for tensor
paralllism.
"""
return weights.get_multi_weights_col([prefix], 0)
@abstractmethod
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
"""
Get the weights at the given prefixes, column-split them for tensor
parallelim, and then concatenate the weights along the given dimension.
"""
...
@abstractmethod
def get_weights_row(self, weights: "Weights", prefix: str):
"""
Get the weights at the given prefix and apply row-splitting for tensor
parallism.
"""
...
class DefaultWeightsLoader(WeightsLoader):
"""
Loader that uses tensors as-is with the exception of applying sharding
and/or concatenation.
"""
def get_weights_col_packed(
self,
weights: "Weights",
prefix: str,
block_sizes: Union[int, List[int]],
):
return weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
return torch.cat(w, dim=dim)
def get_weights_row(self, weights: "Weights", prefix: str):
return weights.get_sharded(f"{prefix}.weight", dim=1)
class Weights: class Weights:
@ -17,6 +92,7 @@ class Weights:
device, device,
dtype, dtype,
process_group, process_group,
weights_loader: WeightsLoader,
aliases: Optional[Dict[str, List[str]]] = None, aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None, prefix: Optional[str] = None,
): ):
@ -37,6 +113,7 @@ class Weights:
self.dtype = dtype self.dtype = dtype
self.process_group = process_group self.process_group = process_group
self.prefix = prefix self.prefix = prefix
self.weights_loader = weights_loader
self._handles = {} self._handles = {}
def _get_handle(self, filename): def _get_handle(self, filename):
@ -181,295 +258,27 @@ class Weights:
num_key_value_heads: int, num_key_value_heads: int,
): ):
return self.get_weights_col_packed( return self.get_weights_col_packed(
prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads] prefix, [num_heads, num_key_value_heads, num_key_value_heads]
) )
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str): def get_weights_col_packed_gate_up(self, prefix: str):
return self.get_weights_col_packed(prefix, quantize, 2) return self.get_weights_col_packed(prefix, 2)
def get_weights_col_packed( def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]):
self, prefix: str, quantize: str, block_sizes: Union[int, List[int]]
):
""" """
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
already alternating Q,K,V within the main tensor.
The columns are split in equally sized blocks when blocks is an `int`, or The columns are split in equally sized blocks when blocks is an `int`, or
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
convenient for e.g. splitting QKV without knowing the storage details of convenient for e.g. splitting QKV without knowing the storage details of
quantized weights. quantized weights.
""" """
if quantize in ["gptq", "awq"]: return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes)
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try: def get_weights_col(self, prefix: str):
qweight = self.get_packed_sharded( return self.weights_loader.get_weights_col(self, prefix)
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
)
scales = self.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
scales = scales.to(dtype=self.dtype)
gptq_params = self._get_gptq_params() def get_multi_weights_col(self, prefixes: List[str], dim: int):
if can_use_gptq_marlin(gptq_params, quantize): return self.weights_loader.get_multi_weights_col(self, prefixes, dim)
g_idx = self.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
qzeros = self.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
if quantize == "gptq" and gptq_params.quant_method == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx")
elif quantize == "gptq" and gptq_params.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = (
torch.arange(
qweight.shape[0] * (32 // gptq_params.bits),
device=qweight.device,
)
// gptq_params.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=False,
)
elif quantize == "marlin":
from text_generation_server.layers.marlin import (
GPTQMarlin24Weight,
MarlinWeight,
repack_gptq_for_marlin,
)
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
B = self.get_packed_sharded(
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
)
B_meta = self.get_packed_sharded(
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
)
s = self.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
gptq_params = self._get_gptq_params()
weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
)
else:
B = self.get_packed_sharded(
f"{prefix}.B", dim=1, block_sizes=block_sizes
)
s = self.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = MarlinWeight(B=B, s=s)
else:
weight = self.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
return weight
def get_weights_col(self, prefix: str, quantize: str):
if quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2Weight
try:
q_weight = self.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = self.get_tensor(f"{prefix}.q_scale")
q_invperm = self.get_tensor(f"{prefix}.q_invperm")
q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
q_groups = self.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
return self.get_multi_weights_col([prefix], quantize, 0)
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize == "exl2":
raise ValueError("get_multi_weights_col is not supported for exl2")
elif quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try:
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
gptq_params = self._get_gptq_params()
if can_use_gptq_marlin(gptq_params, quantize):
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
from text_generation_server.layers.gptq import HAS_EXLLAMA
use_exllama = (
gptq_params.bits == 4
and HAS_EXLLAMA
and quantize == "gptq"
and not gptq_params.desc_act
)
if quantize == "gptq" and gptq_params.quant_method == "gptq":
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif quantize == "gptq" and gptq_params.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // gptq_params.bits),
device=qweight.device,
)
// gptq_params.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=use_exllama,
)
elif quantize == "marlin":
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
GPTQMarlin24Weight,
MarlinWeight,
)
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = torch.cat(
[self.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)
B_meta = torch.cat(
[self.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
)
s = torch.cat(
[self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
gptq_params = self._get_gptq_params()
weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
)
else:
try:
B = torch.cat(
[self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)
s = torch.cat(
[self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = MarlinWeight(B=B, s=s)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
return weight
def get_tensor_shard(self, var, dim): def get_tensor_shard(self, var, dim):
world_size = self.process_group.size() world_size = self.process_group.size()
@ -487,318 +296,8 @@ class Weights:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_multi_weights_row(self, prefix: str, quantize: str): def get_weights_row(self, prefix: str):
if quantize == "exl2": return self.weights_loader.get_weights_row(self, prefix)
from text_generation_server.layers.exl2 import Exl2Weight
try:
q_weight = self.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = self.get_tensor(f"{prefix}.q_scale")
q_invperm = self.get_tensor(f"{prefix}.q_invperm")
q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
q_groups = self.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
elif quantize == "gptq":
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
gptq_params = self._get_gptq_params()
if can_use_gptq_marlin(gptq_params, quantize):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if gptq_params.desc_act or gptq_params.groupsize == -1:
scales = self.get_tensor(f"{prefix}.scales")
else:
scales = self.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = self.process_group.size() > 1
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=sharded_in_features,
)
use_exllama = True
if gptq_params.bits != 4:
use_exllama = False
if gptq_params.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if gptq_params.quant_method == "gptq":
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
elif gptq_params.quant_method == "awq":
g_idx = None
if self.process_group.size() > 1:
if g_idx is not None:
if (
not torch.equal(
g_idx.cpu(),
torch.tensor(
[
i // gptq_params.groupsize
for i in range(g_idx.shape[0])
],
dtype=torch.int32,
),
)
and not (g_idx == 0).all()
):
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
from text_generation_server.layers.gptq import (
HAS_EXLLAMA,
CAN_EXLLAMA,
GPTQWeight,
)
if use_exllama:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
)
use_exllama = False
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama and gptq_params.groupsize != -1:
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0)
else:
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if gptq_params.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // gptq_params.bits),
device=qweight.device,
)
// gptq_params.groupsize
).to(dtype=torch.int32)
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=use_exllama,
)
elif quantize == "awq":
from text_generation_server.layers.gptq import GPTQWeight
gptq_params = self._get_gptq_params()
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `awq` weight, make sure the model is already quantized"
)
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0)
g_idx = None
use_exllama = False
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=use_exllama,
)
elif quantize == "marlin":
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
GPTQMarlin24Weight,
MarlinWeight,
)
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = self.get_sharded(f"{prefix}.B_24", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = self.get_sharded(f"{prefix}.B_meta", dim=0)
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = self.get_tensor(f"{prefix}.s")
else:
s = self.get_sharded(f"{prefix}.s", dim=0)
gptq_params = self._get_gptq_params()
weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
)
else:
try:
B = self.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized."
)
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = self.get_tensor(f"{prefix}.s")
else:
s = self.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
def _get_gptq_params(self) -> GPTQParams:
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
desc_act = False
sym = False
quant_method = "gptq"
except (SafetensorError, RuntimeError) as e:
try:
bits = self.gptq_bits
groupsize = self.gptq_groupsize
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
desc_act = getattr(self, "gptq_desc_act", False)
quant_method = getattr(self, "quant_method", "gptq")
sym = getattr(self, "sym", True)
except Exception:
raise e
return GPTQParams(
bits=bits,
checkpoint_format=checkpoint_format,
desc_act=desc_act,
groupsize=groupsize,
quant_method=quant_method,
sym=sym,
)
def _set_gptq_params(self, model_id, revision):
filename = "config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
self.gptq_bits = data["quantization_config"]["bits"]
self.gptq_groupsize = data["quantization_config"]["group_size"]
# Order is important here, desc_act is missing on some real models
self.quant_method = data["quantization_config"]["quant_method"]
self.gptq_checkpoint_format = data["quantization_config"].get(
"checkpoint_format"
)
self.gptq_sym = data["quantization_config"]["sym"]
self.gptq_desc_act = data["quantization_config"]["desc_act"]
except Exception:
filename = "quantize_config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
self.gptq_bits = data["bits"]
self.gptq_groupsize = data["group_size"]
self.gptq_sym = data["sym"]
self.gptq_desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
self.quant_method = "awq"
except Exception:
filename = "quant_config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
self.gptq_bits = data["w_bit"]
self.gptq_groupsize = data["q_group_size"]
self.gptq_desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
self.quant_method = "awq"
except Exception:
pass
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: