Use FP8 KV cache when specified by compressed-tensors

The compressed-tensors configuration can specify the configuration of
the KV cache as well. Use an FP8 KV cache when the configuration tells
us to do so (all other options and types are ignored for now).
This commit is contained in:
Daniël de Kok 2024-11-20 12:31:47 +00:00
parent 45013b60a4
commit 74a8a820ad
1 changed files with 33 additions and 6 deletions

View File

@ -1,6 +1,11 @@
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables
from compressed_tensors.compressors.model_compressors.model_compressor import (
QuantizationConfig,
)
from compressed_tensors.quantization import QuantizationType
from pydantic import ValidationError
import torch
import enum
import os
@ -23,6 +28,7 @@ from text_generation_server.models.bloom import BloomCausalLMBatch
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
from text_generation_server.models.globals import ATTENTION
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
@ -367,7 +373,8 @@ def get_model(
model_type = config_dict.get("model_type", None)
quantization_config = config_dict.get("quantization_config", None)
compression_config = config_dict.get("compression_config", None)
if quantization_config is None:
quantization_config = config_dict.get("compression_config", None)
if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None)
if method in {"gptq", "awq", "exl2"}:
@ -381,12 +388,9 @@ def get_model(
logger.info, "Auto selecting quantization method compressed-tensors"
)
quantize = "compressed-tensors"
else:
log_master(logger.warning, f"Unknown quantization method {method}")
elif compression_config is not None:
# `compression_config` renamed to `quantization_config`; support retained for backward compatibility.
log_master(logger.info, "Auto selecting quantization method compressed-tensors")
quantize = "compressed-tensors"
if dtype is None:
if quantize in ["awq", "exl2", "gptq", "marlin"]:
@ -408,8 +412,31 @@ def get_model(
else:
raise RuntimeError(f"Unknown dtype {dtype}")
compressed_tensors_config = None
if quantize == "compressed-tensors":
try:
compressed_tensors_config = QuantizationConfig.model_validate(
quantization_config
)
except ValidationError as e:
raise ValueError("Cannot parse compressed-tensors configuration") from e
if kv_cache_dtype is None:
kv_cache_dtype = dtype
kv_cache_scheme = (
compressed_tensors_config.kv_cache_scheme
if isinstance(compressed_tensors_config, QuantizationConfig)
else None
)
if (
kv_cache_scheme is not None
and kv_cache_scheme.type == QuantizationType.FLOAT
and kv_cache_scheme.num_bits == 8
and SYSTEM == "cuda"
and ATTENTION == "flashinfer"
):
kv_cache_dtype = torch.float8_e4m3fn
else:
kv_cache_dtype = dtype
elif kv_cache_dtype == "fp8_e4m3fn":
kv_cache_dtype = torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":