Use FP8 KV cache when specified by compressed-tensors (#2761)
The compressed-tensors configuration can specify the configuration of the KV cache as well. Use an FP8 KV cache when the configuration tells us to do so (all other options and types are ignored for now).
This commit is contained in:
parent
289aa48554
commit
72ab60fdd5
|
@ -1,6 +1,11 @@
|
|||
# ruff: noqa: F821
|
||||
# the above line disables the `undefined-name` rule for the model type variables
|
||||
|
||||
from compressed_tensors.compressors.model_compressors.model_compressor import (
|
||||
QuantizationConfig,
|
||||
)
|
||||
from compressed_tensors.quantization import QuantizationType
|
||||
from pydantic import ValidationError
|
||||
import torch
|
||||
import enum
|
||||
import os
|
||||
|
@ -23,6 +28,7 @@ from text_generation_server.models.bloom import BloomCausalLMBatch
|
|||
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||
BloomForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||
from text_generation_server.models.galactica import GalacticaCausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.neox_modeling import (
|
||||
|
@ -367,7 +373,8 @@ def get_model(
|
|||
model_type = config_dict.get("model_type", None)
|
||||
|
||||
quantization_config = config_dict.get("quantization_config", None)
|
||||
compression_config = config_dict.get("compression_config", None)
|
||||
if quantization_config is None:
|
||||
quantization_config = config_dict.get("compression_config", None)
|
||||
if quantization_config is not None and quantize is None:
|
||||
method = quantization_config.get("quant_method", None)
|
||||
if method in {"gptq", "awq", "exl2"}:
|
||||
|
@ -381,12 +388,9 @@ def get_model(
|
|||
logger.info, "Auto selecting quantization method compressed-tensors"
|
||||
)
|
||||
quantize = "compressed-tensors"
|
||||
|
||||
else:
|
||||
log_master(logger.warning, f"Unknown quantization method {method}")
|
||||
elif compression_config is not None:
|
||||
# `compression_config` renamed to `quantization_config`; support retained for backward compatibility.
|
||||
log_master(logger.info, "Auto selecting quantization method compressed-tensors")
|
||||
quantize = "compressed-tensors"
|
||||
|
||||
if dtype is None:
|
||||
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
|
@ -408,7 +412,30 @@ def get_model(
|
|||
else:
|
||||
raise RuntimeError(f"Unknown dtype {dtype}")
|
||||
|
||||
compressed_tensors_config = None
|
||||
if quantize == "compressed-tensors":
|
||||
try:
|
||||
compressed_tensors_config = QuantizationConfig.model_validate(
|
||||
quantization_config
|
||||
)
|
||||
except ValidationError as e:
|
||||
raise ValueError("Cannot parse compressed-tensors configuration") from e
|
||||
|
||||
if kv_cache_dtype is None:
|
||||
kv_cache_scheme = (
|
||||
compressed_tensors_config.kv_cache_scheme
|
||||
if isinstance(compressed_tensors_config, QuantizationConfig)
|
||||
else None
|
||||
)
|
||||
if (
|
||||
kv_cache_scheme is not None
|
||||
and kv_cache_scheme.type == QuantizationType.FLOAT
|
||||
and kv_cache_scheme.num_bits == 8
|
||||
and SYSTEM == "cuda"
|
||||
and ATTENTION == "flashinfer"
|
||||
):
|
||||
kv_cache_dtype = torch.float8_e4m3fn
|
||||
else:
|
||||
kv_cache_dtype = dtype
|
||||
elif kv_cache_dtype == "fp8_e4m3fn":
|
||||
kv_cache_dtype = torch.float8_e4m3fn
|
||||
|
|
Loading…
Reference in New Issue