diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 89164577..fcc79608 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,6 +1,11 @@ # ruff: noqa: F821 # the above line disables the `undefined-name` rule for the model type variables +from compressed_tensors.compressors.model_compressors.model_compressor import ( + QuantizationConfig, +) +from compressed_tensors.quantization import QuantizationType +from pydantic import ValidationError import torch import enum import os @@ -23,6 +28,7 @@ from text_generation_server.models.bloom import BloomCausalLMBatch from text_generation_server.models.custom_modeling.bloom_modeling import ( BloomForCausalLM, ) +from text_generation_server.models.globals import ATTENTION from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.galactica import GalacticaCausalLMBatch from text_generation_server.models.custom_modeling.neox_modeling import ( @@ -367,7 +373,8 @@ def get_model( model_type = config_dict.get("model_type", None) quantization_config = config_dict.get("quantization_config", None) - compression_config = config_dict.get("compression_config", None) + if quantization_config is None: + quantization_config = config_dict.get("compression_config", None) if quantization_config is not None and quantize is None: method = quantization_config.get("quant_method", None) if method in {"gptq", "awq", "exl2"}: @@ -381,12 +388,9 @@ def get_model( logger.info, "Auto selecting quantization method compressed-tensors" ) quantize = "compressed-tensors" + else: log_master(logger.warning, f"Unknown quantization method {method}") - elif compression_config is not None: - # `compression_config` renamed to `quantization_config`; support retained for backward compatibility. - log_master(logger.info, "Auto selecting quantization method compressed-tensors") - quantize = "compressed-tensors" if dtype is None: if quantize in ["awq", "exl2", "gptq", "marlin"]: @@ -408,8 +412,31 @@ def get_model( else: raise RuntimeError(f"Unknown dtype {dtype}") + compressed_tensors_config = None + if quantize == "compressed-tensors": + try: + compressed_tensors_config = QuantizationConfig.model_validate( + quantization_config + ) + except ValidationError as e: + raise ValueError("Cannot parse compressed-tensors configuration") from e + if kv_cache_dtype is None: - kv_cache_dtype = dtype + kv_cache_scheme = ( + compressed_tensors_config.kv_cache_scheme + if isinstance(compressed_tensors_config, QuantizationConfig) + else None + ) + if ( + kv_cache_scheme is not None + and kv_cache_scheme.type == QuantizationType.FLOAT + and kv_cache_scheme.num_bits == 8 + and SYSTEM == "cuda" + and ATTENTION == "flashinfer" + ): + kv_cache_dtype = torch.float8_e4m3fn + else: + kv_cache_dtype = dtype elif kv_cache_dtype == "fp8_e4m3fn": kv_cache_dtype = torch.float8_e4m3fn elif kv_cache_dtype == "fp8_e5m2":