fix style
This commit is contained in:
parent
8a0bb53ef3
commit
557e18e08c
|
@ -213,5 +213,5 @@ FROM base-copy
|
|||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||
CMD ["--json-output"]
|
||||
# ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||
# CMD ["--json-output"]
|
||||
|
|
3
Makefile
3
Makefile
|
@ -53,3 +53,6 @@ run-falcon-7b-instruct-quantize:
|
|||
|
||||
clean:
|
||||
rm -rf target aml
|
||||
|
||||
interact:
|
||||
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 64g --net host -v /home/mohit/.cache/huggingface/hub/:/data -v $(PWD):/tgi tgi-mht
|
||||
|
|
|
@ -58,4 +58,4 @@ Use [AutoFP8](https://github.com/neuralmagic/AutoFP8) with calibration data to g
|
|||
|
||||
TGI provides a utility to extract the FP8 KV cache scales from an `AutoFP8` quantized model and save them to the FP16 model for use with TGI. For more information: <path to script>
|
||||
|
||||
Alternatively, you can use other quantizer tools, such as Nvidia AMMO, to obtain these scaling factors.
|
||||
Alternatively, you can use other quantizer tools, such as Nvidia AMMO, to obtain these scaling factors.
|
||||
|
|
|
@ -234,7 +234,7 @@ Options:
|
|||
--hostname <HOSTNAME>
|
||||
The IP address to listen on
|
||||
|
||||
[env: HOSTNAME=hf-amd-mi250-dev]
|
||||
[env: HOSTNAME=]
|
||||
[default: 0.0.0.0]
|
||||
|
||||
```
|
||||
|
@ -279,7 +279,7 @@ Options:
|
|||
--huggingface-hub-cache <HUGGINGFACE_HUB_CACHE>
|
||||
The location of the huggingface hub cache. Used to override the location if you want to provide a mounted disk for instance
|
||||
|
||||
[env: HUGGINGFACE_HUB_CACHE=/data]
|
||||
[env: HUGGINGFACE_HUB_CACHE=]
|
||||
|
||||
```
|
||||
## WEIGHTS_CACHE_OVERRIDE
|
||||
|
|
|
@ -37,4 +37,4 @@ To extract KV cache scaling factors from a quantized FP8 model and save them to
|
|||
|
||||
```
|
||||
python extract_fp8_kv_scales.py --quantized-model neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV --model meta-llama/Meta-Llama-3-8B-Instruct --save-path Meta-Llama-3-8B-Instruct
|
||||
```
|
||||
```
|
||||
|
|
|
@ -91,10 +91,12 @@ def serve(
|
|||
raise RuntimeError(
|
||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||
)
|
||||
|
||||
|
||||
if kv_cache_dtype in {"fp8", "fp8_e5m2"}:
|
||||
if SYSTEM not in {"cuda", "rocm"}:
|
||||
raise RuntimeError(f"`{kv_cache_dtype}` KV cache is only supported on Nvidia and AMD GPUs.")
|
||||
raise RuntimeError(
|
||||
f"`{kv_cache_dtype}` KV cache is only supported on Nvidia and AMD GPUs."
|
||||
)
|
||||
if kv_cache_dtype == "fp8_e5m2" and SYSTEM != "cuda":
|
||||
raise RuntimeError(f"`fp8_e5m2` KV cache is only supported on Nvidia GPUs.")
|
||||
|
||||
|
|
|
@ -23,7 +23,9 @@ def reshape_and_cache(
|
|||
kv_cache_dtype: str = "auto",
|
||||
kv_scale: int = 1.0,
|
||||
):
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale)
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale
|
||||
)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
|
|
|
@ -28,7 +28,9 @@ def reshape_and_cache(
|
|||
kv_cache_dtype: str = "auto",
|
||||
kv_scale: int = 1.0,
|
||||
):
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale)
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale
|
||||
)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
|
|
|
@ -114,6 +114,7 @@ except ImportError as e:
|
|||
if MAMBA_AVAILABLE:
|
||||
__all__.append(Mamba)
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
IDEFICS2 = {
|
||||
"type": "idefics2",
|
||||
|
@ -244,6 +245,7 @@ class ModelType(enum.Enum):
|
|||
"multimodal": True,
|
||||
}
|
||||
|
||||
|
||||
FP8_KVCACHE_SUPPORTED_MODELS = {
|
||||
"llama",
|
||||
"baichun",
|
||||
|
|
|
@ -45,6 +45,7 @@ from text_generation_server.layers.layernorm import (
|
|||
)
|
||||
|
||||
from loguru import logger
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
try:
|
||||
from vllm import _custom_C
|
||||
|
@ -138,7 +139,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
self.kv_cache_dtype = config.kv_cache_dtype
|
||||
|
||||
if self.kv_cache_dtype == "fp8":
|
||||
self.kv_scale = weights.get_kv_cache_scaling_factor(prefix, self.kv_cache_dtype)
|
||||
self.kv_scale = weights.get_kv_cache_scaling_factor(
|
||||
prefix, self.kv_cache_dtype
|
||||
)
|
||||
else:
|
||||
self.kv_scale = 1.0
|
||||
logger.info(f"kv_cache_dtype: {self.kv_cache_dtype}, kv_scale: {self.kv_scale}")
|
||||
|
@ -168,7 +171,15 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots, self.kv_cache_dtype, self.kv_scale)
|
||||
reshape_and_cache(
|
||||
kv[:, 0],
|
||||
kv[:, 1],
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
slots,
|
||||
self.kv_cache_dtype,
|
||||
self.kv_scale,
|
||||
)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
|
|
@ -269,6 +269,13 @@ def serve(
|
|||
set_model_id(model_id)
|
||||
asyncio.run(
|
||||
serve_inner(
|
||||
model_id, revision, sharded, quantize, speculate, dtype, kv_cache_dtype, trust_remote_code
|
||||
model_id,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -89,7 +89,11 @@ class Weights:
|
|||
# Special case for gptq which shouldn't convert
|
||||
# u4 which are disguised as int32. Exl2 uses int16
|
||||
# as well.
|
||||
if tensor.dtype not in [torch.int16, torch.int32,torch.int64] and not tensor_name.endswith("kv_scale"):
|
||||
if tensor.dtype not in [
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
] and not tensor_name.endswith("kv_scale"):
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
if to_device:
|
||||
tensor = tensor.to(device=self.device)
|
||||
|
|
Loading…
Reference in New Issue