diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 14d2df01..a96bc22c 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -91,12 +91,12 @@ def serve( raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) - - if kv_cache_dtype in {"fp8", "fp8_e5m2"} and SYSTEM not in {"cuda", "rocm"}: - raise RuntimeError(f"{kv_cache_dtype} KV cache is only supported on Nvidia and AMD GPUs.") - - if kv_cache_dtype == "fp8_e5m2" and SYSTEM != "cuda": - raise RuntimeError(f"fp8_e5m2 KV cache is only supported on Nvidia GPUs.") + + if kv_cache_dtype in {"fp8", "fp8_e5m2"}: + if SYSTEM not in {"cuda", "rocm"}: + raise RuntimeError(f"`{kv_cache_dtype}` KV cache is only supported on Nvidia and AMD GPUs.") + if kv_cache_dtype == "fp8_e5m2" and SYSTEM != "cuda": + raise RuntimeError(f"`fp8_e5m2` KV cache is only supported on Nvidia GPUs.") server.serve( model_id, diff --git a/server/text_generation_server/layers/schema.py b/server/text_generation_server/layers/schema.py deleted file mode 100644 index ca7d81a3..00000000 --- a/server/text_generation_server/layers/schema.py +++ /dev/null @@ -1,90 +0,0 @@ -""" -This file contains the Pydantic schemas for various quantization-related -parameters. When a relevant quantization technique is specified, these -parameters are loaded in the form of a JSON alongside the model weights -and augment the model with additional information needed for use of that -technique. The format of this JSON should be specified by one or more -schemas contained here. - -For example, when the KV cache is quantized to FP8-E4M3 (currently only -possible on ROCm), the model can be optionally augmented with KV cache -scaling factors. -""" - -from typing import Dict, Optional - -from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator - - -class KVCacheQuantSchema(BaseModel): - dtype: str - # Each key is a TP rank. Each value is a dictionary mapping a TP rank's - # layer indices to their per-tensor KV cache scaling factor. - # TODO: Consider pulling this and its validation methods out into its - # own schema class (tricky as its members are variable) - scaling_factor: Dict[int, Dict[int, float]] - - @model_validator(mode="after") - def check_is_fp8(self) -> "KVCacheQuantSchema": - assert self.dtype == "float8_e4m3fn", ( - "Loaded scaling factors intended for KV cache dtype = " - f"{self.dtype} rather than float8_e4m3fn!" - ) - return self - - @model_validator(mode="after") - def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema": - context = info.context - if context: - tp_size = context["tp_size"] - num_hidden_layers = context["num_hidden_layers"] - assert len(self.scaling_factor) == tp_size, ( - f"Loaded dictionary has TP size {len(self.scaling_factor)} " - f"but LLM engine is currently running with TP size {tp_size}." - ) - for tp_rank, layer_maps in self.scaling_factor.items(): - assert len(layer_maps) == num_hidden_layers, ( - f"KV cache scales map for TP rank {tp_rank} is malformed. " - f"Expected {num_hidden_layers} layers, got " - f"{len(layer_maps)}." - ) - for i in range(tp_size): - assert ( - i in self.scaling_factor - ), f"KV cache scales map for TP rank {i} not found." - return self - - @model_validator(mode="after") - def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema": - context = info.context - if context: - tp_rank = context["tp_rank"] - num_hidden_layers = context["num_hidden_layers"] - layer_scales_map = self.scaling_factor[tp_rank] - for i in range(num_hidden_layers): - assert i in layer_scales_map, ( - f"Could not find KV cache scales for layer {i} in " - f"TP rank {tp_rank}." - ) - return self - - -class QuantParamSchema(BaseModel): - # TODO: Generalize and extend with more fields - # (e.g. weights/activations params) once functionality is enabled - model_config = ConfigDict(protected_namespaces=()) - model_type: Optional[str] - kv_cache: KVCacheQuantSchema - - @model_validator(mode="after") - def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema": - context = info.context - if context: - model_type = context.get("model_type", None) - if model_type is not None: - assert model_type == self.model_type, ( - f"Model type is {model_type} but loaded " - f"scaling factors belonging to different " - f"model type {self.model_type}!" - ) - return self diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 4a8813a7..5eb532e8 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -294,7 +294,7 @@ def get_model( if model_type not in FP8_KVCACHE_SUPPORTED_MODELS and kv_cache_dtype != "auto": raise RuntimeError( - f"kv_cache_dtype is only supported for Llama models. Got model_type: {model_type}, kv_cache_dtype: {kv_cache_dtype}" + f"kv_cache_dtype is only supported for {", ".join(FP8_KVCACHE_SUPPORTED_MODELS)} models. Got model_type: {model_type}, kv_cache_dtype: {kv_cache_dtype}" ) speculator = None diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index cc9af126..d16d3710 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -726,8 +726,6 @@ class FlashCausalLM(Model): head_size: int, dtype: torch.dtype, device: torch.device, - kv_cache_dtype: str = "auto", - quantization_param_path: Optional[str] = None, rank: int = 0, world_size: int = 1, sliding_window: Optional[int] = None, @@ -750,37 +748,6 @@ class FlashCausalLM(Model): sliding_window=sliding_window, ) - if kv_cache_dtype == "fp8": - self.kv_cache_dtype = torch.uint8 - else: - self.kv_cache_dtype = self.dtype - - if kv_cache_dtype == "fp8" and SYSTEM == "rocm": - logger.info(f"Using KV cache data type: {kv_cache_dtype}") - # Currently scaled KV cache is only enabled on ROCm - if quantization_param_path is not None: - if callable(getattr(self.model, "load_kv_cache_scales", None)): - self.model.load_kv_cache_scales(quantization_param_path) - else: - raise RuntimeError( - "Using FP8 KV cache and scaling " - "factors provided but model " - f"{self.model.__class__} does not " - "support loading scaling factors." - ) - else: - logger.info( - "Using FP8 KV cache but no scaling factors " - "provided. Defaulting to scaling factors of 1.0. " - "This may lead to less accurate results!" - ) - elif quantization_param_path is not None: - logger.info( - "KV cache scaling factors provided, " - "but the KV cache data type is not FP8. " - "KV cache scaling factors will not be used." - ) - @property def batch_type(self) -> Type[FlashCausalLMBatch]: return FlashCausalLMBatch @@ -906,7 +873,7 @@ class FlashCausalLM(Model): # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory - dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() + dtype_size = torch.tensor([], dtype=self.dtype).element_size() cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size @@ -986,7 +953,7 @@ class FlashCausalLM(Model): if self.speculate is None or self.speculate + 1 <= bs: self.cuda_graph_warmup(bs, max_s, max_bt) except torch.cuda.OutOfMemoryError: - logger.exception("Decode cuda graph warmup failed") + logger.exception(f"Decode cuda graph warmup failed") else: logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py deleted file mode 100644 index af1ff016..00000000 --- a/server/text_generation_server/utils/paged_attention.py +++ /dev/null @@ -1,142 +0,0 @@ -from loguru import logger -import torch -from text_generation_server.utils.import_utils import SYSTEM - -_PARTITION_SIZE = 512 - -if SYSTEM == "xpu": - import intel_extension_for_pytorch as ipex -else: - try: - from vllm._C import cache_ops - from vllm._C import ops - except Exception as e: - raise ImportError( - f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" - ) - - -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slots: torch.Tensor, - kv_cache_dtype: str = "auto", - kv_scale: int = 1.0, -): - if SYSTEM == "xpu": - ipex.llm.modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, slots - ) - else: - cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale - ) - - -def attention( - out: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - kv_head_mapping: torch.Tensor, - softmax_scale: float, - block_tables: torch.Tensor, - input_lengths: torch.Tensor, - max_s: int, - kv_cache_dtype: str = "auto", - kv_scale: int = 1.0, -): - # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py - # Copyright 2023 The vLLM team. All rights - # reserved. - # - # Licensed under the Apache License, Version 2.0 (the "License"); - # you may not use this file except in compliance with the License. - # You may obtain a copy of the License at - # - # http://www.apache.org/licenses/LICENSE-2.0 - # - # Unless required by applicable law or agreed to in writing, software - # distributed under the License is distributed on an "AS IS" BASIS, - # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - # See the License for the specific language governing permissions and - # limitations under the License. - # - - # value_cache => [num_blocks, num_heads, head_size, block_size] - block_size = value_cache.shape[3] - num_seqs, num_heads, head_size = query.shape - max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - if SYSTEM == "xpu": - query = query.contiguous() - return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( - out, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - ) - - # NOTE(woosuk): We use a simple heuristic to decide whether to use - # PagedAttention V1 or V2. If the number of partitions is 1, we use - # V1 to avoid the overhead of reduction. Also, if the number of - # sequences or heads is large, we use V1 since there is enough work - # to parallelize. - use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) - if use_v1: - ops.paged_attention_v1( - out, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - kv_cache_dtype, - kv_scale, - ) - else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=out.dtype, - device=out.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=out.device, - ) - max_logits = torch.empty_like(exp_sums) - - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - kv_cache_dtype, - kv_scale, - ) diff --git a/server/text_generation_server/utils/weights_utils.py b/server/text_generation_server/utils/weights_utils.py deleted file mode 100644 index f96d1a6d..00000000 --- a/server/text_generation_server/utils/weights_utils.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Optional, Tuple, Iterable -from loguru import logger -import json -from text_generation_server.layers.schema import QuantParamSchema - - -def kv_cache_scales_loader( - filename: str, - tp_rank: int, - tp_size: int, - num_hidden_layers: int, - model_type: Optional[str], -) -> Iterable[Tuple[int, float]]: - """ - A simple utility to read in KV cache scaling factors that have been - previously serialized to disk. Used by the model to populate the appropriate - KV cache scaling factors. The serialization should represent a dictionary - whose keys are the TP ranks and values are another dictionary mapping layers - to their KV cache scaling factors. - """ - try: - with open(filename) as f: - context = { - "model_type": model_type, - "num_hidden_layers": num_hidden_layers, - "tp_rank": tp_rank, - "tp_size": tp_size, - } - schema_dct = json.load(f) - schema = QuantParamSchema.model_validate(schema_dct, context=context) - layer_scales_map = schema.kv_cache.scaling_factor[tp_rank] - return layer_scales_map.items() - - except FileNotFoundError: - logger.error(f"File or directory '{filename}' not found.") - except json.JSONDecodeError: - logger.error(f"Error decoding JSON in file '{filename}'.") - except Exception as e: - logger.error(f"An error occurred while reading '{filename}': {e}") - # This section is reached if and only if any of the excepts are hit - # Return an empty iterable (list) => no KV cache scales are loaded - # which ultimately defaults to 1.0 scales - logger.warning( - "Defaulting to KV cache scaling factors = 1.0 " - f"for all layers in TP rank {tp_rank} " - "as an error occurred during loading." - ) - return []