Switch from fbgemm-gpu w8a8 scaled matmul to vLLM/marlin-kernels (#2688)

* Switch from fbgemm-gpu w8a8 scaled matmul to vLLM/marlin-kernels

Performance and accuracy of these kernels are on par (tested with Llama
70B and 405B). Removes a dependency and resolves some stability issues
we have been seeing.

* Update test snapshots
This commit is contained in:
Daniël de Kok 2024-10-25 16:40:47 +02:00 committed by GitHub
parent ba5fc7d922
commit 0f346a3296
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 109 additions and 140 deletions

View File

@ -161,15 +161,6 @@ COPY server/custom_kernels/ .
# Build specific version of transformers # Build specific version of transformers
RUN python setup.py build RUN python setup.py build
# Build FBGEMM CUDA kernels
FROM kernel-builder AS fbgemm-builder
WORKDIR /usr/src
COPY server/Makefile-fbgemm Makefile
RUN make build-fbgemm
# Build vllm CUDA kernels # Build vllm CUDA kernels
FROM kernel-builder AS vllm-builder FROM kernel-builder AS vllm-builder
@ -239,8 +230,6 @@ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from lorax punica kernels builder # Copy build artifacts from lorax punica kernels builder
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from fbgemm builder
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.11/cmake-install /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from vllm builder # Copy build artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from mamba builder # Copy build artifacts from mamba builder

View File

@ -978,16 +978,16 @@
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1729531056, "lastModified": 1729761651,
"narHash": "sha256-dW9IOA31+j3VS19WAWAmkJW2YCzeVZGqd6HpIJfODtI=", "narHash": "sha256-GYykQ9Fxji2EuXCGcPn0dx8Qx8VQBJTkRdcCytp4A/k=",
"owner": "huggingface", "owner": "huggingface",
"repo": "text-generation-inference-nix", "repo": "text-generation-inference-nix",
"rev": "a84a90281a17b15762873845c947e5c78f5a8dd1", "rev": "f7e3c4fa67d70590ed9ee47feeab645bd9ba81b1",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "huggingface", "owner": "huggingface",
"ref": "marlin-kernels-0.3.0", "ref": "marlin-kernels-0.3.1",
"repo": "text-generation-inference-nix", "repo": "text-generation-inference-nix",
"type": "github" "type": "github"
} }

View File

@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
}; };
nix-filter.url = "github:numtide/nix-filter"; nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.0"; tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.1";
nixpkgs.follows = "tgi-nix/nixpkgs"; nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils"; flake-utils.url = "github:numtide/flake-utils";
rust-overlay = { rust-overlay = {

View File

@ -1,8 +1,8 @@
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
"finish_reason": "stop_sequence", "finish_reason": "length",
"generated_tokens": 5, "generated_tokens": 10,
"prefill": [ "prefill": [
{ {
"id": 128000, "id": 128000,
@ -11,12 +11,12 @@
}, },
{ {
"id": 2323, "id": 2323,
"logprob": -9.5625, "logprob": -9.5234375,
"text": "Test" "text": "Test"
}, },
{ {
"id": 1715, "id": 1715,
"logprob": -10.4375, "logprob": -10.421875,
"text": " request" "text": " request"
} }
], ],
@ -24,36 +24,66 @@
"tokens": [ "tokens": [
{ {
"id": 25, "id": 25,
"logprob": -0.8984375, "logprob": -0.88183594,
"special": false, "special": false,
"text": ":" "text": ":"
}, },
{ {
"id": 923, "id": 2209,
"logprob": -2.84375, "logprob": -2.6699219,
"special": false, "special": false,
"text": " add" "text": " Is"
}, },
{ {
"id": 264, "id": 279,
"logprob": 0.0, "logprob": -0.61083984,
"special": false, "special": false,
"text": " a" "text": " the"
},
{
"id": 734,
"logprob": -2.6660156,
"special": false,
"text": " function"
}, },
{ {
"id": 330, "id": 330,
"logprob": -0.31640625, "logprob": -0.35498047,
"special": false, "special": false,
"text": " \"" "text": " \""
}, },
{ {
"id": 1985, "id": 4110,
"logprob": 0.0, "logprob": -2.4101562,
"special": false, "special": false,
"text": "test" "text": "Create"
},
{
"id": 7575,
"logprob": -2.2304688,
"special": false,
"text": "Process"
},
{
"id": 1,
"logprob": -0.080078125,
"special": false,
"text": "\""
},
{
"id": 304,
"logprob": -0.75439453,
"special": false,
"text": " in"
},
{
"id": 12468,
"logprob": -1.8769531,
"special": false,
"text": " Win"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "Test request: add a \"test" "generated_text": "Test request: Is the function \"CreateProcess\" in Win"
} }

View File

@ -16,17 +16,17 @@
}, },
{ {
"id": 5655, "id": 5655,
"logprob": -11.75, "logprob": -11.8359375,
"text": " deep" "text": " deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -2.0625, "logprob": -2.0703125,
"text": " learning" "text": " learning"
}, },
{ {
"id": 30, "id": 30,
"logprob": -6.0, "logprob": -5.9765625,
"text": "?" "text": "?"
} }
], ],
@ -40,25 +40,25 @@
}, },
{ {
"id": 34564, "id": 34564,
"logprob": -0.11279297, "logprob": -0.12512207,
"special": false, "special": false,
"text": "Deep" "text": "Deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.16015625, "logprob": 0.0,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 320, "id": 320,
"logprob": -0.25195312, "logprob": -0.23840332,
"special": false, "special": false,
"text": " (" "text": " ("
}, },
{ {
"id": 16931, "id": 16931,
"logprob": -1.703125, "logprob": -2.0175781,
"special": false, "special": false,
"text": "DL" "text": "DL"
}, },
@ -70,7 +70,7 @@
}, },
{ {
"id": 374, "id": 374,
"logprob": -1.140625, "logprob": -0.8613281,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
@ -82,7 +82,7 @@
}, },
{ {
"id": 1207, "id": 1207,
"logprob": -1.3125, "logprob": -1.2451172,
"special": false, "special": false,
"text": " sub" "text": " sub"
}, },

