fix
This commit is contained in:
parent
81fd601c44
commit
fb83e3416b
|
@ -92,11 +92,11 @@ def serve(
|
||||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||||
)
|
)
|
||||||
|
|
||||||
if kv_cache_dtype in {"fp8", "fp8_e5m2"} and SYSTEM not in {"cuda", "rocm"}:
|
if kv_cache_dtype in {"fp8", "fp8_e5m2"}:
|
||||||
raise RuntimeError(f"{kv_cache_dtype} KV cache is only supported on Nvidia and AMD GPUs.")
|
if SYSTEM not in {"cuda", "rocm"}:
|
||||||
|
raise RuntimeError(f"`{kv_cache_dtype}` KV cache is only supported on Nvidia and AMD GPUs.")
|
||||||
if kv_cache_dtype == "fp8_e5m2" and SYSTEM != "cuda":
|
if kv_cache_dtype == "fp8_e5m2" and SYSTEM != "cuda":
|
||||||
raise RuntimeError(f"fp8_e5m2 KV cache is only supported on Nvidia GPUs.")
|
raise RuntimeError(f"`fp8_e5m2` KV cache is only supported on Nvidia GPUs.")
|
||||||
|
|
||||||
server.serve(
|
server.serve(
|
||||||
model_id,
|
model_id,
|
||||||
|
|
|
@ -1,90 +0,0 @@
|
||||||
"""
|
|
||||||
This file contains the Pydantic schemas for various quantization-related
|
|
||||||
parameters. When a relevant quantization technique is specified, these
|
|
||||||
parameters are loaded in the form of a JSON alongside the model weights
|
|
||||||
and augment the model with additional information needed for use of that
|
|
||||||
technique. The format of this JSON should be specified by one or more
|
|
||||||
schemas contained here.
|
|
||||||
|
|
||||||
For example, when the KV cache is quantized to FP8-E4M3 (currently only
|
|
||||||
possible on ROCm), the model can be optionally augmented with KV cache
|
|
||||||
scaling factors.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
|
|
||||||
|
|
||||||
|
|
||||||
class KVCacheQuantSchema(BaseModel):
|
|
||||||
dtype: str
|
|
||||||
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
|
|
||||||
# layer indices to their per-tensor KV cache scaling factor.
|
|
||||||
# TODO: Consider pulling this and its validation methods out into its
|
|
||||||
# own schema class (tricky as its members are variable)
|
|
||||||
scaling_factor: Dict[int, Dict[int, float]]
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def check_is_fp8(self) -> "KVCacheQuantSchema":
|
|
||||||
assert self.dtype == "float8_e4m3fn", (
|
|
||||||
"Loaded scaling factors intended for KV cache dtype = "
|
|
||||||
f"{self.dtype} rather than float8_e4m3fn!"
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
|
|
||||||
context = info.context
|
|
||||||
if context:
|
|
||||||
tp_size = context["tp_size"]
|
|
||||||
num_hidden_layers = context["num_hidden_layers"]
|
|
||||||
assert len(self.scaling_factor) == tp_size, (
|
|
||||||
f"Loaded dictionary has TP size {len(self.scaling_factor)} "
|
|
||||||
f"but LLM engine is currently running with TP size {tp_size}."
|
|
||||||
)
|
|
||||||
for tp_rank, layer_maps in self.scaling_factor.items():
|
|
||||||
assert len(layer_maps) == num_hidden_layers, (
|
|
||||||
f"KV cache scales map for TP rank {tp_rank} is malformed. "
|
|
||||||
f"Expected {num_hidden_layers} layers, got "
|
|
||||||
f"{len(layer_maps)}."
|
|
||||||
)
|
|
||||||
for i in range(tp_size):
|
|
||||||
assert (
|
|
||||||
i in self.scaling_factor
|
|
||||||
), f"KV cache scales map for TP rank {i} not found."
|
|
||||||
return self
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
|
|
||||||
context = info.context
|
|
||||||
if context:
|
|
||||||
tp_rank = context["tp_rank"]
|
|
||||||
num_hidden_layers = context["num_hidden_layers"]
|
|
||||||
layer_scales_map = self.scaling_factor[tp_rank]
|
|
||||||
for i in range(num_hidden_layers):
|
|
||||||
assert i in layer_scales_map, (
|
|
||||||
f"Could not find KV cache scales for layer {i} in "
|
|
||||||
f"TP rank {tp_rank}."
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class QuantParamSchema(BaseModel):
|
|
||||||
# TODO: Generalize and extend with more fields
|
|
||||||
# (e.g. weights/activations params) once functionality is enabled
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
|
||||||
model_type: Optional[str]
|
|
||||||
kv_cache: KVCacheQuantSchema
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
|
|
||||||
context = info.context
|
|
||||||
if context:
|
|
||||||
model_type = context.get("model_type", None)
|
|
||||||
if model_type is not None:
|
|
||||||
assert model_type == self.model_type, (
|
|
||||||
f"Model type is {model_type} but loaded "
|
|
||||||
f"scaling factors belonging to different "
|
|
||||||
f"model type {self.model_type}!"
|
|
||||||
)
|
|
||||||
return self
|
|
|
@ -294,7 +294,7 @@ def get_model(
|
||||||
|
|
||||||
if model_type not in FP8_KVCACHE_SUPPORTED_MODELS and kv_cache_dtype != "auto":
|
if model_type not in FP8_KVCACHE_SUPPORTED_MODELS and kv_cache_dtype != "auto":
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"kv_cache_dtype is only supported for Llama models. Got model_type: {model_type}, kv_cache_dtype: {kv_cache_dtype}"
|
f"kv_cache_dtype is only supported for {", ".join(FP8_KVCACHE_SUPPORTED_MODELS)} models. Got model_type: {model_type}, kv_cache_dtype: {kv_cache_dtype}"
|
||||||
)
|
)
|
||||||
|
|
||||||
speculator = None
|
speculator = None
|
||||||
|
|
|
@ -726,8 +726,6 @@ class FlashCausalLM(Model):
|
||||||
head_size: int,
|
head_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
kv_cache_dtype: str = "auto",
|
|
||||||
quantization_param_path: Optional[str] = None,
|
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
|
@ -750,37 +748,6 @@ class FlashCausalLM(Model):
|
||||||
sliding_window=sliding_window,
|
sliding_window=sliding_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
if kv_cache_dtype == "fp8":
|
|
||||||
self.kv_cache_dtype = torch.uint8
|
|
||||||
else:
|
|
||||||
self.kv_cache_dtype = self.dtype
|
|
||||||
|
|
||||||
if kv_cache_dtype == "fp8" and SYSTEM == "rocm":
|
|
||||||
logger.info(f"Using KV cache data type: {kv_cache_dtype}")
|
|
||||||
# Currently scaled KV cache is only enabled on ROCm
|
|
||||||
if quantization_param_path is not None:
|
|
||||||
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
|
||||||
self.model.load_kv_cache_scales(quantization_param_path)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Using FP8 KV cache and scaling "
|
|
||||||
"factors provided but model "
|
|
||||||
f"{self.model.__class__} does not "
|
|
||||||
"support loading scaling factors."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"Using FP8 KV cache but no scaling factors "
|
|
||||||
"provided. Defaulting to scaling factors of 1.0. "
|
|
||||||
"This may lead to less accurate results!"
|
|
||||||
)
|
|
||||||
elif quantization_param_path is not None:
|
|
||||||
logger.info(
|
|
||||||
"KV cache scaling factors provided, "
|
|
||||||
"but the KV cache data type is not FP8. "
|
|
||||||
"KV cache scaling factors will not be used."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
||||||
return FlashCausalLMBatch
|
return FlashCausalLMBatch
|
||||||
|
@ -906,7 +873,7 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||||
# Calculate the number of blocks that can be allocated with the free memory
|
# Calculate the number of blocks that can be allocated with the free memory
|
||||||
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
|
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
|
||||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||||
|
|
||||||
|
@ -986,7 +953,7 @@ class FlashCausalLM(Model):
|
||||||
if self.speculate is None or self.speculate + 1 <= bs:
|
if self.speculate is None or self.speculate + 1 <= bs:
|
||||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||||
except torch.cuda.OutOfMemoryError:
|
except torch.cuda.OutOfMemoryError:
|
||||||
logger.exception("Decode cuda graph warmup failed")
|
logger.exception(f"Decode cuda graph warmup failed")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
||||||
|
|
||||||
|
|
|
@ -1,142 +0,0 @@
|
||||||
from loguru import logger
|
|
||||||
import torch
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
_PARTITION_SIZE = 512
|
|
||||||
|
|
||||||
if SYSTEM == "xpu":
|
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
from vllm._C import cache_ops
|
|
||||||
from vllm._C import ops
|
|
||||||
except Exception as e:
|
|
||||||
raise ImportError(
|
|
||||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def reshape_and_cache(
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
|
||||||
kv_cache_dtype: str = "auto",
|
|
||||||
kv_scale: int = 1.0,
|
|
||||||
):
|
|
||||||
if SYSTEM == "xpu":
|
|
||||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
|
||||||
key, value, key_cache, value_cache, slots
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cache_ops.reshape_and_cache(
|
|
||||||
key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
out: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
kv_head_mapping: torch.Tensor,
|
|
||||||
softmax_scale: float,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
input_lengths: torch.Tensor,
|
|
||||||
max_s: int,
|
|
||||||
kv_cache_dtype: str = "auto",
|
|
||||||
kv_scale: int = 1.0,
|
|
||||||
):
|
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
|
||||||
# Copyright 2023 The vLLM team. All rights
|
|
||||||
# reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
|
|
||||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
|
||||||
block_size = value_cache.shape[3]
|
|
||||||
num_seqs, num_heads, head_size = query.shape
|
|
||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
|
||||||
if SYSTEM == "xpu":
|
|
||||||
query = query.contiguous()
|
|
||||||
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
|
||||||
out,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
kv_head_mapping,
|
|
||||||
softmax_scale,
|
|
||||||
block_tables,
|
|
||||||
input_lengths,
|
|
||||||
block_size,
|
|
||||||
max_s,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
|
||||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
|
||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
|
||||||
# sequences or heads is large, we use V1 since there is enough work
|
|
||||||
# to parallelize.
|
|
||||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
|
||||||
if use_v1:
|
|
||||||
ops.paged_attention_v1(
|
|
||||||
out,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
kv_head_mapping,
|
|
||||||
softmax_scale,
|
|
||||||
block_tables,
|
|
||||||
input_lengths,
|
|
||||||
block_size,
|
|
||||||
max_s,
|
|
||||||
None,
|
|
||||||
kv_cache_dtype,
|
|
||||||
kv_scale,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Run PagedAttention V2.
|
|
||||||
assert _PARTITION_SIZE % block_size == 0
|
|
||||||
tmp_output = torch.empty(
|
|
||||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
|
||||||
dtype=out.dtype,
|
|
||||||
device=out.device,
|
|
||||||
)
|
|
||||||
exp_sums = torch.empty(
|
|
||||||
size=(num_seqs, num_heads, max_num_partitions),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=out.device,
|
|
||||||
)
|
|
||||||
max_logits = torch.empty_like(exp_sums)
|
|
||||||
|
|
||||||
ops.paged_attention_v2(
|
|
||||||
out,
|
|
||||||
exp_sums,
|
|
||||||
max_logits,
|
|
||||||
tmp_output,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
kv_head_mapping,
|
|
||||||
softmax_scale,
|
|
||||||
block_tables,
|
|
||||||
input_lengths,
|
|
||||||
block_size,
|
|
||||||
max_s,
|
|
||||||
None,
|
|
||||||
kv_cache_dtype,
|
|
||||||
kv_scale,
|
|
||||||
)
|
|
|
@ -1,48 +0,0 @@
|
||||||
from typing import Optional, Tuple, Iterable
|
|
||||||
from loguru import logger
|
|
||||||
import json
|
|
||||||
from text_generation_server.layers.schema import QuantParamSchema
|
|
||||||
|
|
||||||
|
|
||||||
def kv_cache_scales_loader(
|
|
||||||
filename: str,
|
|
||||||
tp_rank: int,
|
|
||||||
tp_size: int,
|
|
||||||
num_hidden_layers: int,
|
|
||||||
model_type: Optional[str],
|
|
||||||
) -> Iterable[Tuple[int, float]]:
|
|
||||||
"""
|
|
||||||
A simple utility to read in KV cache scaling factors that have been
|
|
||||||
previously serialized to disk. Used by the model to populate the appropriate
|
|
||||||
KV cache scaling factors. The serialization should represent a dictionary
|
|
||||||
whose keys are the TP ranks and values are another dictionary mapping layers
|
|
||||||
to their KV cache scaling factors.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with open(filename) as f:
|
|
||||||
context = {
|
|
||||||
"model_type": model_type,
|
|
||||||
"num_hidden_layers": num_hidden_layers,
|
|
||||||
"tp_rank": tp_rank,
|
|
||||||
"tp_size": tp_size,
|
|
||||||
}
|
|
||||||
schema_dct = json.load(f)
|
|
||||||
schema = QuantParamSchema.model_validate(schema_dct, context=context)
|
|
||||||
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
|
|
||||||
return layer_scales_map.items()
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.error(f"File or directory '{filename}' not found.")
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.error(f"Error decoding JSON in file '{filename}'.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"An error occurred while reading '{filename}': {e}")
|
|
||||||
# This section is reached if and only if any of the excepts are hit
|
|
||||||
# Return an empty iterable (list) => no KV cache scales are loaded
|
|
||||||
# which ultimately defaults to 1.0 scales
|
|
||||||
logger.warning(
|
|
||||||
"Defaulting to KV cache scaling factors = 1.0 "
|
|
||||||
f"for all layers in TP rank {tp_rank} "
|
|
||||||
"as an error occurred during loading."
|
|
||||||
)
|
|
||||||
return []
|
|
Loading…
Reference in New Issue