Remove vLLM dependency for CUDA (#2751)
* Remove vLLM dependency for CUDA This change adds `attention-kernels` as a dependency for paged attention and cache reshaping. With that, we don't use vLLM anywhere for CUDA. Tested run (since we don't have paged attention in CI): ``` ❯ ATTENTION=paged python -m pytest integration-tests -k "llama and awq" --release [...] 5 snapshots passed. ``` * Fix clippy warning
This commit is contained in:
parent
6489f85269
commit
52e48739a5
16
Dockerfile
16
Dockerfile
|
@ -161,18 +161,6 @@ COPY server/custom_kernels/ .
|
||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN python setup.py build
|
RUN python setup.py build
|
||||||
|
|
||||||
# Build vllm CUDA kernels
|
|
||||||
FROM kernel-builder AS vllm-builder
|
|
||||||
|
|
||||||
WORKDIR /usr/src
|
|
||||||
|
|
||||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
|
|
||||||
|
|
||||||
COPY server/Makefile-vllm Makefile
|
|
||||||
|
|
||||||
# Build specific version of vllm
|
|
||||||
RUN make build-vllm-cuda
|
|
||||||
|
|
||||||
# Build mamba kernels
|
# Build mamba kernels
|
||||||
FROM kernel-builder AS mamba-builder
|
FROM kernel-builder AS mamba-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
@ -230,8 +218,6 @@ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86
|
||||||
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||||
# Copy build artifacts from lorax punica kernels builder
|
# Copy build artifacts from lorax punica kernels builder
|
||||||
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||||
# Copy build artifacts from vllm builder
|
|
||||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
|
||||||
# Copy build artifacts from mamba builder
|
# Copy build artifacts from mamba builder
|
||||||
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
|
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
|
||||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
|
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
|
||||||
|
@ -247,7 +233,7 @@ COPY server/Makefile server/Makefile
|
||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
make gen-server && \
|
||||||
pip install -r requirements_cuda.txt && \
|
pip install -r requirements_cuda.txt && \
|
||||||
pip install ".[bnb, accelerate, compressed-tensors, marlin, moe, quantize, peft, outlines]" --no-cache-dir && \
|
pip install ".[attention, bnb, accelerate, compressed-tensors, marlin, moe, quantize, peft, outlines]" --no-cache-dir && \
|
||||||
pip install nvidia-nccl-cu12==2.22.3
|
pip install nvidia-nccl-cu12==2.22.3
|
||||||
|
|
||||||
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||||
|
|
|
@ -978,16 +978,15 @@
|
||||||
"nixpkgs": "nixpkgs_6"
|
"nixpkgs": "nixpkgs_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1731601436,
|
"lastModified": 1731674227,
|
||||||
"narHash": "sha256-PJmXLyz06XnLG3wB5vRLgeJXoVvpuCx6c70khYv6J1o=",
|
"narHash": "sha256-k/ur37KSc+RXcwwz0tgxeamz6wQ5rsOe5hMepzIdD2s=",
|
||||||
"owner": "huggingface",
|
"owner": "huggingface",
|
||||||
"repo": "text-generation-inference-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"rev": "9510f57282795d6e0dbbd163d2b77a6b5bb52566",
|
"rev": "407b9e22a0b7121bf6e171d67ce0144e3f3e39bf",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "huggingface",
|
"owner": "huggingface",
|
||||||
"ref": "nixpkgs-update-20241114",
|
|
||||||
"repo": "text-generation-inference-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
};
|
};
|
||||||
nix-filter.url = "github:numtide/nix-filter";
|
nix-filter.url = "github:numtide/nix-filter";
|
||||||
tgi-nix.url = "github:huggingface/text-generation-inference-nix/nixpkgs-update-20241114";
|
tgi-nix.url = "github:huggingface/text-generation-inference-nix";
|
||||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
rust-overlay = {
|
rust-overlay = {
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
buildPythonPackage,
|
buildPythonPackage,
|
||||||
poetry-core,
|
poetry-core,
|
||||||
mypy-protobuf,
|
mypy-protobuf,
|
||||||
|
attention-kernels,
|
||||||
awq-inference-engine,
|
awq-inference-engine,
|
||||||
causal-conv1d,
|
causal-conv1d,
|
||||||
compressed-tensors,
|
compressed-tensors,
|
||||||
|
@ -27,15 +28,18 @@
|
||||||
opentelemetry-exporter-otlp,
|
opentelemetry-exporter-otlp,
|
||||||
opentelemetry-instrumentation-grpc,
|
opentelemetry-instrumentation-grpc,
|
||||||
opentelemetry-semantic-conventions,
|
opentelemetry-semantic-conventions,
|
||||||
|
outlines,
|
||||||
peft,
|
peft,
|
||||||
|
prometheus-client,
|
||||||
punica-kernels,
|
punica-kernels,
|
||||||
|
py-cpuinfo,
|
||||||
|
pydantic,
|
||||||
safetensors,
|
safetensors,
|
||||||
tokenizers,
|
tokenizers,
|
||||||
torch,
|
torch,
|
||||||
sentencepiece,
|
sentencepiece,
|
||||||
transformers,
|
transformers,
|
||||||
typer,
|
typer,
|
||||||
vllm,
|
|
||||||
}:
|
}:
|
||||||
|
|
||||||
let
|
let
|
||||||
|
@ -72,6 +76,7 @@ buildPythonPackage {
|
||||||
pythonRemoveDeps = [ "scipy" ];
|
pythonRemoveDeps = [ "scipy" ];
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
attention-kernels
|
||||||
awq-inference-engine
|
awq-inference-engine
|
||||||
eetq
|
eetq
|
||||||
causal-conv1d
|
causal-conv1d
|
||||||
|
@ -95,14 +100,17 @@ buildPythonPackage {
|
||||||
opentelemetry-exporter-otlp
|
opentelemetry-exporter-otlp
|
||||||
opentelemetry-instrumentation-grpc
|
opentelemetry-instrumentation-grpc
|
||||||
opentelemetry-semantic-conventions
|
opentelemetry-semantic-conventions
|
||||||
|
outlines
|
||||||
peft
|
peft
|
||||||
|
prometheus-client
|
||||||
punica-kernels
|
punica-kernels
|
||||||
|
py-cpuinfo
|
||||||
|
pydantic
|
||||||
safetensors
|
safetensors
|
||||||
sentencepiece
|
sentencepiece
|
||||||
tokenizers
|
tokenizers
|
||||||
transformers
|
transformers
|
||||||
typer
|
typer
|
||||||
vllm
|
|
||||||
];
|
];
|
||||||
|
|
||||||
prePatch = ''
|
prePatch = ''
|
||||||
|
|
|
@ -22,6 +22,7 @@ use tracing::warn;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
|
#[allow(clippy::large_enum_variant)]
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub enum Tokenizer {
|
pub enum Tokenizer {
|
||||||
Python {
|
Python {
|
||||||
|
|
|
@ -29,8 +29,8 @@ install-server: gen-server
|
||||||
install: install-cuda
|
install: install-cuda
|
||||||
echo "Installed server"
|
echo "Installed server"
|
||||||
|
|
||||||
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention
|
install-cuda: install-server install-flash-attention-v2-cuda install-flash-attention
|
||||||
pip install -e ".[bnb,marlin,moe]"
|
pip install -e ".[attention,bnb,marlin,moe]"
|
||||||
pip install nvidia-nccl-cu12==2.22.3
|
pip install nvidia-nccl-cu12==2.22.3
|
||||||
|
|
||||||
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
||||||
|
|
|
@ -1,14 +1,4 @@
|
||||||
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
|
|
||||||
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
|
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
|
||||||
build-vllm-cuda:
|
|
||||||
if [ ! -d 'vllm' ]; then \
|
|
||||||
pip install -U ninja packaging --no-cache-dir && \
|
|
||||||
git clone https://github.com/Narsil/vllm.git vllm; \
|
|
||||||
fi
|
|
||||||
cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build
|
|
||||||
|
|
||||||
install-vllm-cuda: build-vllm-cuda
|
|
||||||
cd vllm && git fetch origin && git checkout $(commit_cuda) && pip install -e .
|
|
||||||
|
|
||||||
build-vllm-rocm:
|
build-vllm-rocm:
|
||||||
if [ ! -d 'vllm' ]; then \
|
if [ ! -d 'vllm' ]; then \
|
||||||
|
|
|
@ -200,6 +200,74 @@ files = [
|
||||||
{file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
|
{file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "attention-kernels"
|
||||||
|
version = "0.1.1"
|
||||||
|
description = "Attention kernels"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "attention_kernels-0.1.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:812851d4ce0f54ca764ff3815a731b15f0cb110115d0aa2d0997cd7794d808bb"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
torch = "*"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "url"
|
||||||
|
url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "attention-kernels"
|
||||||
|
version = "0.1.1"
|
||||||
|
description = "Attention kernels"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "attention_kernels-0.1.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:614c402621b11dd1f5741a016b9fd27cb6a68814471f2048bc05206923516268"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
torch = "*"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "url"
|
||||||
|
url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "attention-kernels"
|
||||||
|
version = "0.1.1"
|
||||||
|
description = "Attention kernels"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "attention_kernels-0.1.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:6b2ca7c98997431d5f6c4af7553dce6b1bff8dfdec374c97c6ffba71325a02b7"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
torch = "*"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "url"
|
||||||
|
url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "attention-kernels"
|
||||||
|
version = "0.1.1"
|
||||||
|
description = "Attention kernels"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "attention_kernels-0.1.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:a56710c5626e461d6f628ae14b74ffc89833578ebd59c3c0c47f5d6f07461fbf"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
torch = "*"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "url"
|
||||||
|
url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "attrs"
|
name = "attrs"
|
||||||
version = "24.2.0"
|
version = "24.2.0"
|
||||||
|
@ -3985,6 +4053,7 @@ type = ["pytest-mypy"]
|
||||||
|
|
||||||
[extras]
|
[extras]
|
||||||
accelerate = ["accelerate"]
|
accelerate = ["accelerate"]
|
||||||
|
attention = ["attention-kernels", "attention-kernels", "attention-kernels", "attention-kernels"]
|
||||||
bnb = ["bitsandbytes"]
|
bnb = ["bitsandbytes"]
|
||||||
compressed-tensors = ["compressed-tensors"]
|
compressed-tensors = ["compressed-tensors"]
|
||||||
marlin = ["marlin-kernels", "marlin-kernels", "marlin-kernels", "marlin-kernels"]
|
marlin = ["marlin-kernels", "marlin-kernels", "marlin-kernels", "marlin-kernels"]
|
||||||
|
@ -3997,4 +4066,4 @@ torch = ["torch"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.9,<3.13"
|
python-versions = ">=3.9,<3.13"
|
||||||
content-hash = "5d1295a8becce2f65dc68d64f200acb5832de50fc0c37392f6f87bbc5b15d32a"
|
content-hash = "05add88628d836faceae1a26fde4092651a6eca74555ae38ebff879a7895be7e"
|
||||||
|
|
|
@ -9,7 +9,7 @@ text-generation-server = 'text_generation_server.cli:app'
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.9,<3.13"
|
python = ">=3.9,<3.13"
|
||||||
protobuf = "^4.25.3"
|
protobuf = ">=4.25.3,<6"
|
||||||
grpcio = "^1.51.1"
|
grpcio = "^1.51.1"
|
||||||
grpcio-status = "^1.51.1"
|
grpcio-status = "^1.51.1"
|
||||||
grpcio-reflection = "^1.51.1"
|
grpcio-reflection = "^1.51.1"
|
||||||
|
@ -35,12 +35,18 @@ torch = { version = "^2.4.0", optional = true }
|
||||||
scipy = "^1.11.1"
|
scipy = "^1.11.1"
|
||||||
pillow = "^10.0.0"
|
pillow = "^10.0.0"
|
||||||
outlines= { version = "^0.1.1", optional = true }
|
outlines= { version = "^0.1.1", optional = true }
|
||||||
prometheus-client = "^0.20.0"
|
prometheus-client = ">=0.20.0,<0.22"
|
||||||
py-cpuinfo = "^9.0.0"
|
py-cpuinfo = "^9.0.0"
|
||||||
compressed-tensors = { version = "^0.7.1", optional = true }
|
compressed-tensors = { version = "^0.7.1", optional = true }
|
||||||
# Remove later, temporary workaround for outlines.
|
# Remove later, temporary workaround for outlines.
|
||||||
numpy = "^1.26"
|
numpy = "^1.26"
|
||||||
|
|
||||||
|
attention-kernels = [
|
||||||
|
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||||
|
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||||
|
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||||
|
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||||
|
]
|
||||||
marlin-kernels = [
|
marlin-kernels = [
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||||
|
@ -58,6 +64,7 @@ rich = "^13.7.1"
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
torch = ["torch"]
|
torch = ["torch"]
|
||||||
accelerate = ["accelerate"]
|
accelerate = ["accelerate"]
|
||||||
|
attention = ["attention-kernels"]
|
||||||
bnb = ["bitsandbytes"]
|
bnb = ["bitsandbytes"]
|
||||||
compressed-tensors = ["compressed-tensors"]
|
compressed-tensors = ["compressed-tensors"]
|
||||||
marlin = ["marlin-kernels"]
|
marlin = ["marlin-kernels"]
|
||||||
|
|
|
@ -108,7 +108,7 @@ def paged_attention(
|
||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||||
from vllm._C import ops
|
import attention_kernels
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
|
@ -116,7 +116,7 @@ def paged_attention(
|
||||||
max_num_partitions == 1 or num_seqs * num_heads > 512
|
max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||||
)
|
)
|
||||||
if use_v1:
|
if use_v1:
|
||||||
ops.paged_attention_v1(
|
attention_kernels.paged_attention_v1(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache.key,
|
||||||
|
@ -146,7 +146,7 @@ def paged_attention(
|
||||||
)
|
)
|
||||||
max_logits = torch.empty_like(exp_sums)
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
|
||||||
ops.paged_attention_v2(
|
attention_kernels.paged_attention_v2(
|
||||||
out,
|
out,
|
||||||
exp_sums,
|
exp_sums,
|
||||||
max_logits,
|
max_logits,
|
||||||
|
|
|
@ -200,12 +200,12 @@ def paged_reshape_and_cache(
|
||||||
):
|
):
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
try:
|
try:
|
||||||
from vllm._C import cache_ops
|
import attention_kernels
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
|
||||||
)
|
)
|
||||||
cache_ops.reshape_and_cache(
|
attention_kernels.reshape_and_cache(
|
||||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||||
)
|
)
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
|
|
|
@ -23,8 +23,10 @@ from typing import Optional, List, Tuple, Any
|
||||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if SYSTEM != "ipex":
|
if SYSTEM == "rocm":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
elif SYSTEM != "ipex":
|
||||||
|
from moe_kernels.fused_moe import fused_moe
|
||||||
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
|
|
Loading…
Reference in New Issue