View File

@ -8,7 +8,6 @@
eetq, eetq,
einops, einops,
exllamav2, exllamav2,
fbgemm-gpu,
flashinfer, flashinfer,
flash-attn, flash-attn,
flash-attn-layer-norm, flash-attn-layer-norm,
@ -77,7 +76,6 @@ buildPythonPackage {
causal-conv1d causal-conv1d
einops einops
exllamav2 exllamav2
fbgemm-gpu
flashinfer flashinfer
flash-attn flash-attn
flash-attn-layer-norm flash-attn-layer-norm

View File

@ -5,7 +5,6 @@ include Makefile-awq
include Makefile-eetq include Makefile-eetq
include Makefile-selective-scan include Makefile-selective-scan
include Makefile-lorax-punica include Makefile-lorax-punica
include Makefile-fbgemm
include Makefile-exllamav2 include Makefile-exllamav2
include Makefile-flashinfer include Makefile-flashinfer
@ -30,7 +29,7 @@ install-server: gen-server
install: install-cuda install: install-cuda
echo "Installed server" echo "Installed server"
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention
pip install -e ".[bnb,marlin,moe]" pip install -e ".[bnb,marlin,moe]"
pip install nvidia-nccl-cu12==2.22.3 pip install nvidia-nccl-cu12==2.22.3

View File

@ -1,15 +0,0 @@
fbgemm_commit := v0.8.0
build-fbgemm:
@if [ ! -d "fbgemm" ]; then \
git clone https://github.com/pytorch/FBGEMM.git fbgemm; \
fi
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \
git submodule update --init --recursive && \
cd fbgemm_gpu && \
pip install -r requirements.txt && \
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build
install-fbgemm: build-fbgemm
cd fbgemm/fbgemm_gpu && \
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install

29
server/poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
[[package]] [[package]]
name = "accelerate" name = "accelerate"
@ -1215,12 +1215,12 @@ files = [
[[package]] [[package]]
name = "marlin-kernels" name = "marlin-kernels"
version = "0.3.0" version = "0.3.1"
description = "Marlin quantization kernels" description = "Marlin quantization kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:a2086b9e98d22071f52c5b4b4b98b1b4a988565258905173fa74c5a9eddd1a0a"}, {file = "marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:705c89ed54977099a40b37dc0c796964649024f1a8819a1832118cd7b146efe1"},
] ]
[package.dependencies] [package.dependencies]
@ -1228,16 +1228,16 @@ torch = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl" url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
[[package]] [[package]]
name = "marlin-kernels" name = "marlin-kernels"
version = "0.3.0" version = "0.3.1"
description = "Marlin quantization kernels" description = "Marlin quantization kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:f39a6946d8247629446ec170832d832c7038c363f1d8803211fe67249c2d804d"}, {file = "marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:e1f3d123eca643149d0a4f6b81c4405d78abb3a694a78fccc8670a25b3404406"},
] ]
[package.dependencies] [package.dependencies]
@ -1245,16 +1245,16 @@ torch = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl" url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
[[package]] [[package]]
name = "marlin-kernels" name = "marlin-kernels"
version = "0.3.0" version = "0.3.1"
description = "Marlin quantization kernels" description = "Marlin quantization kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:07fd869d5289777fa866107dae676523e18b1f6ba4afce79946ddc58a6870169"}, {file = "marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:9d68367fd5e1caf2edc90b77ad5d074b11586012265a3147ecca1f1171ae22f8"},
] ]
[package.dependencies] [package.dependencies]
@ -1262,16 +1262,16 @@ torch = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl" url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
[[package]] [[package]]
name = "marlin-kernels" name = "marlin-kernels"
version = "0.3.0" version = "0.3.1"
description = "Marlin quantization kernels" description = "Marlin quantization kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:0dedaa418225d490a5f1d8f85dbc75e439a8c43a8870e4ef32945bf61672d7dc"}, {file = "marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:d962277c5f7642972e298650913dd0546b9f735b706dc88bb34955b3cac7f330"},
] ]
[package.dependencies] [package.dependencies]
@ -1279,7 +1279,7 @@ torch = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
[[package]] [[package]]
name = "mdurl" name = "mdurl"
@ -1770,6 +1770,7 @@ description = "Nvidia JIT LTO Library"
optional = true optional = true
python-versions = ">=3" python-versions = ">=3"
files = [ files = [
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"},
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"},
] ]
@ -3966,4 +3967,4 @@ torch = ["torch"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9,<3.13" python-versions = ">=3.9,<3.13"
content-hash = "500fa44255e4a6c89a16314a931548447afe1ba71ea341a73cad6670e46ddac7" content-hash = "b39033e573f50a0f046787aebf1702d86673aad0b2fcee818404fcea7f644b81"

View File

@ -41,10 +41,10 @@ py-cpuinfo = "^9.0.0"
numpy = "^1.26" numpy = "^1.26"
marlin-kernels = [ marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
] ]
moe-kernels = [ moe-kernels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },

