This commit is contained in:
Mohit Sharma 2024-06-24 08:25:59 +00:00
parent 81fd601c44
commit fb83e3416b
6 changed files with 9 additions and 322 deletions

View File

@ -92,11 +92,11 @@ def serve(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
) )
if kv_cache_dtype in {"fp8", "fp8_e5m2"} and SYSTEM not in {"cuda", "rocm"}: if kv_cache_dtype in {"fp8", "fp8_e5m2"}:
raise RuntimeError(f"{kv_cache_dtype} KV cache is only supported on Nvidia and AMD GPUs.") if SYSTEM not in {"cuda", "rocm"}:
raise RuntimeError(f"`{kv_cache_dtype}` KV cache is only supported on Nvidia and AMD GPUs.")
if kv_cache_dtype == "fp8_e5m2" and SYSTEM != "cuda": if kv_cache_dtype == "fp8_e5m2" and SYSTEM != "cuda":
raise RuntimeError(f"fp8_e5m2 KV cache is only supported on Nvidia GPUs.") raise RuntimeError(f"`fp8_e5m2` KV cache is only supported on Nvidia GPUs.")
server.serve( server.serve(
model_id, model_id,

View File

@ -1,90 +0,0 @@
"""
This file contains the Pydantic schemas for various quantization-related
parameters. When a relevant quantization technique is specified, these
parameters are loaded in the form of a JSON alongside the model weights
and augment the model with additional information needed for use of that
technique. The format of this JSON should be specified by one or more
schemas contained here.
For example, when the KV cache is quantized to FP8-E4M3 (currently only
possible on ROCm), the model can be optionally augmented with KV cache
scaling factors.
"""
from typing import Dict, Optional
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
class KVCacheQuantSchema(BaseModel):
dtype: str
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
# layer indices to their per-tensor KV cache scaling factor.
# TODO: Consider pulling this and its validation methods out into its
# own schema class (tricky as its members are variable)
scaling_factor: Dict[int, Dict[int, float]]
@model_validator(mode="after")
def check_is_fp8(self) -> "KVCacheQuantSchema":
assert self.dtype == "float8_e4m3fn", (
"Loaded scaling factors intended for KV cache dtype = "
f"{self.dtype} rather than float8_e4m3fn!"
)
return self
@model_validator(mode="after")
def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
context = info.context
if context:
tp_size = context["tp_size"]
num_hidden_layers = context["num_hidden_layers"]
assert len(self.scaling_factor) == tp_size, (
f"Loaded dictionary has TP size {len(self.scaling_factor)} "
f"but LLM engine is currently running with TP size {tp_size}."
)
for tp_rank, layer_maps in self.scaling_factor.items():
assert len(layer_maps) == num_hidden_layers, (
f"KV cache scales map for TP rank {tp_rank} is malformed. "
f"Expected {num_hidden_layers} layers, got "
f"{len(layer_maps)}."
)
for i in range(tp_size):
assert (
i in self.scaling_factor
), f"KV cache scales map for TP rank {i} not found."
return self
@model_validator(mode="after")
def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
context = info.context
if context:
tp_rank = context["tp_rank"]
num_hidden_layers = context["num_hidden_layers"]
layer_scales_map = self.scaling_factor[tp_rank]
for i in range(num_hidden_layers):
assert i in layer_scales_map, (
f"Could not find KV cache scales for layer {i} in "
f"TP rank {tp_rank}."
)
return self
class QuantParamSchema(BaseModel):
# TODO: Generalize and extend with more fields
# (e.g. weights/activations params) once functionality is enabled
model_config = ConfigDict(protected_namespaces=())
model_type: Optional[str]
kv_cache: KVCacheQuantSchema
@model_validator(mode="after")
def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
context = info.context
if context:
model_type = context.get("model_type", None)
if model_type is not None:
assert model_type == self.model_type, (
f"Model type is {model_type} but loaded "
f"scaling factors belonging to different "
f"model type {self.model_type}!"
)
return self

View File

@ -294,7 +294,7 @@ def get_model(
if model_type not in FP8_KVCACHE_SUPPORTED_MODELS and kv_cache_dtype != "auto": if model_type not in FP8_KVCACHE_SUPPORTED_MODELS and kv_cache_dtype != "auto":
raise RuntimeError( raise RuntimeError(
f"kv_cache_dtype is only supported for Llama models. Got model_type: {model_type}, kv_cache_dtype: {kv_cache_dtype}" f"kv_cache_dtype is only supported for {", ".join(FP8_KVCACHE_SUPPORTED_MODELS)} models. Got model_type: {model_type}, kv_cache_dtype: {kv_cache_dtype}"
) )
speculator = None speculator = None

View File

@ -726,8 +726,6 @@ class FlashCausalLM(Model):
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
kv_cache_dtype: str = "auto",
quantization_param_path: Optional[str] = None,
rank: int = 0, rank: int = 0,
world_size: int = 1, world_size: int = 1,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
@ -750,37 +748,6 @@ class FlashCausalLM(Model):
sliding_window=sliding_window, sliding_window=sliding_window,
) )
if kv_cache_dtype == "fp8":
self.kv_cache_dtype = torch.uint8
else:
self.kv_cache_dtype = self.dtype
if kv_cache_dtype == "fp8" and SYSTEM == "rocm":
logger.info(f"Using KV cache data type: {kv_cache_dtype}")
# Currently scaled KV cache is only enabled on ROCm
if quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
self.model.load_kv_cache_scales(quantization_param_path)
else:
raise RuntimeError(
"Using FP8 KV cache and scaling "
"factors provided but model "
f"{self.model.__class__} does not "
"support loading scaling factors."
)
else:
logger.info(
"Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!"
)
elif quantization_param_path is not None:
logger.info(
"KV cache scaling factors provided, "
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used."
)
@property @property
def batch_type(self) -> Type[FlashCausalLMBatch]: def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch return FlashCausalLMBatch
@ -906,7 +873,7 @@ class FlashCausalLM(Model):
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory # Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() dtype_size = torch.tensor([], dtype=self.dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
@ -986,7 +953,7 @@ class FlashCausalLM(Model):
if self.speculate is None or self.speculate + 1 <= bs: if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_s, max_bt) self.cuda_graph_warmup(bs, max_s, max_bt)
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:
logger.exception("Decode cuda graph warmup failed") logger.exception(f"Decode cuda graph warmup failed")
else: else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")

