add skinny kernel and merge fixes

This commit is contained in:
Mohit Sharma 2024-09-12 13:16:13 +00:00
parent 058162685f
commit 59fd0cbdff
23 changed files with 121 additions and 101 deletions

View File

@ -152,9 +152,6 @@ ENV HIP_FORCE_DEV_KERNARG=1
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
# However, Triton requires a tunning for each prompt length, which is prohibitive.
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
ENV VLLM_MOE_PADDING=0
FROM base AS kernel-builder
@ -245,6 +242,13 @@ ENTRYPOINT ["./entrypoint.sh"]
# Final image
FROM base-copy
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
ENV VLLM_MOE_PADDING=0
ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0
ENV ROCM_USE_SKINNY_GEMM=1
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh

View File

@ -45,7 +45,6 @@ def paged_attention(
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
num_kv_heads: int,
softcap: Optional[float] = None,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py

View File

@ -62,7 +62,6 @@ def paged_attention(
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
num_kv_heads: int,
softcap: Optional[float] = None,
):
out = torch.empty_like(query)

View File

@ -50,9 +50,8 @@ def paged_attention(
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: Seqlen,
seqlen: Seqlen,
max_s: int,
num_kv_heads: int,
softcap: Optional[float] = None,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
@ -76,6 +75,7 @@ def paged_attention(
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
num_kv_heads = key_cache.shape[1]
gqa_ratio = num_heads // num_kv_heads
use_custom = (
custom_attn_available
@ -92,7 +92,7 @@ def paged_attention(
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = input_lengths.input_lengths
input_lengths = seqlen.input_lengths
out = torch.empty_like(query)
@ -220,10 +220,10 @@ if ENGINE == "ck":
def attention(
q,
k,
v,
cu_seqlens,
max_s,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
@ -237,17 +237,17 @@ if ENGINE == "ck":
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
key_cache,
value_cache,
out,
cu_seqlens,
cu_seqlens,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
None,
None,
None,
None,
max_s,
max_s,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
@ -264,26 +264,27 @@ elif ENGINE == "triton":
def attention(
q,
k,
v,
cu_seqlens,
max_s,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
output, _ = triton_attention(
q,
k,
v,
key_cache,
value_cache,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
causal,
softmax_scale,
)

View File

@ -1,12 +1,19 @@
import torch
from text_generation_server.utils.import_utils import SYSTEM
from torch.nn import functional as F
import os
if SYSTEM == "rocm":
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
"true",
"1",
)
if ROCM_USE_SKINNY_GEMM:
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class FastLinear(torch.nn.Module):
@ -48,6 +55,14 @@ class FastLinearROCm(torch.nn.Module):
else:
self.bias = None
self.cu_count = torch.cuda.get_device_properties(
device="cuda"
).multi_processor_count
self.use_skinny_gemm = (
ROCM_USE_SKINNY_GEMM
and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName
)
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_tensor(f"{prefix}.weight")
@ -62,9 +77,9 @@ class FastLinearROCm(torch.nn.Module):
bias = self.bias
if (
SYSTEM == "rocm"
and inp.numel() // inp.shape[-1] == 1
self.use_skinny_gemm
and inp.dtype == torch.float16
and inp.shape[-1] % 8 == 0
):
batched = False
inp_shape = inp.shape
@ -73,13 +88,16 @@ class FastLinearROCm(torch.nn.Module):
inp = inp.view(-1, inp_shape[-1])
batched = True
m, k = weight.shape[0], inp_shape[1]
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
)
if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
_custom_C.LLMM1(weight, inp, out, 8)
elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
if m > 8 and n <= 4:
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
elif m % 4 == 0 and n == 1 and k <= 8192:
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
_custom_C.LLMM1(weight, inp, out, 4)
else:
out = F.linear(inp, weight)

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch
import torch.distributed
@ -297,8 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
kv_cache[0] if PAGED_KV else key,
kv_cache[1] if PAGED_KV else value,
seqlen,
block_tables,
self.softmax_scale,
@ -314,7 +315,6 @@ class FlashCohereAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
self.num_key_value_heads,
)
return self.o_proj(

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch
import torch.distributed
@ -336,8 +337,8 @@ class DbrxAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
kv_cache[0] if PAGED_KV else kv[:, 0],
kv_cache[1] if PAGED_KV else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
@ -353,7 +354,6 @@ class DbrxAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
self.num_key_value_heads,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -15,6 +15,7 @@
from typing import Any, Dict, List, Optional, Tuple
from text_generation_server.models.globals import PAGED_KV
import torch
import torch.distributed
from text_generation_server.layers import (
@ -363,8 +364,8 @@ class DeepseekV2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
kv_cache[0] if PAGED_KV else key,
kv_cache[1] if PAGED_KV else value,
seqlen,
block_tables,
self.softmax_scale,
@ -380,7 +381,6 @@ class DeepseekV2Attention(torch.nn.Module):
block_tables,
seqlen,
max_s,
self.num_key_value_heads,
)
# Remove padding.

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch
import torch.distributed
@ -25,7 +26,6 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
kv_cache[0] if PAGED_KV else kv[:, 0],
kv_cache[1] if PAGED_KV else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
@ -257,7 +257,6 @@ class FlashGemma2Attention(torch.nn.Module):
block_tables,
seqlen,
max_s,
self.num_key_value_heads,
softcap=self.softcap,
)

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch
import torch.distributed
@ -25,7 +26,6 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
kv_cache[0] if PAGED_KV else kv[:, 0],
kv_cache[1] if PAGED_KV else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
@ -249,7 +249,6 @@ class FlashGemmaAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
self.num_key_value_heads,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -18,13 +18,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
kv_cache[0] if PAGED_KV else key,
kv_cache[1] if PAGED_KV else value,
seqlen,
block_tables,
self.softmax_scale,
@ -248,7 +248,6 @@ class FlashGPT2Attention(torch.nn.Module):
block_tables,
seqlen,
max_s,
self.num_key_value_heads,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch
import torch.distributed
@ -192,8 +193,8 @@ class FlashGPTJAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
kv_cache[0] if PAGED_KV else key,
kv_cache[1] if PAGED_KV else value,
seqlen,
block_tables,
self.softmax_scale,

View File

@ -28,6 +28,7 @@ from torch import nn
from transformers.activations import ACT2FN
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import PAGED_KV
from text_generation_server.layers.attention import (
paged_attention,
attention,
@ -220,8 +221,8 @@ class FlashLlamaAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
kv_cache[0] if PAGED_KV else kv[:, 0],
kv_cache[1] if PAGED_KV else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
@ -237,7 +238,6 @@ class FlashLlamaAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
self.num_key_value_heads,
)
return self.o_proj(

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch
import torch.distributed
@ -218,8 +219,8 @@ class MistralAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
@ -236,7 +237,6 @@ class MistralAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
self.num_key_value_heads,
)
return self.o_proj(

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch
import torch.distributed
@ -275,8 +276,8 @@ class MixtralAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
@ -293,7 +294,6 @@ class MixtralAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
self.num_key_value_heads,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch
import torch.distributed
@ -26,7 +27,6 @@ from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module):
# flash attention
attn_output = attention(
qkv[:, 0],
kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1],
kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2],
kv_cache[0] if PAGED_KV else qkv[:, 1],
kv_cache[1] if PAGED_KV else qkv[:, 2],
seqlen,
block_tables,
self.softmax_scale,
@ -189,7 +189,6 @@ class FlashNeoxAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
self.num_key_value_heads,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -1,3 +1,4 @@
from text_generation_server.models.globals import PAGED_KV
import torch
import torch.distributed
@ -25,7 +26,6 @@ from text_generation_server.layers.layernorm import (
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.utils.import_utils import SYSTEM
class PhiConfig(PretrainedConfig):
@ -194,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
if cu_seqlen_prefill is not None:
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
kv_cache[0] if PAGED_KV else kv[:, 0],
kv_cache[1] if PAGED_KV else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
@ -211,7 +211,6 @@ class FlashPhiAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
self.num_key_value_heads,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -1,3 +1,4 @@
from text_generation_server.models.globals import PAGED_KV
import torch
import torch.distributed
@ -21,7 +22,6 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.import_utils import SYSTEM
def load_attention(config, prefix, weights):
@ -137,8 +137,8 @@ class Qwen2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
@ -155,7 +155,6 @@ class Qwen2Attention(torch.nn.Module):
block_tables,
seqlen,
max_s,
self.num_key_value_heads,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -1,11 +1,11 @@
from typing import List, Optional, Tuple
from text_generation_server.models.globals import PAGED_KV
import torch
import torch.distributed
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
SpeculativeHead,
TensorParallelColumnLinear,
@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
kv_cache[0] if PAGED_KV else kv[:, 0],
kv_cache[1] if PAGED_KV else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
@ -224,7 +224,6 @@ class FlashRWAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
self.num_key_value_heads,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -326,8 +325,8 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(),
kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(),
kv_cache[0] if PAGED_KV else kv[:, :, 0].contiguous(),
kv_cache[1] if PAGED_KV else kv[:, :, 1].contiguous(),
seqlen,
block_tables,
self.softmax_scale,
@ -343,7 +342,6 @@ class FlashRWLargeAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
self.num_key_value_heads,
)
return self.dense(

View File

@ -1,3 +1,4 @@
from text_generation_server.models.globals import PAGED_KV
import torch
import torch.distributed
@ -22,7 +23,6 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from text_generation_server.utils.import_utils import SYSTEM
def load_multi_mqa(
@ -293,8 +293,8 @@ class FlashMQAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0],
kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1],
kv_cache[0] if PAGED_KV else key_value[:, 0],
kv_cache[1] if PAGED_KV else key_value[:, 1],
seqlen,
block_tables,
self.softmax_scale,
@ -310,7 +310,6 @@ class FlashMQAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
self.num_key_value_heads,
)
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch
import torch.distributed
@ -47,7 +48,6 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.utils.weights import UnquantizedWeight
from text_generation_server.utils.import_utils import SYSTEM
class Starcoder2Config(PretrainedConfig):
@ -242,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
@ -260,7 +260,6 @@ class Starcoder2Attention(torch.nn.Module):
block_tables,
seqlen,
max_s,
self.num_key_value_heads,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -1379,6 +1379,7 @@ class FlashCausalLM(Model):
cu_seqlen_prefill = torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
)
max_s = seqlen
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
@ -1396,7 +1397,7 @@ class FlashCausalLM(Model):
block_tables=None,
seqlen=seqlen,
slots=slots,
max_s=seqlen,
max_s=max_s,
lm_head_indices=None,
prefill_cache_indices=None,
)

View File

@ -4,6 +4,7 @@ from loguru import logger
from typing import Dict, Optional
from text_generation_server.utils.log import log_master
from text_generation_server.utils.import_utils import SYSTEM
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
@ -52,6 +53,12 @@ CUDA_GRAPHS = cuda_graphs
# index in all cases.
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
PAGED_KV: bool
if SYSTEM in {"rocm", "ipex"}:
PAGED_KV = False
else:
PAGED_KV = True
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
global ADAPTER_TO_INDEX