Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
from dataclasses import dataclass
|
|
|
|
|
2024-05-13 04:44:30 -06:00
|
|
|
import torch
|
|
|
|
from EETQ import quant_weights, w8_a16_gemm
|
2024-07-20 11:02:04 -06:00
|
|
|
from text_generation_server.utils.weights import UnquantizedWeight
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
2024-07-20 11:02:04 -06:00
|
|
|
class EETQWeight(UnquantizedWeight):
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
weight: torch.Tensor
|
|
|
|
|
|
|
|
def get_linear(self, bias: torch.Tensor):
|
|
|
|
try:
|
|
|
|
from text_generation_server.layers.eetq import EETQLinear
|
|
|
|
|
|
|
|
return EETQLinear(self.weight, bias)
|
|
|
|
except ImportError:
|
|
|
|
raise ImportError(
|
|
|
|
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
|
|
|
)
|
2024-05-13 04:44:30 -06:00
|
|
|
|
|
|
|
|
|
|
|
class EETQLinear(torch.nn.Module):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
weight,
|
|
|
|
bias,
|
|
|
|
) -> None:
|
|
|
|
super().__init__()
|
|
|
|
device = weight.device
|
|
|
|
if weight.dtype != torch.float16:
|
|
|
|
weight = weight.to(dtype=torch.float16)
|
|
|
|
weight = torch.t(weight).contiguous().cpu()
|
|
|
|
weight, scale = quant_weights(weight, torch.int8, False)
|
|
|
|
|
|
|
|
self.weight = weight.cuda(device)
|
|
|
|
self.scale = scale.cuda(device)
|
|
|
|
self.bias = bias.cuda(device) if bias is not None else None
|
|
|
|
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
|
|
output = w8_a16_gemm(input, self.weight, self.scale)
|
|
|
|
output = output + self.bias if self.bias is not None else output
|
|
|
|
return output
|