Move quantized weight handling out of the `Weights` class (#2194)
Quantized weights were loaded in the `Weights` class, but this was getting quite unwieldy, where every higher level method to load weights was a long conditional to cover all the different quantizers. This change moves loading of quantized weights out of the `Weights` class. This is done by defining a simple `WeightsLoader` interface that is implemented by `Exl2WeightsLoader`, `GPTQWeightsLoader`, and `MarlinWeightsLoader`. These implementations are in the quantizers' respective modules. The `Weights` class provides the low-level load operations (such as loading tensors or sharded tensors), but delegates loads that need quantizer-specific weight processing to a loader. The loaders still use the low-level functionality provided by `Weights`. I initially tried making a hierarchy where a class like `GPTQWeights` would inherit from `Weights`. But it is not very flexible (e.g. does not work well with the new weight storage mock used in tests) and the implicit indirections made the code harder to follow.
This commit is contained in:
parent
4c976fb406
commit
8511669cb2
|
@ -2,6 +2,7 @@ import torch
|
|||
from text_generation_server.layers import (
|
||||
TensorParallelEmbedding,
|
||||
)
|
||||
from text_generation_server.utils.weights import DefaultWeightsLoader
|
||||
|
||||
|
||||
class ProcessGroup:
|
||||
|
@ -42,7 +43,12 @@ class Weights:
|
|||
def test_weight_hub_files_offline_error():
|
||||
|
||||
vocab_size = 17
|
||||
weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256)
|
||||
weights = Weights(
|
||||
rank=0,
|
||||
world_size=1,
|
||||
vocab_size=vocab_size,
|
||||
hidden_dim=256,
|
||||
)
|
||||
embeddings = TensorParallelEmbedding("", weights)
|
||||
|
||||
input_ids = torch.arange(vocab_size)
|
||||
|
|
|
@ -1,13 +1,47 @@
|
|||
import pytest
|
||||
import torch
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
from text_generation_server.layers.marlin import MarlinWeight
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
Weights,
|
||||
WeightsLoader,
|
||||
)
|
||||
from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader
|
||||
from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader
|
||||
from text_generation_server.layers.marlin import MarlinWeight, MarlinWeightsLoader
|
||||
from types import SimpleNamespace
|
||||
from typing import List, Optional, Dict, Union
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gptq_weights_loader():
|
||||
return GPTQWeightsLoader(
|
||||
bits=4,
|
||||
groupsize=-1,
|
||||
desc_act=False,
|
||||
quant_method="gptq",
|
||||
quantize="gptq",
|
||||
sym=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gptq_weights_loader_awq():
|
||||
return GPTQWeightsLoader(
|
||||
bits=4,
|
||||
groupsize=-1,
|
||||
desc_act=False,
|
||||
quant_method="awq",
|
||||
quantize="awq",
|
||||
sym=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def marlin_weights_loader():
|
||||
return MarlinWeightsLoader(bits=4, is_marlin_24=False)
|
||||
|
||||
|
||||
dummy_file_system = {
|
||||
"test_weights": {
|
||||
"layer.0.weight": torch.tensor(
|
||||
|
@ -58,7 +92,7 @@ dummy_file_system = {
|
|||
dtype=torch.float32,
|
||||
),
|
||||
},
|
||||
"test_get_multi_weights_row": {
|
||||
"test_get_weights_row": {
|
||||
"weight.weight": torch.tensor(
|
||||
[
|
||||
[1, 2],
|
||||
|
@ -101,7 +135,7 @@ dummy_file_system = {
|
|||
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||
"weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
|
||||
},
|
||||
"test_get_multi_weights_row_gptq": {
|
||||
"test_get_weights_row_gptq": {
|
||||
"weight.qweight": torch.tensor(
|
||||
[
|
||||
[1, 2],
|
||||
|
@ -200,7 +234,7 @@ dummy_file_system = {
|
|||
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
|
||||
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
|
||||
},
|
||||
"test_get_multi_weights_row_exl2": {
|
||||
"test_get_weights_row_exl2": {
|
||||
"weight.q_weight": torch.tensor(
|
||||
[
|
||||
[1, 2],
|
||||
|
@ -245,7 +279,7 @@ dummy_file_system = {
|
|||
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
|
||||
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
|
||||
},
|
||||
"test_get_multi_weights_row_marlin": {
|
||||
"test_get_weights_row_marlin": {
|
||||
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
|
||||
},
|
||||
|
@ -308,6 +342,7 @@ class MockWeights(Weights):
|
|||
dummy_fs,
|
||||
aliases: Optional[Dict[str, List[str]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
weights_loader: Optional[WeightsLoader] = None,
|
||||
):
|
||||
routing = {}
|
||||
self.dummy_fs = dummy_fs
|
||||
|
@ -327,6 +362,9 @@ class MockWeights(Weights):
|
|||
self.dtype = dtype
|
||||
self.process_group = process_group
|
||||
self.prefix = prefix
|
||||
self.weights_loader = (
|
||||
DefaultWeightsLoader() if weights_loader is None else weights_loader
|
||||
)
|
||||
self._handles = {}
|
||||
|
||||
def _get_handle(self, filename: Union[Path, str]):
|
||||
|
@ -412,12 +450,10 @@ def test_get_weights_col_packed():
|
|||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = None
|
||||
block_sizes = 1
|
||||
|
||||
w = weights.get_weights_col_packed(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
block_sizes=block_sizes,
|
||||
)
|
||||
|
||||
|
@ -448,12 +484,10 @@ def test_get_weights_col_packed_block_size():
|
|||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = None
|
||||
block_sizes = 2
|
||||
|
||||
w = weights.get_weights_col_packed(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
block_sizes=block_sizes,
|
||||
)
|
||||
|
||||
|
@ -484,12 +518,10 @@ def test_get_weights_col_packed_block_size_arr():
|
|||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = None
|
||||
block_sizes = [1, 1]
|
||||
|
||||
w = weights.get_weights_col_packed(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
block_sizes=block_sizes,
|
||||
)
|
||||
|
||||
|
@ -519,11 +551,9 @@ def test_get_multi_weights_col():
|
|||
)
|
||||
|
||||
prefixes = ["weight", "weight"]
|
||||
quantize = None
|
||||
|
||||
w = weights.get_multi_weights_col(
|
||||
prefixes=prefixes,
|
||||
quantize=quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
@ -545,10 +575,10 @@ def test_get_multi_weights_col():
|
|||
)
|
||||
|
||||
|
||||
def test_get_multi_weights_row():
|
||||
def test_get_weights_row():
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_multi_weights_row",
|
||||
"test_get_weights_row",
|
||||
],
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
|
@ -557,11 +587,9 @@ def test_get_multi_weights_row():
|
|||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = None
|
||||
|
||||
w = weights.get_multi_weights_row(
|
||||
w = weights.get_weights_row(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
assert torch.allclose(
|
||||
|
@ -576,7 +604,7 @@ def test_get_multi_weights_row():
|
|||
# test_get_weights_col
|
||||
|
||||
|
||||
def test_get_weights_col_awq():
|
||||
def test_get_weights_col_awq(gptq_weights_loader_awq):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_weights_col_gptq",
|
||||
|
@ -585,14 +613,13 @@ def test_get_weights_col_awq():
|
|||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=gptq_weights_loader_awq,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "awq"
|
||||
|
||||
w = weights.get_weights_col(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
expected_weight = GPTQWeight(
|
||||
|
@ -617,7 +644,7 @@ def test_get_weights_col_awq():
|
|||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
def test_get_weights_col_gtpq():
|
||||
def test_get_weights_col_gtpq(gptq_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_weights_col_gptq",
|
||||
|
@ -626,14 +653,13 @@ def test_get_weights_col_gtpq():
|
|||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=gptq_weights_loader,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "gptq"
|
||||
|
||||
w = weights.get_weights_col(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
expected_weight = GPTQWeight(
|
||||
|
@ -664,14 +690,13 @@ def test_get_weights_col_exl2():
|
|||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=Exl2WeightsLoader(),
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "exl2"
|
||||
|
||||
w = weights.get_weights_col(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
scaled_scale_max = 0.3906 * 256
|
||||
|
@ -692,7 +717,7 @@ def test_get_weights_col_exl2():
|
|||
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
|
||||
|
||||
|
||||
def test_get_weights_col_marlin():
|
||||
def test_get_weights_col_marlin(marlin_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_weights_col_marlin",
|
||||
|
@ -701,14 +726,13 @@ def test_get_weights_col_marlin():
|
|||
dtype=torch.float16,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=marlin_weights_loader,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "marlin"
|
||||
|
||||
w = weights.get_weights_col(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
expected_weight = MarlinWeight(
|
||||
|
@ -723,7 +747,7 @@ def test_get_weights_col_marlin():
|
|||
# test_get_weights_col_packed
|
||||
|
||||
|
||||
def test_get_weights_col_packed_awq():
|
||||
def test_get_weights_col_packed_awq(gptq_weights_loader_awq):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_weights_col_packed_gptq",
|
||||
|
@ -732,15 +756,14 @@ def test_get_weights_col_packed_awq():
|
|||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=gptq_weights_loader_awq,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "awq"
|
||||
block_sizes = 1
|
||||
|
||||
w = weights.get_weights_col_packed(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
block_sizes=block_sizes,
|
||||
)
|
||||
|
||||
|
@ -773,15 +796,14 @@ def test_get_weights_col_packed_exl2():
|
|||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=Exl2WeightsLoader(),
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "exl2"
|
||||
block_sizes = 1
|
||||
|
||||
w = weights.get_weights_col_packed(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
block_sizes=block_sizes,
|
||||
)
|
||||
|
||||
|
@ -803,7 +825,7 @@ def test_get_weights_col_packed_exl2():
|
|||
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
|
||||
|
||||
|
||||
def test_get_weights_col_packed_gptq():
|
||||
def test_get_weights_col_packed_gptq(gptq_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_weights_col_packed_gptq",
|
||||
|
@ -812,14 +834,13 @@ def test_get_weights_col_packed_gptq():
|
|||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=gptq_weights_loader,
|
||||
)
|
||||
|
||||
prefixes = ["weight"]
|
||||
quantize = "gptq"
|
||||
|
||||
w = weights.get_multi_weights_col(
|
||||
prefixes=prefixes,
|
||||
quantize=quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
@ -842,7 +863,7 @@ def test_get_weights_col_packed_gptq():
|
|||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
def test_get_weights_col_packed_marlin():
|
||||
def test_get_weights_col_packed_marlin(marlin_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_weights_col_packed_marlin",
|
||||
|
@ -851,14 +872,13 @@ def test_get_weights_col_packed_marlin():
|
|||
dtype=torch.float16,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=marlin_weights_loader,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "marlin"
|
||||
|
||||
w = weights.get_multi_weights_col(
|
||||
prefixes=[prefix],
|
||||
quantize=quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
@ -876,7 +896,7 @@ def test_get_weights_col_packed_marlin():
|
|||
# test_get_multi_weights_col
|
||||
|
||||
|
||||
def test_get_multi_weights_col_awq():
|
||||
def test_get_multi_weights_col_awq(gptq_weights_loader_awq):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_multi_weights_col_gptq",
|
||||
|
@ -885,14 +905,13 @@ def test_get_multi_weights_col_awq():
|
|||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=gptq_weights_loader_awq,
|
||||
)
|
||||
|
||||
prefixes = ["weight"]
|
||||
quantize = "awq"
|
||||
|
||||
w = weights.get_multi_weights_col(
|
||||
prefixes=prefixes,
|
||||
quantize=quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
@ -924,22 +943,21 @@ def test_get_multi_weights_col_exl2():
|
|||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=Exl2WeightsLoader(),
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "exl2"
|
||||
|
||||
try:
|
||||
w = weights.get_multi_weights_col(
|
||||
prefixes=[prefix],
|
||||
quantize=quantize,
|
||||
dim=0,
|
||||
)
|
||||
except ValueError as e:
|
||||
assert e.args[0] == "get_multi_weights_col is not supported for exl2"
|
||||
|
||||
|
||||
def test_get_multi_weights_col_gptq():
|
||||
def test_get_multi_weights_col_gptq(gptq_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_multi_weights_col_gptq",
|
||||
|
@ -948,14 +966,13 @@ def test_get_multi_weights_col_gptq():
|
|||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=gptq_weights_loader,
|
||||
)
|
||||
|
||||
prefixes = ["weight"]
|
||||
quantize = "gptq"
|
||||
|
||||
w = weights.get_multi_weights_col(
|
||||
prefixes=prefixes,
|
||||
quantize=quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
@ -978,7 +995,7 @@ def test_get_multi_weights_col_gptq():
|
|||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
def test_get_multi_weights_col_marlin():
|
||||
def test_get_multi_weights_col_marlin(marlin_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_multi_weights_col_marlin",
|
||||
|
@ -987,14 +1004,13 @@ def test_get_multi_weights_col_marlin():
|
|||
dtype=torch.float16,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=marlin_weights_loader,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "marlin"
|
||||
|
||||
w = weights.get_multi_weights_col(
|
||||
prefixes=[prefix],
|
||||
quantize=quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
@ -1007,26 +1023,25 @@ def test_get_multi_weights_col_marlin():
|
|||
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
|
||||
|
||||
|
||||
# test_get_multi_weights_row
|
||||
# test_get_weights_row
|
||||
|
||||
|
||||
def test_get_multi_weights_row_awq():
|
||||
def test_get_weights_row_awq(gptq_weights_loader_awq):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_multi_weights_row_gptq",
|
||||
"test_get_weights_row_gptq",
|
||||
],
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=gptq_weights_loader_awq,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "awq"
|
||||
|
||||
w = weights.get_multi_weights_row(
|
||||
w = weights.get_weights_row(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
expected_weight = GPTQWeight(
|
||||
|
@ -1048,23 +1063,22 @@ def test_get_multi_weights_row_awq():
|
|||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
def test_get_multi_weights_row_exl2():
|
||||
def test_get_weights_row_exl2():
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_multi_weights_row_exl2",
|
||||
"test_get_weights_row_exl2",
|
||||
],
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=Exl2WeightsLoader(),
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "exl2"
|
||||
|
||||
w = weights.get_multi_weights_row(
|
||||
w = weights.get_weights_row(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
)
|
||||
print(w)
|
||||
|
||||
|
@ -1086,23 +1100,22 @@ def test_get_multi_weights_row_exl2():
|
|||
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
|
||||
|
||||
|
||||
def test_get_multi_weights_row_gptq():
|
||||
def test_get_weights_row_gptq(gptq_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_multi_weights_row_gptq",
|
||||
"test_get_weights_row_gptq",
|
||||
],
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=gptq_weights_loader,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "gptq"
|
||||
|
||||
w = weights.get_multi_weights_row(
|
||||
w = weights.get_weights_row(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
expected_weight = GPTQWeight(
|
||||
|
@ -1124,23 +1137,22 @@ def test_get_multi_weights_row_gptq():
|
|||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
def test_get_multi_weights_row_marlin():
|
||||
def test_get_weights_row_marlin(marlin_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_multi_weights_row_marlin",
|
||||
"test_get_weights_row_marlin",
|
||||
],
|
||||
device="cpu",
|
||||
dtype=torch.float16,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=marlin_weights_loader,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "marlin"
|
||||
|
||||
w = weights.get_multi_weights_row(
|
||||
w = weights.get_weights_row(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
expected_weight = MarlinWeight(
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import torch
|
||||
from typing import List, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from text_generation_server.utils.weights import WeightsLoader, Weights
|
||||
|
||||
|
||||
@dataclass
|
||||
class Exl2Weight:
|
||||
|
@ -21,3 +24,60 @@ class Exl2Weight:
|
|||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.q_weight.device
|
||||
|
||||
|
||||
class Exl2WeightsLoader(WeightsLoader):
|
||||
"""Loader for exl2-quantized weights."""
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
raise RuntimeError("Column-packed weights are not supported for exl")
|
||||
|
||||
def get_weights_col(self, weights: Weights, prefix: str):
|
||||
try:
|
||||
q_weight = weights.get_tensor(f"{prefix}.q_weight")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
q_scale = weights.get_tensor(f"{prefix}.q_scale")
|
||||
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
|
||||
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
|
||||
q_groups = weights.get_tensor(f"{prefix}.q_groups")
|
||||
|
||||
return Exl2Weight(
|
||||
q_weight=q_weight,
|
||||
q_scale=q_scale,
|
||||
q_invperm=q_invperm,
|
||||
q_scale_max=q_scale_max,
|
||||
q_groups=q_groups,
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
raise ValueError("get_multi_weights_col is not supported for exl2")
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
try:
|
||||
q_weight = weights.get_tensor(f"{prefix}.q_weight")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
q_scale = weights.get_tensor(f"{prefix}.q_scale")
|
||||
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
|
||||
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
|
||||
q_groups = weights.get_tensor(f"{prefix}.q_groups")
|
||||
|
||||
return Exl2Weight(
|
||||
q_weight=q_weight,
|
||||
q_scale=q_scale,
|
||||
q_invperm=q_invperm,
|
||||
q_scale_max=q_scale_max,
|
||||
q_groups=q_groups,
|
||||
)
|
||||
|
|
|
@ -1,20 +1,14 @@
|
|||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Union
|
||||
from safetensors import SafetensorError
|
||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import (
|
||||
SYSTEM,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQParams:
|
||||
bits: int
|
||||
checkpoint_format: Optional[str]
|
||||
groupsize: int
|
||||
desc_act: bool
|
||||
quant_method: str
|
||||
sym: bool
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -69,3 +63,341 @@ elif CAN_EXLLAMA:
|
|||
pass
|
||||
|
||||
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
||||
|
||||
|
||||
class GPTQWeightsLoader(WeightsLoader):
|
||||
"""
|
||||
Loader for GPTQ- and AWQ-quantized weights.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
bits: int,
|
||||
desc_act: bool,
|
||||
groupsize: int,
|
||||
quant_method: str,
|
||||
quantize: str,
|
||||
sym: bool,
|
||||
):
|
||||
self.bits = bits
|
||||
self.desc_act = desc_act
|
||||
self.groupsize = groupsize
|
||||
self.quant_method = quant_method
|
||||
self.quantize = quantize
|
||||
self.sym = sym
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
try:
|
||||
qweight = weights.get_packed_sharded(
|
||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
|
||||
)
|
||||
scales = weights.get_packed_sharded(
|
||||
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
scales = scales.to(dtype=weights.dtype)
|
||||
|
||||
self._get_gptq_params(weights)
|
||||
if can_use_gptq_marlin(
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
quantize=self.quantize,
|
||||
sym=self.sym,
|
||||
):
|
||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
qzeros = weights.get_packed_sharded(
|
||||
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||
elif self.quantize == "gptq" and self.quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.layers.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
g_idx = (
|
||||
torch.arange(
|
||||
qweight.shape[0] * (32 // self.bits),
|
||||
device=qweight.device,
|
||||
)
|
||||
// self.groupsize
|
||||
).to(dtype=torch.int32)
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
return GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
use_exllama=False,
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
scales = torch.cat(
|
||||
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
self._get_gptq_params(weights)
|
||||
if can_use_gptq_marlin(
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
quantize=self.quantize,
|
||||
sym=self.sym,
|
||||
):
|
||||
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
qzeros = torch.cat(
|
||||
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
from text_generation_server.layers.gptq import HAS_EXLLAMA
|
||||
|
||||
use_exllama = (
|
||||
self.bits == 4
|
||||
and HAS_EXLLAMA
|
||||
and self.quantize == "gptq"
|
||||
and not self.desc_act
|
||||
)
|
||||
|
||||
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
elif self.quantize == "gptq" and self.quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.layers.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
if use_exllama:
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = (
|
||||
torch.arange(
|
||||
qweight.shape[0] * (32 // self.bits),
|
||||
device=qweight.device,
|
||||
)
|
||||
// self.groupsize
|
||||
).to(dtype=torch.int32)
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
return GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
self._get_gptq_params(weights)
|
||||
if can_use_gptq_marlin(
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
quantize=self.quantize,
|
||||
sym=self.sym,
|
||||
):
|
||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||
try:
|
||||
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
if self.desc_act or self.groupsize == -1:
|
||||
scales = weights.get_tensor(f"{prefix}.scales")
|
||||
else:
|
||||
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
||||
|
||||
sharded_in_features = weights.process_group.size() > 1
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=sharded_in_features,
|
||||
)
|
||||
|
||||
use_exllama = True
|
||||
if self.bits != 4:
|
||||
use_exllama = False
|
||||
|
||||
if self.desc_act:
|
||||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||
use_exllama = False
|
||||
|
||||
try:
|
||||
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
if weights.process_group.size() > 1:
|
||||
if g_idx is not None:
|
||||
if (
|
||||
not torch.equal(
|
||||
g_idx.cpu(),
|
||||
torch.tensor(
|
||||
[i // self.groupsize for i in range(g_idx.shape[0])],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
)
|
||||
and not (g_idx == 0).all()
|
||||
):
|
||||
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||
# it would require to reorder input activations that are split unto several GPUs
|
||||
use_exllama = False
|
||||
|
||||
from text_generation_server.layers.gptq import (
|
||||
HAS_EXLLAMA,
|
||||
CAN_EXLLAMA,
|
||||
GPTQWeight,
|
||||
)
|
||||
|
||||
if use_exllama:
|
||||
if not HAS_EXLLAMA:
|
||||
if CAN_EXLLAMA:
|
||||
log_once(
|
||||
logger.warning,
|
||||
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
|
||||
)
|
||||
use_exllama = False
|
||||
else:
|
||||
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||
|
||||
if use_exllama and self.groupsize != -1:
|
||||
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
||||
else:
|
||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||
scales = weights.get_tensor(f"{prefix}.scales")
|
||||
|
||||
if use_exllama and g_idx is not None:
|
||||
g_idx = g_idx - g_idx[0]
|
||||
|
||||
if self.quantize == "gptq" and self.quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.layers.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
if use_exllama:
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = (
|
||||
torch.arange(
|
||||
qweight.shape[0] * (32 // self.bits),
|
||||
device=qweight.device,
|
||||
)
|
||||
// self.groupsize
|
||||
).to(dtype=torch.int32)
|
||||
|
||||
return GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
|
||||
def _get_gptq_params(self, weights: Weights):
|
||||
try:
|
||||
self.bits = weights.get_tensor("gptq_bits").item()
|
||||
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||
self.desc_act = False
|
||||
self.sym = False
|
||||
self.quant_method = "gptq"
|
||||
except (SafetensorError, RuntimeError) as e:
|
||||
pass
|
||||
|
|
|
@ -16,6 +16,8 @@ from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
|||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.utils.weights import DefaultWeightsLoader
|
||||
|
||||
DEV = torch.device("cuda:0")
|
||||
|
||||
|
||||
|
@ -891,6 +893,7 @@ def quantize(
|
|||
dtype=torch.float16,
|
||||
process_group=process_group,
|
||||
aliases={"embed_tokens.weight": ["lm_head.weight"]},
|
||||
weights_loader=DefaultWeightsLoader(),
|
||||
)
|
||||
hooks = []
|
||||
for name, module in model.named_modules():
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.layers.gptq import GPTQParams
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
try:
|
||||
|
@ -24,16 +24,132 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
|
|||
MARLIN_TILE_SIZE = 16
|
||||
|
||||
|
||||
def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool:
|
||||
class MarlinWeightsLoader(WeightsLoader):
|
||||
"""Loader for Marlin-quantized weights."""
|
||||
|
||||
def __init__(self, *, bits: int, is_marlin_24: bool):
|
||||
self.bits = bits
|
||||
self.is_marlin_24 = is_marlin_24
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
if self.is_marlin_24:
|
||||
B = weights.get_packed_sharded(
|
||||
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
B_meta = weights.get_packed_sharded(
|
||||
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
s = weights.get_packed_sharded(
|
||||
f"{prefix}.s", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
|
||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
||||
else:
|
||||
B = weights.get_packed_sharded(
|
||||
f"{prefix}.B", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
s = weights.get_packed_sharded(
|
||||
f"{prefix}.s", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
|
||||
return weight
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
try:
|
||||
B = torch.cat(
|
||||
[weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `marlin` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
B_meta = torch.cat(
|
||||
[weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
s = torch.cat(
|
||||
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
||||
else:
|
||||
try:
|
||||
B = torch.cat(
|
||||
[weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `marlin` weight, make sure the model is already quantized"
|
||||
)
|
||||
s = torch.cat(
|
||||
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
|
||||
return weight
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
try:
|
||||
B = weights.get_sharded(f"{prefix}.B_24", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0)
|
||||
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
|
||||
if num_groups == 1:
|
||||
# The number of groups is 1 when groupsize == -1. share
|
||||
# scales between all shards in this case.
|
||||
s = weights.get_tensor(f"{prefix}.s")
|
||||
else:
|
||||
s = weights.get_sharded(f"{prefix}.s", dim=0)
|
||||
|
||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
||||
else:
|
||||
try:
|
||||
B = weights.get_sharded(f"{prefix}.B", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `marlin` weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
|
||||
if num_groups == 1:
|
||||
# The number of groups is 1 when groupsize == -1. share
|
||||
# scales between all shards in this case.
|
||||
s = weights.get_tensor(f"{prefix}.s")
|
||||
else:
|
||||
s = weights.get_sharded(f"{prefix}.s", dim=0)
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
|
||||
return weight
|
||||
|
||||
|
||||
def can_use_gptq_marlin(
|
||||
*, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool
|
||||
) -> bool:
|
||||
return (
|
||||
SYSTEM == "cuda"
|
||||
and marlin_kernels is not None
|
||||
and has_sm_8_0
|
||||
and quantize == "gptq"
|
||||
and gptq_params.quant_method == "gptq"
|
||||
and gptq_params.bits in GPTQ_MARLIN_BITS
|
||||
and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES
|
||||
and gptq_params.sym
|
||||
and quant_method == "gptq"
|
||||
and bits in GPTQ_MARLIN_BITS
|
||||
and groupsize in GPTQ_MARLIN_GROUP_SIZES
|
||||
and sym
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ class TensorParallelHead(SuperLayer):
|
|||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
except:
|
||||
# ...otherwise they are quantized.
|
||||
weight = weights.get_weights_col(prefix, config.quantize)
|
||||
weight = weights.get_weights_col(prefix)
|
||||
should_gather = weights.process_group.size() > 1
|
||||
elif weights.process_group.size() > 1:
|
||||
try:
|
||||
|
@ -129,9 +129,7 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||
@classmethod
|
||||
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
|
||||
"""Specific method when the QKV was joined after the fact"""
|
||||
weight = weights.get_weights_col_packed_gate_up(
|
||||
prefix, quantize=config.quantize
|
||||
)
|
||||
weight = weights.get_weights_col_packed_gate_up(prefix)
|
||||
if bias:
|
||||
raise NotImplementedError("packed_gate_up only implemented without bias")
|
||||
else:
|
||||
|
@ -152,7 +150,6 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||
"""Specific method when the QKV was joined after the fact"""
|
||||
weight = weights.get_weights_col_packed_qkv(
|
||||
prefix,
|
||||
quantize=config.quantize,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
)
|
||||
|
@ -165,7 +162,7 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_weights_col(prefix, config.quantize)
|
||||
weight = weights.get_weights_col(prefix)
|
||||
if bias:
|
||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
else:
|
||||
|
@ -178,14 +175,12 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||
if config.quantize == "exl2":
|
||||
linears = []
|
||||
for prefix in prefixes:
|
||||
weight = weights.get_weights_col(prefix, config.quantize)
|
||||
weight = weights.get_weights_col(prefix)
|
||||
b = weights.get_tensor(f"{prefix}.bias") if bias else None
|
||||
linears.append(get_linear(weight, b, config.quantize))
|
||||
linear = LayerConcat(linears)
|
||||
else:
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes, quantize=config.quantize, dim=dim
|
||||
)
|
||||
weight = weights.get_multi_weights_col(prefixes, dim=dim)
|
||||
if bias:
|
||||
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
||||
bias = torch.cat(b, dim=dim)
|
||||
|
@ -202,7 +197,7 @@ class TensorParallelRowLinear(SuperLayer):
|
|||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
weight = weights.get_weights_row(prefix)
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
|
|
|
@ -20,6 +20,7 @@ from text_generation_server.utils import (
|
|||
from text_generation_server.models import Model
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.quantization import get_loader
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
|
@ -546,12 +547,17 @@ class CausalLM(Model):
|
|||
tokenizer.pad_token_id = config.pad_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
weights_loader = get_loader(
|
||||
quantize=quantize, model_id=model_id, revision=revision
|
||||
)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
filenames,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
weights_loader=weights_loader,
|
||||
)
|
||||
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
prefix = ""
|
||||
model = model_class(prefix, config, weights)
|
||||
|
|
|
@ -163,7 +163,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
|
|
@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
|
|
@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
|
|
@ -61,7 +61,6 @@ def _load_qkv_gptq(config, prefix: str, weights):
|
|||
# Weights
|
||||
weight = weights.get_weights_col_packed_qkv(
|
||||
f"{prefix}.c_attn",
|
||||
config.quantize,
|
||||
config.num_attention_heads,
|
||||
config.num_attention_heads,
|
||||
)
|
||||
|
@ -137,7 +136,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||
"""load_row, but with transposed weight matrices."""
|
||||
|
||||
if config.quantize == "gptq":
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
weight = weights.get_weights_row(prefix)
|
||||
else:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
||||
|
||||
|
@ -155,9 +154,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||
def load_col(config, prefix: str, weights, bias: bool):
|
||||
"""load_col, but with transposed weight matrices."""
|
||||
if config.quantize == "gptq":
|
||||
weight = weights.get_multi_weights_col(
|
||||
[prefix], quantize=config.quantize, dim=1
|
||||
)
|
||||
weight = weights.get_multi_weights_col([prefix], dim=1)
|
||||
else:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
||||
|
||||
|
|
|
@ -135,7 +135,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ from text_generation_server.layers.rotary import (
|
|||
|
||||
|
||||
def load_row(config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
weight = weights.get_weights_row(prefix)
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
|
@ -64,7 +64,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||
|
||||
|
||||
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
||||
weight = weights.get_multi_weights_col([prefix], quantize=config.quantize, dim=0)
|
||||
weight = weights.get_multi_weights_col([prefix], dim=0)
|
||||
if isinstance(weight, torch.Tensor):
|
||||
# Only on non quantized versions
|
||||
weight = (
|
||||
|
|
|
@ -85,7 +85,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from text_generation_server.layers.attention import (
|
|||
|
||||
|
||||
def load_row(config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
weight = weights.get_weights_row(prefix)
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
|
|
|
@ -17,6 +17,7 @@ from text_generation_server.layers import (
|
|||
TensorParallelEmbedding,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
|
@ -81,11 +82,13 @@ def _load_multi_mqa_gptq(
|
|||
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||
qzeros = qzeros.to(device=weights.device)
|
||||
|
||||
gptq_params = weights._get_gptq_params()
|
||||
if gptq_params.quant_method == "gptq":
|
||||
loader = weights.weights_loader
|
||||
assert isinstance(loader, GPTQWeightsLoader)
|
||||
loader._get_gptq_params(weights)
|
||||
if loader.quant_method == "gptq":
|
||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||
g_idx = g_idx.to(device=weights.device)
|
||||
elif gptq_params.quant_method == "awq":
|
||||
elif loader.quant_method == "awq":
|
||||
g_idx = None
|
||||
from text_generation_server.layers.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
|
@ -100,8 +103,8 @@ def _load_multi_mqa_gptq(
|
|||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
groupsize=gptq_params.groupsize,
|
||||
bits=loader.bits,
|
||||
groupsize=loader.groupsize,
|
||||
use_exllama=HAS_EXLLAMA,
|
||||
)
|
||||
|
||||
|
@ -197,9 +200,7 @@ def load_col(config, prefix: str, weights, bias: bool):
|
|||
if config.transpose:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
||||
else:
|
||||
weight = weights.get_multi_weights_col(
|
||||
[prefix], quantize=config.quantize, dim=0
|
||||
)
|
||||
weight = weights.get_multi_weights_col([prefix], dim=0)
|
||||
|
||||
if bias:
|
||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
|
@ -212,7 +213,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||
if config.transpose:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
||||
else:
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
weight = weights.get_weights_row(prefix)
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
|
|
|
@ -126,7 +126,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
|
|
@ -50,6 +50,7 @@ from text_generation_server.models.globals import (
|
|||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||
from text_generation_server.utils.quantization import get_loader
|
||||
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
|
||||
|
||||
from text_generation_server.utils.import_utils import (
|
||||
|
@ -881,12 +882,16 @@ class FlashCausalLM(Model):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
weights_loader = get_loader(quantize, model_id, revision)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device, dtype, process_group=self.process_group, aliases=aliases
|
||||
filenames,
|
||||
device,
|
||||
dtype,
|
||||
process_group=self.process_group,
|
||||
aliases=aliases,
|
||||
weights_loader=weights_loader,
|
||||
)
|
||||
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
prefix = ""
|
||||
model = model_class(prefix, config, weights)
|
||||
|
|
|
@ -23,6 +23,7 @@ from text_generation_server.utils import (
|
|||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.quantization import get_loader
|
||||
|
||||
|
||||
class IDEFICSSharded(IdeficsCausalLM):
|
||||
|
@ -70,6 +71,9 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
weights_loader = get_loader(
|
||||
quantize=quantize, model_id=model_id, revision=revision
|
||||
)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
|
@ -77,6 +81,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||
device=device,
|
||||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
weights_loader=weights_loader,
|
||||
)
|
||||
|
||||
model = IdeficsForVisionText2Text(config, weights)
|
||||
|
|
|
@ -28,6 +28,7 @@ from text_generation_server.models.types import (
|
|||
GeneratedText,
|
||||
)
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.quantization import get_loader
|
||||
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
||||
from dataclasses import dataclass
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
|
@ -448,8 +449,17 @@ class Mamba(Model):
|
|||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
weights_loader = get_loader(
|
||||
quantize=quantize, model_id=model_id, revision=revision
|
||||
)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
weights = Weights(
|
||||
filenames,
|
||||
device,
|
||||
dtype,
|
||||
process_group=self.process_group,
|
||||
weights_loader=weights_loader,
|
||||
)
|
||||
model = MambaModel(config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(Mamba, self).__init__(
|
||||
|
|
|
@ -18,6 +18,7 @@ from text_generation_server.utils import (
|
|||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.quantization import get_loader
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.models.types import (
|
||||
|
@ -586,6 +587,9 @@ class Seq2SeqLM(Model):
|
|||
)
|
||||
tokenizer.bos_token_id = config.decoder_start_token_id
|
||||
|
||||
weights_loader = get_loader(
|
||||
quantize=quantize, model_id=model_id, revision=revision
|
||||
)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
|
@ -594,6 +598,7 @@ class Seq2SeqLM(Model):
|
|||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
aliases=aliases,
|
||||
weights_loader=weights_loader,
|
||||
)
|
||||
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
from typing import Optional
|
||||
import os
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from text_generation_server.utils.weights import DefaultWeightsLoader, WeightsLoader
|
||||
|
||||
|
||||
@dataclass
|
||||
class _QuantizerConfig:
|
||||
bits: int
|
||||
checkpoint_format: Optional[str]
|
||||
desc_act: bool
|
||||
groupsize: int
|
||||
quant_method: str
|
||||
sym: bool
|
||||
|
||||
|
||||
# We should probably do this with Pytantic JSON deserialization,
|
||||
# but for now we'll stay close to the old _set_gptq_params.
|
||||
def _get_quantizer_config(model_id, revision):
|
||||
bits = 4
|
||||
groupsize = -1
|
||||
quant_method = "gptq"
|
||||
checkpoint_format = None
|
||||
sym = True
|
||||
desc_act = False
|
||||
|
||||
filename = "config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(model_id, filename=filename, revision=revision)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
bits = data["quantization_config"]["bits"]
|
||||
groupsize = data["quantization_config"]["group_size"]
|
||||
# Order is important here, desc_act is missing on some real models
|
||||
quant_method = data["quantization_config"]["quant_method"]
|
||||
checkpoint_format = data["quantization_config"].get("checkpoint_format")
|
||||
sym = data["quantization_config"]["sym"]
|
||||
desc_act = data["quantization_config"]["desc_act"]
|
||||
except Exception:
|
||||
filename = "quantize_config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(
|
||||
model_id, filename=filename, revision=revision
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
bits = data["bits"]
|
||||
groupsize = data["group_size"]
|
||||
sym = data["sym"]
|
||||
desc_act = data["desc_act"]
|
||||
if "version" in data and data["version"] == "GEMM":
|
||||
quant_method = "awq"
|
||||
except Exception:
|
||||
filename = "quant_config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(
|
||||
model_id, filename=filename, revision=revision
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
bits = data["w_bit"]
|
||||
groupsize = data["q_group_size"]
|
||||
desc_act = data["desc_act"]
|
||||
if "version" in data and data["version"] == "GEMM":
|
||||
quant_method = "awq"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return _QuantizerConfig(
|
||||
bits=bits,
|
||||
groupsize=groupsize,
|
||||
quant_method=quant_method,
|
||||
checkpoint_format=checkpoint_format,
|
||||
sym=sym,
|
||||
desc_act=desc_act,
|
||||
)
|
||||
|
||||
|
||||
def get_loader(
|
||||
quantize: Optional[str], model_id: str, revision: Optional[str]
|
||||
) -> WeightsLoader:
|
||||
quantizer_config = _get_quantizer_config(model_id, revision)
|
||||
if quantize in {"awq", "gptq"}:
|
||||
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||
|
||||
return GPTQWeightsLoader(
|
||||
bits=quantizer_config.bits,
|
||||
desc_act=quantizer_config.desc_act,
|
||||
groupsize=quantizer_config.groupsize,
|
||||
quant_method=quantizer_config.quant_method,
|
||||
quantize=quantize,
|
||||
sym=quantizer_config.sym,
|
||||
)
|
||||
elif quantize == "exl2":
|
||||
from text_generation_server.layers.exl2 import Exl2WeightsLoader
|
||||
|
||||
return Exl2WeightsLoader()
|
||||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.marlin import MarlinWeightsLoader
|
||||
|
||||
return MarlinWeightsLoader(
|
||||
bits=quantizer_config.bits,
|
||||
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
|
||||
)
|
||||
else:
|
||||
return DefaultWeightsLoader()
|
|
@ -1,13 +1,88 @@
|
|||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from safetensors import safe_open, SafetensorError
|
||||
from safetensors import safe_open
|
||||
import torch
|
||||
from loguru import logger
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
from text_generation_server.layers.gptq import GPTQParams
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
|
||||
class WeightsLoader(ABC):
|
||||
"""
|
||||
Instances of this type implement higher-level weight loading.
|
||||
|
||||
At a low-level, every weight is stored in the Safetensors format.
|
||||
The interpretation of weights may be different however, for instance
|
||||
could be packed, quantized weights. Loaders are responsible for
|
||||
interpreting the raw tensors, sharding tensors in a manner compatible
|
||||
with the format, etc.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: "Weights",
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
"""
|
||||
Get the packed weights at the given prefix with column-splitting for
|
||||
tensor parallelism. This method should be used when multiple different
|
||||
weights are packed into a tensor, for instance, query/key/value
|
||||
weights or a gate/up projection.
|
||||
|
||||
The `block_sizes` determines the proportions of the packed tensors.
|
||||
The columns are split in equally sized blocks when `block_sizes` is an
|
||||
`int`, or in blocks proportional given to the sizes. For instance
|
||||
`[2, 1, 1]` will divide an input with dimensionality `1024` in
|
||||
`[512, 256, 256]`.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_weights_col(self, weights: "Weights", prefix: str):
|
||||
"""
|
||||
Get weights at the given prefix and apply column-splitting for tensor
|
||||
paralllism.
|
||||
"""
|
||||
return weights.get_multi_weights_col([prefix], 0)
|
||||
|
||||
@abstractmethod
|
||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
"""
|
||||
Get the weights at the given prefixes, column-split them for tensor
|
||||
parallelim, and then concatenate the weights along the given dimension.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
"""
|
||||
Get the weights at the given prefix and apply row-splitting for tensor
|
||||
parallism.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class DefaultWeightsLoader(WeightsLoader):
|
||||
"""
|
||||
Loader that uses tensors as-is with the exception of applying sharding
|
||||
and/or concatenation.
|
||||
"""
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: "Weights",
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
return weights.get_packed_sharded(
|
||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
return torch.cat(w, dim=dim)
|
||||
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
return weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
|
||||
|
||||
class Weights:
|
||||
|
@ -17,6 +92,7 @@ class Weights:
|
|||
device,
|
||||
dtype,
|
||||
process_group,
|
||||
weights_loader: WeightsLoader,
|
||||
aliases: Optional[Dict[str, List[str]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
):
|
||||
|
@ -37,6 +113,7 @@ class Weights:
|
|||
self.dtype = dtype
|
||||
self.process_group = process_group
|
||||
self.prefix = prefix
|
||||
self.weights_loader = weights_loader
|
||||
self._handles = {}
|
||||
|
||||
def _get_handle(self, filename):
|
||||
|
@ -181,295 +258,27 @@ class Weights:
|
|||
num_key_value_heads: int,
|
||||
):
|
||||
return self.get_weights_col_packed(
|
||||
prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads]
|
||||
prefix, [num_heads, num_key_value_heads, num_key_value_heads]
|
||||
)
|
||||
|
||||
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
|
||||
return self.get_weights_col_packed(prefix, quantize, 2)
|
||||
def get_weights_col_packed_gate_up(self, prefix: str):
|
||||
return self.get_weights_col_packed(prefix, 2)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self, prefix: str, quantize: str, block_sizes: Union[int, List[int]]
|
||||
):
|
||||
def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]):
|
||||
"""
|
||||
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
||||
already alternating Q,K,V within the main tensor.
|
||||
|
||||
The columns are split in equally sized blocks when blocks is an `int`, or
|
||||
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
|
||||
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
|
||||
convenient for e.g. splitting QKV without knowing the storage details of
|
||||
quantized weights.
|
||||
"""
|
||||
if quantize in ["gptq", "awq"]:
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes)
|
||||
|
||||
try:
|
||||
qweight = self.get_packed_sharded(
|
||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
||||
)
|
||||
scales = self.get_packed_sharded(
|
||||
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
scales = scales.to(dtype=self.dtype)
|
||||
def get_weights_col(self, prefix: str):
|
||||
return self.weights_loader.get_weights_col(self, prefix)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
if can_use_gptq_marlin(gptq_params, quantize):
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
desc_act=gptq_params.desc_act,
|
||||
groupsize=gptq_params.groupsize,
|
||||
sym=gptq_params.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
qzeros = self.get_packed_sharded(
|
||||
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
if quantize == "gptq" and gptq_params.quant_method == "gptq":
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
elif quantize == "gptq" and gptq_params.quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.layers.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
g_idx = (
|
||||
torch.arange(
|
||||
qweight.shape[0] * (32 // gptq_params.bits),
|
||||
device=qweight.device,
|
||||
)
|
||||
// gptq_params.groupsize
|
||||
).to(dtype=torch.int32)
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
weight = GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
groupsize=gptq_params.groupsize,
|
||||
use_exllama=False,
|
||||
)
|
||||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlin24Weight,
|
||||
MarlinWeight,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
B = self.get_packed_sharded(
|
||||
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
B_meta = self.get_packed_sharded(
|
||||
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
s = self.get_packed_sharded(
|
||||
f"{prefix}.s", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
weight = GPTQMarlin24Weight(
|
||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||
)
|
||||
else:
|
||||
B = self.get_packed_sharded(
|
||||
f"{prefix}.B", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
s = self.get_packed_sharded(
|
||||
f"{prefix}.s", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
else:
|
||||
weight = self.get_packed_sharded(
|
||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||
)
|
||||
return weight
|
||||
|
||||
def get_weights_col(self, prefix: str, quantize: str):
|
||||
if quantize == "exl2":
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
|
||||
try:
|
||||
q_weight = self.get_tensor(f"{prefix}.q_weight")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
q_scale = self.get_tensor(f"{prefix}.q_scale")
|
||||
q_invperm = self.get_tensor(f"{prefix}.q_invperm")
|
||||
q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
|
||||
q_groups = self.get_tensor(f"{prefix}.q_groups")
|
||||
|
||||
return Exl2Weight(
|
||||
q_weight=q_weight,
|
||||
q_scale=q_scale,
|
||||
q_invperm=q_invperm,
|
||||
q_scale_max=q_scale_max,
|
||||
q_groups=q_groups,
|
||||
)
|
||||
|
||||
return self.get_multi_weights_col([prefix], quantize, 0)
|
||||
|
||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||
if quantize == "exl2":
|
||||
raise ValueError("get_multi_weights_col is not supported for exl2")
|
||||
elif quantize in ["gptq", "awq"]:
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
scales = torch.cat(
|
||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
if can_use_gptq_marlin(gptq_params, quantize):
|
||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
desc_act=gptq_params.desc_act,
|
||||
groupsize=gptq_params.groupsize,
|
||||
sym=gptq_params.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
qzeros = torch.cat(
|
||||
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
from text_generation_server.layers.gptq import HAS_EXLLAMA
|
||||
|
||||
use_exllama = (
|
||||
gptq_params.bits == 4
|
||||
and HAS_EXLLAMA
|
||||
and quantize == "gptq"
|
||||
and not gptq_params.desc_act
|
||||
)
|
||||
|
||||
if quantize == "gptq" and gptq_params.quant_method == "gptq":
|
||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
elif quantize == "gptq" and gptq_params.quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.layers.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
if use_exllama:
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = (
|
||||
torch.arange(
|
||||
qweight.shape[0] * (32 // gptq_params.bits),
|
||||
device=qweight.device,
|
||||
)
|
||||
// gptq_params.groupsize
|
||||
).to(dtype=torch.int32)
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
weight = GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
groupsize=gptq_params.groupsize,
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlin24Weight,
|
||||
MarlinWeight,
|
||||
)
|
||||
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
try:
|
||||
B = torch.cat(
|
||||
[self.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
B_meta = torch.cat(
|
||||
[self.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
s = torch.cat(
|
||||
[self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
weight = GPTQMarlin24Weight(
|
||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||
)
|
||||
else:
|
||||
try:
|
||||
B = torch.cat(
|
||||
[self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
s = torch.cat(
|
||||
[self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
weight = torch.cat(w, dim=dim)
|
||||
|
||||
return weight
|
||||
def get_multi_weights_col(self, prefixes: List[str], dim: int):
|
||||
return self.weights_loader.get_multi_weights_col(self, prefixes, dim)
|
||||
|
||||
def get_tensor_shard(self, var, dim):
|
||||
world_size = self.process_group.size()
|
||||
|
@ -487,318 +296,8 @@ class Weights:
|
|||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||
if quantize == "exl2":
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
|
||||
try:
|
||||
q_weight = self.get_tensor(f"{prefix}.q_weight")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
q_scale = self.get_tensor(f"{prefix}.q_scale")
|
||||
q_invperm = self.get_tensor(f"{prefix}.q_invperm")
|
||||
q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
|
||||
q_groups = self.get_tensor(f"{prefix}.q_groups")
|
||||
|
||||
return Exl2Weight(
|
||||
q_weight=q_weight,
|
||||
q_scale=q_scale,
|
||||
q_invperm=q_invperm,
|
||||
q_scale_max=q_scale_max,
|
||||
q_groups=q_groups,
|
||||
)
|
||||
|
||||
elif quantize == "gptq":
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
if can_use_gptq_marlin(gptq_params, quantize):
|
||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
if gptq_params.desc_act or gptq_params.groupsize == -1:
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
else:
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
|
||||
sharded_in_features = self.process_group.size() > 1
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
desc_act=gptq_params.desc_act,
|
||||
groupsize=gptq_params.groupsize,
|
||||
sym=gptq_params.sym,
|
||||
sharded_infeatures=sharded_in_features,
|
||||
)
|
||||
|
||||
use_exllama = True
|
||||
if gptq_params.bits != 4:
|
||||
use_exllama = False
|
||||
|
||||
if gptq_params.desc_act:
|
||||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||
use_exllama = False
|
||||
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
if gptq_params.quant_method == "gptq":
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
elif gptq_params.quant_method == "awq":
|
||||
g_idx = None
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
if g_idx is not None:
|
||||
if (
|
||||
not torch.equal(
|
||||
g_idx.cpu(),
|
||||
torch.tensor(
|
||||
[
|
||||
i // gptq_params.groupsize
|
||||
for i in range(g_idx.shape[0])
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
)
|
||||
and not (g_idx == 0).all()
|
||||
):
|
||||
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||
# it would require to reorder input activations that are split unto several GPUs
|
||||
use_exllama = False
|
||||
|
||||
from text_generation_server.layers.gptq import (
|
||||
HAS_EXLLAMA,
|
||||
CAN_EXLLAMA,
|
||||
GPTQWeight,
|
||||
)
|
||||
|
||||
if use_exllama:
|
||||
if not HAS_EXLLAMA:
|
||||
if CAN_EXLLAMA:
|
||||
log_once(
|
||||
logger.warning,
|
||||
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
|
||||
)
|
||||
use_exllama = False
|
||||
else:
|
||||
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||
|
||||
if use_exllama and gptq_params.groupsize != -1:
|
||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
else:
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
|
||||
if use_exllama and g_idx is not None:
|
||||
g_idx = g_idx - g_idx[0]
|
||||
|
||||
if gptq_params.quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.layers.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
if use_exllama:
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = (
|
||||
torch.arange(
|
||||
qweight.shape[0] * (32 // gptq_params.bits),
|
||||
device=qweight.device,
|
||||
)
|
||||
// gptq_params.groupsize
|
||||
).to(dtype=torch.int32)
|
||||
|
||||
weight = GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
groupsize=gptq_params.groupsize,
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
elif quantize == "awq":
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `awq` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
g_idx = None
|
||||
use_exllama = False
|
||||
|
||||
weight = GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
groupsize=gptq_params.groupsize,
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlin24Weight,
|
||||
MarlinWeight,
|
||||
)
|
||||
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
try:
|
||||
B = self.get_sharded(f"{prefix}.B_24", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
B_meta = self.get_sharded(f"{prefix}.B_meta", dim=0)
|
||||
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
|
||||
if num_groups == 1:
|
||||
# The number of groups is 1 when groupsize == -1. share
|
||||
# scales between all shards in this case.
|
||||
s = self.get_tensor(f"{prefix}.s")
|
||||
else:
|
||||
s = self.get_sharded(f"{prefix}.s", dim=0)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
weight = GPTQMarlin24Weight(
|
||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||
)
|
||||
else:
|
||||
try:
|
||||
B = self.get_sharded(f"{prefix}.B", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `marlin` weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
|
||||
if num_groups == 1:
|
||||
# The number of groups is 1 when groupsize == -1. share
|
||||
# scales between all shards in this case.
|
||||
s = self.get_tensor(f"{prefix}.s")
|
||||
else:
|
||||
s = self.get_sharded(f"{prefix}.s", dim=0)
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
else:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
||||
def _get_gptq_params(self) -> GPTQParams:
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
|
||||
desc_act = False
|
||||
sym = False
|
||||
quant_method = "gptq"
|
||||
except (SafetensorError, RuntimeError) as e:
|
||||
try:
|
||||
bits = self.gptq_bits
|
||||
groupsize = self.gptq_groupsize
|
||||
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
|
||||
desc_act = getattr(self, "gptq_desc_act", False)
|
||||
quant_method = getattr(self, "quant_method", "gptq")
|
||||
sym = getattr(self, "sym", True)
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
return GPTQParams(
|
||||
bits=bits,
|
||||
checkpoint_format=checkpoint_format,
|
||||
desc_act=desc_act,
|
||||
groupsize=groupsize,
|
||||
quant_method=quant_method,
|
||||
sym=sym,
|
||||
)
|
||||
|
||||
def _set_gptq_params(self, model_id, revision):
|
||||
filename = "config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(
|
||||
model_id, filename=filename, revision=revision
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["quantization_config"]["bits"]
|
||||
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
||||
# Order is important here, desc_act is missing on some real models
|
||||
self.quant_method = data["quantization_config"]["quant_method"]
|
||||
self.gptq_checkpoint_format = data["quantization_config"].get(
|
||||
"checkpoint_format"
|
||||
)
|
||||
self.gptq_sym = data["quantization_config"]["sym"]
|
||||
self.gptq_desc_act = data["quantization_config"]["desc_act"]
|
||||
except Exception:
|
||||
filename = "quantize_config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(
|
||||
model_id, filename=filename, revision=revision
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["bits"]
|
||||
self.gptq_groupsize = data["group_size"]
|
||||
self.gptq_sym = data["sym"]
|
||||
self.gptq_desc_act = data["desc_act"]
|
||||
if "version" in data and data["version"] == "GEMM":
|
||||
self.quant_method = "awq"
|
||||
except Exception:
|
||||
filename = "quant_config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(
|
||||
model_id, filename=filename, revision=revision
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["w_bit"]
|
||||
self.gptq_groupsize = data["q_group_size"]
|
||||
self.gptq_desc_act = data["desc_act"]
|
||||
if "version" in data and data["version"] == "GEMM":
|
||||
self.quant_method = "awq"
|
||||
except Exception:
|
||||
pass
|
||||
def get_weights_row(self, prefix: str):
|
||||
return self.weights_loader.get_weights_row(self, prefix)
|
||||
|
||||
|
||||
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
||||
|
|
Loading…
Reference in New Issue