View File

@ -1,7 +1,8 @@
import torch
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union, List import os
from typing import Optional, Tuple, Type, Union, List
import torch
from loguru import logger from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
@ -11,20 +12,7 @@ from text_generation_server.utils.weights import (
UnquantizedWeight, UnquantizedWeight,
Weights, Weights,
) )
from text_generation_server.utils.log import log_master, log_once from text_generation_server.utils.log import log_once
import importlib.util
FBGEMM_MM_AVAILABLE = False
FBGEMM_DYN_AVAILABLE = False
def is_fbgemm_gpu_available():
try:
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
except ModuleNotFoundError:
return False
try: try:
import marlin_kernels import marlin_kernels
@ -32,23 +20,26 @@ except ImportError:
marlin_kernels = None marlin_kernels = None
if is_fbgemm_gpu_available(): if SYSTEM == "cuda" and marlin_kernels is not None:
if SYSTEM == "cuda": major, minor = torch.cuda.get_device_capability()
major, _ = torch.cuda.get_device_capability() CUTLASS_FP8_AVAILABLE = marlin_kernels.cutlass_scaled_mm_supports_fp8(
FBGEMM_MM_AVAILABLE = major == 9 major * 10 + minor
FBGEMM_DYN_AVAILABLE = major >= 8 )
else: else:
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.") CUTLASS_FP8_AVAILABLE = False
def get_fp8_linear() -> torch.nn.Module: def get_fp8_linear() -> Type[torch.nn.Module]:
""" """
Return an FP8 linear `Module` that is compatible with the current system. Return an FP8 linear `Module` that is compatible with the current system.
""" """
if SYSTEM == "cuda": if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability() major, _ = torch.cuda.get_device_capability()
if major == 8: if major == 8 and os.getenv("USE_CUTLASS_W8A8", "0") != "1":
# NOTE: Capability 8.9 is supported by cutlass kernels, but FP8-Marlin
# gives better decoding throughput on L4 and L40.
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
return GPTQMarlinFP8Linear return GPTQMarlinFP8Linear
@ -94,12 +85,6 @@ def fp8_quantize(
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
be used without modification). be used without modification).
""" """
if FBGEMM_DYN_AVAILABLE and not scalar and not scale:
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
)
return qweight, scale
if marlin_kernels is not None: if marlin_kernels is not None:
shape = weight.shape shape = weight.shape
qweight, scale = marlin_kernels.scaled_fp8_quant( qweight, scale = marlin_kernels.scaled_fp8_quant(
@ -107,11 +92,12 @@ def fp8_quantize(
dtype=qdtype, dtype=qdtype,
scale=scale, scale=scale,
scale_ub=scale_upper_bound, scale_ub=scale_upper_bound,
# TODO: don't do this when we have to use the Torch kernel.
use_per_token_if_dynamic=not scalar,
) )
return qweight.reshape(shape), scale return qweight.reshape(shape), scale
# weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype) finfo = torch.finfo(qdtype)
if scale is None: if scale is None:
@ -327,8 +313,8 @@ class Fp8Linear(torch.nn.Module):
scale_upper_bound: Optional[float] = None, scale_upper_bound: Optional[float] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if FBGEMM_MM_AVAILABLE: if CUTLASS_FP8_AVAILABLE:
log_once(logger.info, "Using FBGEMM fp8 optimized kernels") log_once(logger.info, "Using cutlass w8a8 kernels")
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn: if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz( qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=qweight, weight_scale=scale weight=qweight, weight_scale=scale
@ -339,13 +325,9 @@ class Fp8Linear(torch.nn.Module):
self.scale = scale.float() self.scale = scale.float()
self.input_scale = input_scale.float() if input_scale is not None else None self.input_scale = input_scale.float() if input_scale is not None else None
if FBGEMM_MM_AVAILABLE: if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
self.scale_upper_bound = ( self.scale_upper_bound = torch.tensor(
torch.tensor( scale_upper_bound, dtype=torch.float32, device=qweight.device
[scale_upper_bound], dtype=torch.float32, device=qweight.device
)
if scale_upper_bound is not None
else None
) )
else: else:
self.scale_upper_bound = scale_upper_bound self.scale_upper_bound = scale_upper_bound
@ -354,7 +336,7 @@ class Fp8Linear(torch.nn.Module):
@classmethod @classmethod
def from_unquant(cls, weight, bias, dtype): def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE) qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE)
return cls( return cls(
qweight=qweight, qweight=qweight,
scale=scale, scale=scale,
@ -376,9 +358,6 @@ class Fp8Linear(torch.nn.Module):
input_scale = kwargs.get("input_scale", None) input_scale = kwargs.get("input_scale", None)
scale_upper_bound = kwargs.get("scale_upper_bound", None) scale_upper_bound = kwargs.get("scale_upper_bound", None)
if FBGEMM_DYN_AVAILABLE:
# fbgemm needs float32 scales.
scale = scale.float()
return cls( return cls(
qweight=weight, qweight=weight,
scale=scale, scale=scale,
@ -397,20 +376,14 @@ class Fp8Linear(torch.nn.Module):
return cls._device_identity_cache[device] return cls._device_identity_cache[device]
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if FBGEMM_MM_AVAILABLE: if CUTLASS_FP8_AVAILABLE:
# cutlass FP8 supports per-token scales, so get non-scalar scales.
qinput, scale = fp8_quantize( qinput, scale = fp8_quantize(
input, scale_upper_bound=self.scale_upper_bound input, scale_upper_bound=self.scale_upper_bound, scalar=False
) )
return marlin_kernels.cutlass_scaled_mm(
y = torch.ops.fbgemm.f8f8bf16_rowwise( qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
qinput,
self.qweight,
scale,
self.scale,
use_fast_accum=True,
bias=self.bias,
) )
return y.to(self.dtype)
qinput, scale = fp8_quantize( qinput, scale = fp8_quantize(
input, input,

View File

@ -410,12 +410,6 @@ def get_model(
else: else:
# These quantizers only work with float16 params. # These quantizers only work with float16 params.
dtype = torch.float16 dtype = torch.float16
elif quantize == "fp8":
from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE
if FBGEMM_DYN_AVAILABLE:
# fbgemm kernels are fp8xfp8->bf16
dtype = torch.bfloat16
else: else:
# Keep it as default for now and let # Keep it as default for now and let
# every model resolve their own default dtype. # every model resolve their own default dtype.