2024-05-28 03:51:31 -06:00
|
|
|
from dataclasses import dataclass
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
from typing import List, Union
|
2024-05-28 03:51:31 -06:00
|
|
|
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
import torch
|
|
|
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
2024-07-09 12:04:03 -06:00
|
|
|
|
2024-05-28 03:51:31 -06:00
|
|
|
|
|
|
|
@dataclass
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
class Exl2Weight(Weight):
|
2024-05-28 03:51:31 -06:00
|
|
|
"""
|
|
|
|
Exllama2 exl2 quantized weights.
|
|
|
|
"""
|
|
|
|
|
|
|
|
q_weight: torch.Tensor
|
|
|
|
q_scale: torch.Tensor
|
|
|
|
q_invperm: torch.Tensor
|
|
|
|
q_scale_max: torch.Tensor
|
|
|
|
q_groups: torch.Tensor
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
self.q_scale_max /= 256
|
|
|
|
self.q_invperm = self.q_invperm.short()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def device(self) -> torch.device:
|
|
|
|
return self.q_weight.device
|
2024-07-09 12:04:03 -06:00
|
|
|
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
def get_linear(self, bias: torch.Tensor):
|
|
|
|
from text_generation_server.layers.gptq import ExllamaQuantLinear
|
|
|
|
|
|
|
|
return ExllamaQuantLinear(self, bias)
|
|
|
|
|
2024-07-09 12:04:03 -06:00
|
|
|
|
|
|
|
class Exl2WeightsLoader(WeightsLoader):
|
|
|
|
"""Loader for exl2-quantized weights."""
|
|
|
|
|
|
|
|
def get_weights_col_packed(
|
|
|
|
self,
|
|
|
|
weights: Weights,
|
|
|
|
prefix: str,
|
|
|
|
block_sizes: Union[int, List[int]],
|
|
|
|
):
|
|
|
|
raise RuntimeError("Column-packed weights are not supported for exl")
|
|
|
|
|
|
|
|
def get_weights_col(self, weights: Weights, prefix: str):
|
|
|
|
try:
|
|
|
|
q_weight = weights.get_tensor(f"{prefix}.q_weight")
|
|
|
|
except RuntimeError:
|
|
|
|
raise RuntimeError(
|
|
|
|
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
|
|
|
)
|
|
|
|
|
|
|
|
q_scale = weights.get_tensor(f"{prefix}.q_scale")
|
|
|
|
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
|
|
|
|
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
|
|
|
|
q_groups = weights.get_tensor(f"{prefix}.q_groups")
|
|
|
|
|
|
|
|
return Exl2Weight(
|
|
|
|
q_weight=q_weight,
|
|
|
|
q_scale=q_scale,
|
|
|
|
q_invperm=q_invperm,
|
|
|
|
q_scale_max=q_scale_max,
|
|
|
|
q_groups=q_groups,
|
|
|
|
)
|
|
|
|
|
|
|
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
|
|
|
raise ValueError("get_multi_weights_col is not supported for exl2")
|
|
|
|
|
|
|
|
def get_weights_row(self, weights: Weights, prefix: str):
|
|
|
|
try:
|
|
|
|
q_weight = weights.get_tensor(f"{prefix}.q_weight")
|
|
|
|
except RuntimeError:
|
|
|
|
raise RuntimeError(
|
|
|
|
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
|
|
|
)
|
|
|
|
|
|
|
|
q_scale = weights.get_tensor(f"{prefix}.q_scale")
|
|
|
|
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
|
|
|
|
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
|
|
|
|
q_groups = weights.get_tensor(f"{prefix}.q_groups")
|
|
|
|
|
|
|
|
return Exl2Weight(
|
|
|
|
q_weight=q_weight,
|
|
|
|
q_scale=q_scale,
|
|
|
|
q_invperm=q_invperm,
|
|
|
|
q_scale_max=q_scale_max,
|
|
|
|
q_groups=q_groups,
|
|
|
|
)
|