View File

@ -1,142 +0,0 @@
from loguru import logger
import torch
from text_generation_server.utils.import_utils import SYSTEM
_PARTITION_SIZE = 512
if SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex
else:
try:
from vllm._C import cache_ops
from vllm._C import ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
kv_cache_dtype: str = "auto",
kv_scale: int = 1.0,
):
if SYSTEM == "xpu":
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale
)
def attention(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
kv_cache_dtype: str = "auto",
kv_scale: int = 1.0,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
if SYSTEM == "xpu":
query = query.contiguous()
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1:
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
kv_cache_dtype,
kv_scale,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
kv_cache_dtype,
kv_scale,
)

View File

@ -1,48 +0,0 @@
from typing import Optional, Tuple, Iterable
from loguru import logger
import json
from text_generation_server.layers.schema import QuantParamSchema
def kv_cache_scales_loader(
filename: str,
tp_rank: int,
tp_size: int,
num_hidden_layers: int,
model_type: Optional[str],
) -> Iterable[Tuple[int, float]]:
"""
A simple utility to read in KV cache scaling factors that have been
previously serialized to disk. Used by the model to populate the appropriate
KV cache scaling factors. The serialization should represent a dictionary
whose keys are the TP ranks and values are another dictionary mapping layers
to their KV cache scaling factors.
"""
try:
with open(filename) as f:
context = {
"model_type": model_type,
"num_hidden_layers": num_hidden_layers,
"tp_rank": tp_rank,
"tp_size": tp_size,
}
schema_dct = json.load(f)
schema = QuantParamSchema.model_validate(schema_dct, context=context)
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
return layer_scales_map.items()
except FileNotFoundError:
logger.error(f"File or directory '{filename}' not found.")
except json.JSONDecodeError:
logger.error(f"Error decoding JSON in file '{filename}'.")
except Exception as e:
logger.error(f"An error occurred while reading '{filename}': {e}")
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger.warning(
"Defaulting to KV cache scaling factors = 1.0 "
f"for all layers in TP rank {tp_rank} "
"as an error occurred during loading."
)
return []