2024-07-09 12:04:03 -06:00
|
|
|
import json
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
import os
|
2024-07-09 12:04:03 -06:00
|
|
|
from dataclasses import dataclass
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
from typing import Optional
|
2024-07-09 12:04:03 -06:00
|
|
|
|
|
|
|
from huggingface_hub import hf_hub_download
|
2024-07-31 05:08:41 -06:00
|
|
|
from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
from text_generation_server.utils.weights import (
|
|
|
|
DefaultWeightsLoader,
|
|
|
|
WeightsLoader,
|
|
|
|
)
|
2024-07-09 12:04:03 -06:00
|
|
|
|
|
|
|
|
2024-07-20 11:02:04 -06:00
|
|
|
# TODO: Split this config to have a single config type per quant method
|
2024-07-09 12:04:03 -06:00
|
|
|
@dataclass
|
|
|
|
class _QuantizerConfig:
|
|
|
|
bits: int
|
|
|
|
checkpoint_format: Optional[str]
|
|
|
|
desc_act: bool
|
|
|
|
groupsize: int
|
|
|
|
quant_method: str
|
|
|
|
sym: bool
|
|
|
|
|
|
|
|
|
2024-07-20 11:02:04 -06:00
|
|
|
@dataclass
|
|
|
|
class _FP8QuantizerConfig:
|
|
|
|
activation_scale_ub: float
|
|
|
|
|
|
|
|
|
2024-07-09 12:04:03 -06:00
|
|
|
# We should probably do this with Pytantic JSON deserialization,
|
|
|
|
# but for now we'll stay close to the old _set_gptq_params.
|
|
|
|
def _get_quantizer_config(model_id, revision):
|
|
|
|
bits = 4
|
|
|
|
groupsize = -1
|
|
|
|
quant_method = "gptq"
|
|
|
|
checkpoint_format = None
|
2024-07-23 05:08:20 -06:00
|
|
|
sym = False
|
2024-07-09 12:04:03 -06:00
|
|
|
desc_act = False
|
|
|
|
|
|
|
|
filename = "config.json"
|
|
|
|
try:
|
|
|
|
if os.path.exists(os.path.join(model_id, filename)):
|
|
|
|
filename = os.path.join(model_id, filename)
|
|
|
|
else:
|
|
|
|
filename = hf_hub_download(model_id, filename=filename, revision=revision)
|
|
|
|
with open(filename, "r") as f:
|
|
|
|
data = json.load(f)
|
2024-07-20 11:02:04 -06:00
|
|
|
|
|
|
|
# FP8 config
|
|
|
|
if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
|
|
|
|
return _FP8QuantizerConfig(
|
|
|
|
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
|
|
|
|
)
|
|
|
|
|
2024-07-23 05:08:20 -06:00
|
|
|
if "zero_point" in data["quantization_config"]:
|
|
|
|
sym = not data["quantization_config"]["zero_point"]
|
|
|
|
quant_method = "awq"
|
|
|
|
elif "sym" in data["quantization_config"]:
|
|
|
|
sym = data["quantization_config"]["sym"]
|
|
|
|
|
2024-07-09 12:04:03 -06:00
|
|
|
bits = data["quantization_config"]["bits"]
|
|
|
|
groupsize = data["quantization_config"]["group_size"]
|
|
|
|
# Order is important here, desc_act is missing on some real models
|
|
|
|
quant_method = data["quantization_config"]["quant_method"]
|
|
|
|
checkpoint_format = data["quantization_config"].get("checkpoint_format")
|
|
|
|
desc_act = data["quantization_config"]["desc_act"]
|
|
|
|
except Exception:
|
|
|
|
filename = "quantize_config.json"
|
|
|
|
try:
|
|
|
|
if os.path.exists(os.path.join(model_id, filename)):
|
|
|
|
filename = os.path.join(model_id, filename)
|
|
|
|
else:
|
|
|
|
filename = hf_hub_download(
|
|
|
|
model_id, filename=filename, revision=revision
|
|
|
|
)
|
|
|
|
with open(filename, "r") as f:
|
|
|
|
data = json.load(f)
|
|
|
|
bits = data["bits"]
|
|
|
|
groupsize = data["group_size"]
|
2024-07-23 05:08:20 -06:00
|
|
|
|
|
|
|
if "zero_point" in data:
|
|
|
|
sym = not data["zero_point"]
|
|
|
|
quant_method = "awq"
|
|
|
|
elif "sym" in data:
|
|
|
|
sym = data["sym"]
|
|
|
|
|
2024-07-09 12:04:03 -06:00
|
|
|
desc_act = data["desc_act"]
|
|
|
|
if "version" in data and data["version"] == "GEMM":
|
|
|
|
quant_method = "awq"
|
|
|
|
except Exception:
|
|
|
|
filename = "quant_config.json"
|
|
|
|
try:
|
|
|
|
if os.path.exists(os.path.join(model_id, filename)):
|
|
|
|
filename = os.path.join(model_id, filename)
|
|
|
|
else:
|
|
|
|
filename = hf_hub_download(
|
|
|
|
model_id, filename=filename, revision=revision
|
|
|
|
)
|
|
|
|
with open(filename, "r") as f:
|
|
|
|
data = json.load(f)
|
|
|
|
bits = data["w_bit"]
|
|
|
|
groupsize = data["q_group_size"]
|
|
|
|
desc_act = data["desc_act"]
|
|
|
|
if "version" in data and data["version"] == "GEMM":
|
|
|
|
quant_method = "awq"
|
|
|
|
except Exception:
|
|
|
|
pass
|
|
|
|
|
|
|
|
return _QuantizerConfig(
|
|
|
|
bits=bits,
|
|
|
|
groupsize=groupsize,
|
|
|
|
quant_method=quant_method,
|
|
|
|
checkpoint_format=checkpoint_format,
|
|
|
|
sym=sym,
|
|
|
|
desc_act=desc_act,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def get_loader(
|
|
|
|
quantize: Optional[str], model_id: str, revision: Optional[str]
|
|
|
|
) -> WeightsLoader:
|
|
|
|
quantizer_config = _get_quantizer_config(model_id, revision)
|
|
|
|
if quantize in {"awq", "gptq"}:
|
|
|
|
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
|
|
|
|
2024-07-20 11:02:04 -06:00
|
|
|
# TODO: improve check once we have one config type per quantize value
|
|
|
|
if not isinstance(quantizer_config, _QuantizerConfig):
|
|
|
|
raise ValueError(
|
|
|
|
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
|
|
|
|
)
|
|
|
|
|
2024-07-31 05:08:41 -06:00
|
|
|
if can_use_gptq_marlin(
|
2024-07-09 12:04:03 -06:00
|
|
|
bits=quantizer_config.bits,
|
|
|
|
groupsize=quantizer_config.groupsize,
|
|
|
|
quant_method=quantizer_config.quant_method,
|
|
|
|
quantize=quantize,
|
|
|
|
sym=quantizer_config.sym,
|
2024-07-31 05:08:41 -06:00
|
|
|
):
|
|
|
|
from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader
|
|
|
|
|
|
|
|
return GPTQMarlinWeightsLoader(
|
|
|
|
bits=quantizer_config.bits,
|
|
|
|
desc_act=quantizer_config.desc_act,
|
|
|
|
groupsize=quantizer_config.groupsize,
|
|
|
|
quant_method=quantizer_config.quant_method,
|
|
|
|
quantize=quantize,
|
|
|
|
sym=quantizer_config.sym,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
return GPTQWeightsLoader(
|
|
|
|
bits=quantizer_config.bits,
|
|
|
|
desc_act=quantizer_config.desc_act,
|
|
|
|
groupsize=quantizer_config.groupsize,
|
|
|
|
quant_method=quantizer_config.quant_method,
|
|
|
|
quantize=quantize,
|
|
|
|
sym=quantizer_config.sym,
|
|
|
|
)
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
elif quantize == "bitsandbytes":
|
|
|
|
from text_generation_server.layers.bnb import BNBWeight
|
|
|
|
|
|
|
|
return DefaultWeightsLoader(BNBWeight)
|
|
|
|
elif quantize == "bitsandbytes-fp4":
|
|
|
|
from text_generation_server.layers.bnb import BNBFP4Weight
|
|
|
|
|
|
|
|
return DefaultWeightsLoader(BNBFP4Weight)
|
|
|
|
elif quantize == "bitsandbytes-nf4":
|
|
|
|
from text_generation_server.layers.bnb import BNBNF4Weight
|
|
|
|
|
|
|
|
return DefaultWeightsLoader(BNBNF4Weight)
|
|
|
|
elif quantize == "eetq":
|
|
|
|
from text_generation_server.layers.eetq import EETQWeight
|
|
|
|
|
|
|
|
return DefaultWeightsLoader(EETQWeight)
|
2024-07-09 12:04:03 -06:00
|
|
|
elif quantize == "exl2":
|
|
|
|
from text_generation_server.layers.exl2 import Exl2WeightsLoader
|
|
|
|
|
|
|
|
return Exl2WeightsLoader()
|
|
|
|
elif quantize == "marlin":
|
|
|
|
from text_generation_server.layers.marlin import MarlinWeightsLoader
|
|
|
|
|
2024-07-20 11:02:04 -06:00
|
|
|
# TODO: improve check once we have one config type per quantize value
|
|
|
|
if not isinstance(quantizer_config, _QuantizerConfig):
|
|
|
|
raise ValueError(
|
|
|
|
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
|
|
|
|
)
|
|
|
|
|
2024-07-09 12:04:03 -06:00
|
|
|
return MarlinWeightsLoader(
|
|
|
|
bits=quantizer_config.bits,
|
|
|
|
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
|
|
|
|
)
|
2024-07-20 11:02:04 -06:00
|
|
|
elif quantize == "fp8" or quantize is None:
|
|
|
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
|
|
|
|
|
|
|
# Since the default for the quantize config is _QuantizerConfig,
|
|
|
|
# we need to add this check to not get an attribute error
|
|
|
|
activation_scale_ub = None
|
|
|
|
if isinstance(quantizer_config, _FP8QuantizerConfig):
|
|
|
|
activation_scale_ub = quantizer_config.activation_scale_ub
|
|
|
|
|
|
|
|
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8")
|
2024-07-09 12:04:03 -06:00
|
|
|
else:
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
raise ValueError(f"Unknown quantization method: {quantize}")
|