add skinny kernel and merge fixes
This commit is contained in:
parent
058162685f
commit
59fd0cbdff
|
@ -152,9 +152,6 @@ ENV HIP_FORCE_DEV_KERNARG=1
|
|||
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
|
||||
# However, Triton requires a tunning for each prompt length, which is prohibitive.
|
||||
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
|
||||
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
|
||||
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
|
||||
ENV VLLM_MOE_PADDING=0
|
||||
|
||||
FROM base AS kernel-builder
|
||||
|
||||
|
@ -245,6 +242,13 @@ ENTRYPOINT ["./entrypoint.sh"]
|
|||
# Final image
|
||||
FROM base-copy
|
||||
|
||||
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
|
||||
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
|
||||
ENV VLLM_MOE_PADDING=0
|
||||
ENV ATTENTION=paged
|
||||
ENV USE_PREFIX_CACHING=0
|
||||
ENV ROCM_USE_SKINNY_GEMM=1
|
||||
|
||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
|
|
|
@ -45,7 +45,6 @@ def paged_attention(
|
|||
block_tables: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
num_kv_heads: int,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||
|
|
|
@ -62,7 +62,6 @@ def paged_attention(
|
|||
block_tables: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
num_kv_heads: int,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
out = torch.empty_like(query)
|
||||
|
|
|
@ -50,9 +50,8 @@ def paged_attention(
|
|||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
input_lengths: Seqlen,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
num_kv_heads: int,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||
|
@ -76,6 +75,7 @@ def paged_attention(
|
|||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
|
||||
num_kv_heads = key_cache.shape[1]
|
||||
gqa_ratio = num_heads // num_kv_heads
|
||||
use_custom = (
|
||||
custom_attn_available
|
||||
|
@ -92,7 +92,7 @@ def paged_attention(
|
|||
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
|
||||
|
||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
input_lengths = input_lengths.input_lengths
|
||||
input_lengths = seqlen.input_lengths
|
||||
|
||||
out = torch.empty_like(query)
|
||||
|
||||
|
@ -220,10 +220,10 @@ if ENGINE == "ck":
|
|||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
|
@ -237,17 +237,17 @@ if ENGINE == "ck":
|
|||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
key_cache,
|
||||
value_cache,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
max_s,
|
||||
max_s,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
|
@ -264,26 +264,27 @@ elif ENGINE == "triton":
|
|||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
softcap=0.0,
|
||||
):
|
||||
out = torch.empty_like(q)
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
output, _ = triton_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
key_cache,
|
||||
value_cache,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
causal,
|
||||
softmax_scale,
|
||||
)
|
||||
|
|
|
@ -1,12 +1,19 @@
|
|||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from torch.nn import functional as F
|
||||
import os
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
try:
|
||||
from vllm import _custom_C
|
||||
except Exception as e:
|
||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||
ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
)
|
||||
|
||||
if ROCM_USE_SKINNY_GEMM:
|
||||
try:
|
||||
from vllm import _custom_C
|
||||
except Exception as e:
|
||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||
|
||||
|
||||
class FastLinear(torch.nn.Module):
|
||||
|
@ -48,6 +55,14 @@ class FastLinearROCm(torch.nn.Module):
|
|||
else:
|
||||
self.bias = None
|
||||
|
||||
self.cu_count = torch.cuda.get_device_properties(
|
||||
device="cuda"
|
||||
).multi_processor_count
|
||||
self.use_skinny_gemm = (
|
||||
ROCM_USE_SKINNY_GEMM
|
||||
and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
|
@ -62,9 +77,9 @@ class FastLinearROCm(torch.nn.Module):
|
|||
bias = self.bias
|
||||
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and inp.numel() // inp.shape[-1] == 1
|
||||
self.use_skinny_gemm
|
||||
and inp.dtype == torch.float16
|
||||
and inp.shape[-1] % 8 == 0
|
||||
):
|
||||
batched = False
|
||||
inp_shape = inp.shape
|
||||
|
@ -73,13 +88,16 @@ class FastLinearROCm(torch.nn.Module):
|
|||
inp = inp.view(-1, inp_shape[-1])
|
||||
batched = True
|
||||
|
||||
m, k = weight.shape[0], inp_shape[1]
|
||||
out = torch.empty(
|
||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
|
||||
)
|
||||
if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
|
||||
_custom_C.LLMM1(weight, inp, out, 8)
|
||||
elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
|
||||
m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
|
||||
if m > 8 and n <= 4:
|
||||
out = torch.empty(
|
||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
||||
)
|
||||
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
|
||||
elif m % 4 == 0 and n == 1 and k <= 8192:
|
||||
out = torch.empty(
|
||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
||||
)
|
||||
_custom_C.LLMM1(weight, inp, out, 4)
|
||||
else:
|
||||
out = F.linear(inp, weight)
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -297,8 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||
kv_cache[0] if PAGED_KV else key,
|
||||
kv_cache[1] if PAGED_KV else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -314,7 +315,6 @@ class FlashCohereAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -336,8 +337,8 @@ class DbrxAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -353,7 +354,6 @@ class DbrxAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
import torch
|
||||
import torch.distributed
|
||||
from text_generation_server.layers import (
|
||||
|
@ -363,8 +364,8 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||
kv_cache[0] if PAGED_KV else key,
|
||||
kv_cache[1] if PAGED_KV else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -380,7 +381,6 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
# Remove padding.
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -25,7 +26,6 @@ from torch import nn
|
|||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
|
@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -257,7 +257,6 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
softcap=self.softcap,
|
||||
)
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -25,7 +26,6 @@ from torch import nn
|
|||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
|
@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -249,7 +249,6 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
|
|
@ -18,13 +18,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
|
@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||
kv_cache[0] if PAGED_KV else key,
|
||||
kv_cache[1] if PAGED_KV else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -248,7 +248,6 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -192,8 +193,8 @@ class FlashGPTJAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||
kv_cache[0] if PAGED_KV else key,
|
||||
kv_cache[1] if PAGED_KV else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -28,6 +28,7 @@ from torch import nn
|
|||
from transformers.activations import ACT2FN
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
|
@ -220,8 +221,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -237,7 +238,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -218,8 +219,8 @@ class MistralAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -236,7 +237,6 @@ class MistralAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -275,8 +276,8 @@ class MixtralAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -293,7 +294,6 @@ class MixtralAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -26,7 +27,6 @@ from transformers.activations import ACT2FN
|
|||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
|
||||
from typing import Optional, List, Tuple
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
|
@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
qkv[:, 0],
|
||||
kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1],
|
||||
kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2],
|
||||
kv_cache[0] if PAGED_KV else qkv[:, 1],
|
||||
kv_cache[1] if PAGED_KV else qkv[:, 2],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -189,7 +189,6 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from text_generation_server.models.globals import PAGED_KV
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -25,7 +26,6 @@ from text_generation_server.layers.layernorm import (
|
|||
from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
class PhiConfig(PretrainedConfig):
|
||||
|
@ -194,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
|
|||
if cu_seqlen_prefill is not None:
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -211,7 +211,6 @@ class FlashPhiAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from text_generation_server.models.globals import PAGED_KV
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -21,7 +22,6 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
|||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
|
@ -137,8 +137,8 @@ class Qwen2Attention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -155,7 +155,6 @@ class Qwen2Attention(torch.nn.Module):
|
|||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from typing import List, Optional, Tuple
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch import nn
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers import (
|
||||
SpeculativeHead,
|
||||
TensorParallelColumnLinear,
|
||||
|
@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -224,7 +224,6 @@ class FlashRWAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
@ -326,8 +325,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(),
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(),
|
||||
kv_cache[0] if PAGED_KV else kv[:, :, 0].contiguous(),
|
||||
kv_cache[1] if PAGED_KV else kv[:, :, 1].contiguous(),
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -343,7 +342,6 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.dense(
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from text_generation_server.models.globals import PAGED_KV
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -22,7 +23,6 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
|
|||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
def load_multi_mqa(
|
||||
|
@ -293,8 +293,8 @@ class FlashMQAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1],
|
||||
kv_cache[0] if PAGED_KV else key_value[:, 0],
|
||||
kv_cache[1] if PAGED_KV else key_value[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -310,7 +310,6 @@ class FlashMQAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -47,7 +48,6 @@ from text_generation_server.layers.rotary import (
|
|||
PositionRotaryEmbedding,
|
||||
)
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
class Starcoder2Config(PretrainedConfig):
|
||||
|
@ -242,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -260,7 +260,6 @@ class Starcoder2Attention(torch.nn.Module):
|
|||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
|
|
@ -1379,6 +1379,7 @@ class FlashCausalLM(Model):
|
|||
cu_seqlen_prefill = torch.tensor(
|
||||
[0, seqlen], device=self.device, dtype=torch.int32
|
||||
)
|
||||
max_s = seqlen
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
prefix_lengths=prefix_lens_tensor,
|
||||
|
@ -1396,7 +1397,7 @@ class FlashCausalLM(Model):
|
|||
block_tables=None,
|
||||
seqlen=seqlen,
|
||||
slots=slots,
|
||||
max_s=seqlen,
|
||||
max_s=max_s,
|
||||
lm_head_indices=None,
|
||||
prefill_cache_indices=None,
|
||||
)
|
||||
|
|
|
@ -4,6 +4,7 @@ from loguru import logger
|
|||
from typing import Dict, Optional
|
||||
|
||||
from text_generation_server.utils.log import log_master
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
|
||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||
|
@ -52,6 +53,12 @@ CUDA_GRAPHS = cuda_graphs
|
|||
# index in all cases.
|
||||
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
|
||||
|
||||
PAGED_KV: bool
|
||||
if SYSTEM in {"rocm", "ipex"}:
|
||||
PAGED_KV = False
|
||||
else:
|
||||
PAGED_KV = True
|
||||
|
||||
|
||||
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
||||
global ADAPTER_TO_INDEX
|
||||
|
|
Loading…
Reference in New Issue