Switch from fbgemm-gpu w8a8 scaled matmul to vLLM/marlin-kernels (#2688)
* Switch from fbgemm-gpu w8a8 scaled matmul to vLLM/marlin-kernels Performance and accuracy of these kernels are on par (tested with Llama 70B and 405B). Removes a dependency and resolves some stability issues we have been seeing. * Update test snapshots
This commit is contained in:
parent
ba5fc7d922
commit
0f346a3296
11
Dockerfile
11
Dockerfile
|
@ -161,15 +161,6 @@ COPY server/custom_kernels/ .
|
||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN python setup.py build
|
RUN python setup.py build
|
||||||
|
|
||||||
# Build FBGEMM CUDA kernels
|
|
||||||
FROM kernel-builder AS fbgemm-builder
|
|
||||||
|
|
||||||
WORKDIR /usr/src
|
|
||||||
|
|
||||||
COPY server/Makefile-fbgemm Makefile
|
|
||||||
|
|
||||||
RUN make build-fbgemm
|
|
||||||
|
|
||||||
# Build vllm CUDA kernels
|
# Build vllm CUDA kernels
|
||||||
FROM kernel-builder AS vllm-builder
|
FROM kernel-builder AS vllm-builder
|
||||||
|
|
||||||
|
@ -239,8 +230,6 @@ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86
|
||||||
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||||
# Copy build artifacts from lorax punica kernels builder
|
# Copy build artifacts from lorax punica kernels builder
|
||||||
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||||
# Copy build artifacts from fbgemm builder
|
|
||||||
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.11/cmake-install /opt/conda/lib/python3.11/site-packages
|
|
||||||
# Copy build artifacts from vllm builder
|
# Copy build artifacts from vllm builder
|
||||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||||
# Copy build artifacts from mamba builder
|
# Copy build artifacts from mamba builder
|
||||||
|
|
|
@ -978,16 +978,16 @@
|
||||||
"nixpkgs": "nixpkgs_6"
|
"nixpkgs": "nixpkgs_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1729531056,
|
"lastModified": 1729761651,
|
||||||
"narHash": "sha256-dW9IOA31+j3VS19WAWAmkJW2YCzeVZGqd6HpIJfODtI=",
|
"narHash": "sha256-GYykQ9Fxji2EuXCGcPn0dx8Qx8VQBJTkRdcCytp4A/k=",
|
||||||
"owner": "huggingface",
|
"owner": "huggingface",
|
||||||
"repo": "text-generation-inference-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"rev": "a84a90281a17b15762873845c947e5c78f5a8dd1",
|
"rev": "f7e3c4fa67d70590ed9ee47feeab645bd9ba81b1",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "huggingface",
|
"owner": "huggingface",
|
||||||
"ref": "marlin-kernels-0.3.0",
|
"ref": "marlin-kernels-0.3.1",
|
||||||
"repo": "text-generation-inference-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
};
|
};
|
||||||
nix-filter.url = "github:numtide/nix-filter";
|
nix-filter.url = "github:numtide/nix-filter";
|
||||||
tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.0";
|
tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.1";
|
||||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
rust-overlay = {
|
rust-overlay = {
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
"best_of_sequences": null,
|
"best_of_sequences": null,
|
||||||
"finish_reason": "stop_sequence",
|
"finish_reason": "length",
|
||||||
"generated_tokens": 5,
|
"generated_tokens": 10,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 128000,
|
"id": 128000,
|
||||||
|
@ -11,12 +11,12 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2323,
|
"id": 2323,
|
||||||
"logprob": -9.5625,
|
"logprob": -9.5234375,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1715,
|
"id": 1715,
|
||||||
"logprob": -10.4375,
|
"logprob": -10.421875,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -24,36 +24,66 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 25,
|
"id": 25,
|
||||||
"logprob": -0.8984375,
|
"logprob": -0.88183594,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 923,
|
"id": 2209,
|
||||||
"logprob": -2.84375,
|
"logprob": -2.6699219,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " add"
|
"text": " Is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 264,
|
"id": 279,
|
||||||
"logprob": 0.0,
|
"logprob": -0.61083984,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 734,
|
||||||
|
"logprob": -2.6660156,
|
||||||
|
"special": false,
|
||||||
|
"text": " function"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 330,
|
"id": 330,
|
||||||
"logprob": -0.31640625,
|
"logprob": -0.35498047,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " \""
|
"text": " \""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1985,
|
"id": 4110,
|
||||||
"logprob": 0.0,
|
"logprob": -2.4101562,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "test"
|
"text": "Create"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7575,
|
||||||
|
"logprob": -2.2304688,
|
||||||
|
"special": false,
|
||||||
|
"text": "Process"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": -0.080078125,
|
||||||
|
"special": false,
|
||||||
|
"text": "\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 304,
|
||||||
|
"logprob": -0.75439453,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12468,
|
||||||
|
"logprob": -1.8769531,
|
||||||
|
"special": false,
|
||||||
|
"text": " Win"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Test request: add a \"test"
|
"generated_text": "Test request: Is the function \"CreateProcess\" in Win"
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,17 +16,17 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5655,
|
"id": 5655,
|
||||||
"logprob": -11.75,
|
"logprob": -11.8359375,
|
||||||
"text": " deep"
|
"text": " deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -2.0625,
|
"logprob": -2.0703125,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 30,
|
"id": 30,
|
||||||
"logprob": -6.0,
|
"logprob": -5.9765625,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -40,25 +40,25 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 34564,
|
"id": 34564,
|
||||||
"logprob": -0.11279297,
|
"logprob": -0.12512207,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Deep"
|
"text": "Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.16015625,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 320,
|
"id": 320,
|
||||||
"logprob": -0.25195312,
|
"logprob": -0.23840332,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " ("
|
"text": " ("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 16931,
|
"id": 16931,
|
||||||
"logprob": -1.703125,
|
"logprob": -2.0175781,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "DL"
|
"text": "DL"
|
||||||
},
|
},
|
||||||
|
@ -70,7 +70,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -1.140625,
|
"logprob": -0.8613281,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
|
@ -82,7 +82,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1207,
|
"id": 1207,
|
||||||
"logprob": -1.3125,
|
"logprob": -1.2451172,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " sub"
|
"text": " sub"
|
||||||
},
|
},
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
eetq,
|
eetq,
|
||||||
einops,
|
einops,
|
||||||
exllamav2,
|
exllamav2,
|
||||||
fbgemm-gpu,
|
|
||||||
flashinfer,
|
flashinfer,
|
||||||
flash-attn,
|
flash-attn,
|
||||||
flash-attn-layer-norm,
|
flash-attn-layer-norm,
|
||||||
|
@ -77,7 +76,6 @@ buildPythonPackage {
|
||||||
causal-conv1d
|
causal-conv1d
|
||||||
einops
|
einops
|
||||||
exllamav2
|
exllamav2
|
||||||
fbgemm-gpu
|
|
||||||
flashinfer
|
flashinfer
|
||||||
flash-attn
|
flash-attn
|
||||||
flash-attn-layer-norm
|
flash-attn-layer-norm
|
||||||
|
|
|
@ -5,7 +5,6 @@ include Makefile-awq
|
||||||
include Makefile-eetq
|
include Makefile-eetq
|
||||||
include Makefile-selective-scan
|
include Makefile-selective-scan
|
||||||
include Makefile-lorax-punica
|
include Makefile-lorax-punica
|
||||||
include Makefile-fbgemm
|
|
||||||
include Makefile-exllamav2
|
include Makefile-exllamav2
|
||||||
include Makefile-flashinfer
|
include Makefile-flashinfer
|
||||||
|
|
||||||
|
@ -30,7 +29,7 @@ install-server: gen-server
|
||||||
install: install-cuda
|
install: install-cuda
|
||||||
echo "Installed server"
|
echo "Installed server"
|
||||||
|
|
||||||
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
|
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention
|
||||||
pip install -e ".[bnb,marlin,moe]"
|
pip install -e ".[bnb,marlin,moe]"
|
||||||
pip install nvidia-nccl-cu12==2.22.3
|
pip install nvidia-nccl-cu12==2.22.3
|
||||||
|
|
||||||
|
|
|
@ -1,15 +0,0 @@
|
||||||
fbgemm_commit := v0.8.0
|
|
||||||
|
|
||||||
build-fbgemm:
|
|
||||||
@if [ ! -d "fbgemm" ]; then \
|
|
||||||
git clone https://github.com/pytorch/FBGEMM.git fbgemm; \
|
|
||||||
fi
|
|
||||||
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \
|
|
||||||
git submodule update --init --recursive && \
|
|
||||||
cd fbgemm_gpu && \
|
|
||||||
pip install -r requirements.txt && \
|
|
||||||
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build
|
|
||||||
|
|
||||||
install-fbgemm: build-fbgemm
|
|
||||||
cd fbgemm/fbgemm_gpu && \
|
|
||||||
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install
|
|
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "accelerate"
|
name = "accelerate"
|
||||||
|
@ -1215,12 +1215,12 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "marlin-kernels"
|
name = "marlin-kernels"
|
||||||
version = "0.3.0"
|
version = "0.3.1"
|
||||||
description = "Marlin quantization kernels"
|
description = "Marlin quantization kernels"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:a2086b9e98d22071f52c5b4b4b98b1b4a988565258905173fa74c5a9eddd1a0a"},
|
{file = "marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:705c89ed54977099a40b37dc0c796964649024f1a8819a1832118cd7b146efe1"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -1228,16 +1228,16 @@ torch = "*"
|
||||||
|
|
||||||
[package.source]
|
[package.source]
|
||||||
type = "url"
|
type = "url"
|
||||||
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
|
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "marlin-kernels"
|
name = "marlin-kernels"
|
||||||
version = "0.3.0"
|
version = "0.3.1"
|
||||||
description = "Marlin quantization kernels"
|
description = "Marlin quantization kernels"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:f39a6946d8247629446ec170832d832c7038c363f1d8803211fe67249c2d804d"},
|
{file = "marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:e1f3d123eca643149d0a4f6b81c4405d78abb3a694a78fccc8670a25b3404406"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -1245,16 +1245,16 @@ torch = "*"
|
||||||
|
|
||||||
[package.source]
|
[package.source]
|
||||||
type = "url"
|
type = "url"
|
||||||
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
|
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "marlin-kernels"
|
name = "marlin-kernels"
|
||||||
version = "0.3.0"
|
version = "0.3.1"
|
||||||
description = "Marlin quantization kernels"
|
description = "Marlin quantization kernels"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:07fd869d5289777fa866107dae676523e18b1f6ba4afce79946ddc58a6870169"},
|
{file = "marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:9d68367fd5e1caf2edc90b77ad5d074b11586012265a3147ecca1f1171ae22f8"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -1262,16 +1262,16 @@ torch = "*"
|
||||||
|
|
||||||
[package.source]
|
[package.source]
|
||||||
type = "url"
|
type = "url"
|
||||||
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
|
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "marlin-kernels"
|
name = "marlin-kernels"
|
||||||
version = "0.3.0"
|
version = "0.3.1"
|
||||||
description = "Marlin quantization kernels"
|
description = "Marlin quantization kernels"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:0dedaa418225d490a5f1d8f85dbc75e439a8c43a8870e4ef32945bf61672d7dc"},
|
{file = "marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:d962277c5f7642972e298650913dd0546b9f735b706dc88bb34955b3cac7f330"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -1279,7 +1279,7 @@ torch = "*"
|
||||||
|
|
||||||
[package.source]
|
[package.source]
|
||||||
type = "url"
|
type = "url"
|
||||||
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
|
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mdurl"
|
name = "mdurl"
|
||||||
|
@ -1770,6 +1770,7 @@ description = "Nvidia JIT LTO Library"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3"
|
python-versions = ">=3"
|
||||||
files = [
|
files = [
|
||||||
|
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"},
|
||||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
|
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
|
||||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"},
|
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"},
|
||||||
]
|
]
|
||||||
|
@ -3966,4 +3967,4 @@ torch = ["torch"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.9,<3.13"
|
python-versions = ">=3.9,<3.13"
|
||||||
content-hash = "500fa44255e4a6c89a16314a931548447afe1ba71ea341a73cad6670e46ddac7"
|
content-hash = "b39033e573f50a0f046787aebf1702d86673aad0b2fcee818404fcea7f644b81"
|
||||||
|
|
|
@ -41,10 +41,10 @@ py-cpuinfo = "^9.0.0"
|
||||||
numpy = "^1.26"
|
numpy = "^1.26"
|
||||||
|
|
||||||
marlin-kernels = [
|
marlin-kernels = [
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||||
]
|
]
|
||||||
moe-kernels = [
|
moe-kernels = [
|
||||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple, Union, List
|
import os
|
||||||
|
from typing import Optional, Tuple, Type, Union, List
|
||||||
|
|
||||||
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
@ -11,20 +12,7 @@ from text_generation_server.utils.weights import (
|
||||||
UnquantizedWeight,
|
UnquantizedWeight,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.log import log_master, log_once
|
from text_generation_server.utils.log import log_once
|
||||||
import importlib.util
|
|
||||||
|
|
||||||
|
|
||||||
FBGEMM_MM_AVAILABLE = False
|
|
||||||
FBGEMM_DYN_AVAILABLE = False
|
|
||||||
|
|
||||||
|
|
||||||
def is_fbgemm_gpu_available():
|
|
||||||
try:
|
|
||||||
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import marlin_kernels
|
import marlin_kernels
|
||||||
|
@ -32,23 +20,26 @@ except ImportError:
|
||||||
marlin_kernels = None
|
marlin_kernels = None
|
||||||
|
|
||||||
|
|
||||||
if is_fbgemm_gpu_available():
|
if SYSTEM == "cuda" and marlin_kernels is not None:
|
||||||
if SYSTEM == "cuda":
|
major, minor = torch.cuda.get_device_capability()
|
||||||
major, _ = torch.cuda.get_device_capability()
|
CUTLASS_FP8_AVAILABLE = marlin_kernels.cutlass_scaled_mm_supports_fp8(
|
||||||
FBGEMM_MM_AVAILABLE = major == 9
|
major * 10 + minor
|
||||||
FBGEMM_DYN_AVAILABLE = major >= 8
|
)
|
||||||
else:
|
else:
|
||||||
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
|
CUTLASS_FP8_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
def get_fp8_linear() -> torch.nn.Module:
|
def get_fp8_linear() -> Type[torch.nn.Module]:
|
||||||
"""
|
"""
|
||||||
Return an FP8 linear `Module` that is compatible with the current system.
|
Return an FP8 linear `Module` that is compatible with the current system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
|
|
||||||
major, _ = torch.cuda.get_device_capability()
|
major, _ = torch.cuda.get_device_capability()
|
||||||
if major == 8:
|
if major == 8 and os.getenv("USE_CUTLASS_W8A8", "0") != "1":
|
||||||
|
# NOTE: Capability 8.9 is supported by cutlass kernels, but FP8-Marlin
|
||||||
|
# gives better decoding throughput on L4 and L40.
|
||||||
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
||||||
|
|
||||||
return GPTQMarlinFP8Linear
|
return GPTQMarlinFP8Linear
|
||||||
|
@ -94,12 +85,6 @@ def fp8_quantize(
|
||||||
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
|
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
|
||||||
be used without modification).
|
be used without modification).
|
||||||
"""
|
"""
|
||||||
if FBGEMM_DYN_AVAILABLE and not scalar and not scale:
|
|
||||||
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
|
||||||
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
|
||||||
)
|
|
||||||
return qweight, scale
|
|
||||||
|
|
||||||
if marlin_kernels is not None:
|
if marlin_kernels is not None:
|
||||||
shape = weight.shape
|
shape = weight.shape
|
||||||
qweight, scale = marlin_kernels.scaled_fp8_quant(
|
qweight, scale = marlin_kernels.scaled_fp8_quant(
|
||||||
|
@ -107,11 +92,12 @@ def fp8_quantize(
|
||||||
dtype=qdtype,
|
dtype=qdtype,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
scale_ub=scale_upper_bound,
|
scale_ub=scale_upper_bound,
|
||||||
|
# TODO: don't do this when we have to use the Torch kernel.
|
||||||
|
use_per_token_if_dynamic=not scalar,
|
||||||
)
|
)
|
||||||
|
|
||||||
return qweight.reshape(shape), scale
|
return qweight.reshape(shape), scale
|
||||||
|
|
||||||
# weight, scale = quant_weights(weight, torch.int8, False)
|
|
||||||
finfo = torch.finfo(qdtype)
|
finfo = torch.finfo(qdtype)
|
||||||
|
|
||||||
if scale is None:
|
if scale is None:
|
||||||
|
@ -327,8 +313,8 @@ class Fp8Linear(torch.nn.Module):
|
||||||
scale_upper_bound: Optional[float] = None,
|
scale_upper_bound: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if FBGEMM_MM_AVAILABLE:
|
if CUTLASS_FP8_AVAILABLE:
|
||||||
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
|
log_once(logger.info, "Using cutlass w8a8 kernels")
|
||||||
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
|
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
|
||||||
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=qweight, weight_scale=scale
|
weight=qweight, weight_scale=scale
|
||||||
|
@ -339,13 +325,9 @@ class Fp8Linear(torch.nn.Module):
|
||||||
self.scale = scale.float()
|
self.scale = scale.float()
|
||||||
self.input_scale = input_scale.float() if input_scale is not None else None
|
self.input_scale = input_scale.float() if input_scale is not None else None
|
||||||
|
|
||||||
if FBGEMM_MM_AVAILABLE:
|
if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
|
||||||
self.scale_upper_bound = (
|
self.scale_upper_bound = torch.tensor(
|
||||||
torch.tensor(
|
scale_upper_bound, dtype=torch.float32, device=qweight.device
|
||||||
[scale_upper_bound], dtype=torch.float32, device=qweight.device
|
|
||||||
)
|
|
||||||
if scale_upper_bound is not None
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.scale_upper_bound = scale_upper_bound
|
self.scale_upper_bound = scale_upper_bound
|
||||||
|
@ -354,7 +336,7 @@ class Fp8Linear(torch.nn.Module):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_unquant(cls, weight, bias, dtype):
|
def from_unquant(cls, weight, bias, dtype):
|
||||||
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
|
qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE)
|
||||||
return cls(
|
return cls(
|
||||||
qweight=qweight,
|
qweight=qweight,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
|
@ -376,9 +358,6 @@ class Fp8Linear(torch.nn.Module):
|
||||||
input_scale = kwargs.get("input_scale", None)
|
input_scale = kwargs.get("input_scale", None)
|
||||||
scale_upper_bound = kwargs.get("scale_upper_bound", None)
|
scale_upper_bound = kwargs.get("scale_upper_bound", None)
|
||||||
|
|
||||||
if FBGEMM_DYN_AVAILABLE:
|
|
||||||
# fbgemm needs float32 scales.
|
|
||||||
scale = scale.float()
|
|
||||||
return cls(
|
return cls(
|
||||||
qweight=weight,
|
qweight=weight,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
|
@ -397,20 +376,14 @@ class Fp8Linear(torch.nn.Module):
|
||||||
return cls._device_identity_cache[device]
|
return cls._device_identity_cache[device]
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
if FBGEMM_MM_AVAILABLE:
|
if CUTLASS_FP8_AVAILABLE:
|
||||||
|
# cutlass FP8 supports per-token scales, so get non-scalar scales.
|
||||||
qinput, scale = fp8_quantize(
|
qinput, scale = fp8_quantize(
|
||||||
input, scale_upper_bound=self.scale_upper_bound
|
input, scale_upper_bound=self.scale_upper_bound, scalar=False
|
||||||
)
|
)
|
||||||
|
return marlin_kernels.cutlass_scaled_mm(
|
||||||
y = torch.ops.fbgemm.f8f8bf16_rowwise(
|
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
|
||||||
qinput,
|
|
||||||
self.qweight,
|
|
||||||
scale,
|
|
||||||
self.scale,
|
|
||||||
use_fast_accum=True,
|
|
||||||
bias=self.bias,
|
|
||||||
)
|
)
|
||||||
return y.to(self.dtype)
|
|
||||||
|
|
||||||
qinput, scale = fp8_quantize(
|
qinput, scale = fp8_quantize(
|
||||||
input,
|
input,
|
||||||
|
|
|
@ -410,12 +410,6 @@ def get_model(
|
||||||
else:
|
else:
|
||||||
# These quantizers only work with float16 params.
|
# These quantizers only work with float16 params.
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
elif quantize == "fp8":
|
|
||||||
from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE
|
|
||||||
|
|
||||||
if FBGEMM_DYN_AVAILABLE:
|
|
||||||
# fbgemm kernels are fp8xfp8->bf16
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
else:
|
else:
|
||||||
# Keep it as default for now and let
|
# Keep it as default for now and let
|
||||||
# every model resolve their own default dtype.
|
# every model resolve their own default dtype.
|
||||||
|
|
Loading…
Reference in New Issue