Switch from fbgemm-gpu w8a8 scaled matmul to vLLM/marlin-kernels (#2688)

* Switch from fbgemm-gpu w8a8 scaled matmul to vLLM/marlin-kernels

Performance and accuracy of these kernels are on par (tested with Llama
70B and 405B). Removes a dependency and resolves some stability issues
we have been seeing.

* Update test snapshots
This commit is contained in:
Daniël de Kok 2024-10-25 16:40:47 +02:00 committed by GitHub
parent ba5fc7d922
commit 0f346a3296
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 109 additions and 140 deletions

View File

@ -161,15 +161,6 @@ COPY server/custom_kernels/ .
# Build specific version of transformers
RUN python setup.py build
# Build FBGEMM CUDA kernels
FROM kernel-builder AS fbgemm-builder
WORKDIR /usr/src
COPY server/Makefile-fbgemm Makefile
RUN make build-fbgemm
# Build vllm CUDA kernels
FROM kernel-builder AS vllm-builder
@ -239,8 +230,6 @@ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from lorax punica kernels builder
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from fbgemm builder
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.11/cmake-install /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from mamba builder

View File

@ -978,16 +978,16 @@
"nixpkgs": "nixpkgs_6"
},
"locked": {
"lastModified": 1729531056,
"narHash": "sha256-dW9IOA31+j3VS19WAWAmkJW2YCzeVZGqd6HpIJfODtI=",
"lastModified": 1729761651,
"narHash": "sha256-GYykQ9Fxji2EuXCGcPn0dx8Qx8VQBJTkRdcCytp4A/k=",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"rev": "a84a90281a17b15762873845c947e5c78f5a8dd1",
"rev": "f7e3c4fa67d70590ed9ee47feeab645bd9ba81b1",
"type": "github"
},
"original": {
"owner": "huggingface",
"ref": "marlin-kernels-0.3.0",
"ref": "marlin-kernels-0.3.1",
"repo": "text-generation-inference-nix",
"type": "github"
}

View File

@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
};
nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.0";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.1";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
rust-overlay = {

View File

@ -1,8 +1,8 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "stop_sequence",
"generated_tokens": 5,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
@ -11,12 +11,12 @@
},
{
"id": 2323,
"logprob": -9.5625,
"logprob": -9.5234375,
"text": "Test"
},
{
"id": 1715,
"logprob": -10.4375,
"logprob": -10.421875,
"text": " request"
}
],
@ -24,36 +24,66 @@
"tokens": [
{
"id": 25,
"logprob": -0.8984375,
"logprob": -0.88183594,
"special": false,
"text": ":"
},
{
"id": 923,
"logprob": -2.84375,
"id": 2209,
"logprob": -2.6699219,
"special": false,
"text": " add"
"text": " Is"
},
{
"id": 264,
"logprob": 0.0,
"id": 279,
"logprob": -0.61083984,
"special": false,
"text": " a"
"text": " the"
},
{
"id": 734,
"logprob": -2.6660156,
"special": false,
"text": " function"
},
{
"id": 330,
"logprob": -0.31640625,
"logprob": -0.35498047,
"special": false,
"text": " \""
},
{
"id": 1985,
"logprob": 0.0,
"id": 4110,
"logprob": -2.4101562,
"special": false,
"text": "test"
"text": "Create"
},
{
"id": 7575,
"logprob": -2.2304688,
"special": false,
"text": "Process"
},
{
"id": 1,
"logprob": -0.080078125,
"special": false,
"text": "\""
},
{
"id": 304,
"logprob": -0.75439453,
"special": false,
"text": " in"
},
{
"id": 12468,
"logprob": -1.8769531,
"special": false,
"text": " Win"
}
],
"top_tokens": null
},
"generated_text": "Test request: add a \"test"
"generated_text": "Test request: Is the function \"CreateProcess\" in Win"
}

View File

@ -16,17 +16,17 @@
},
{
"id": 5655,
"logprob": -11.75,
"logprob": -11.8359375,
"text": " deep"
},
{
"id": 6975,
"logprob": -2.0625,
"logprob": -2.0703125,
"text": " learning"
},
{
"id": 30,
"logprob": -6.0,
"logprob": -5.9765625,
"text": "?"
}
],
@ -40,25 +40,25 @@
},
{
"id": 34564,
"logprob": -0.11279297,
"logprob": -0.12512207,
"special": false,
"text": "Deep"
},
{
"id": 6975,
"logprob": -0.16015625,
"logprob": 0.0,
"special": false,
"text": " learning"
},
{
"id": 320,
"logprob": -0.25195312,
"logprob": -0.23840332,
"special": false,
"text": " ("
},
{
"id": 16931,
"logprob": -1.703125,
"logprob": -2.0175781,
"special": false,
"text": "DL"
},
@ -70,7 +70,7 @@
},
{
"id": 374,
"logprob": -1.140625,
"logprob": -0.8613281,
"special": false,
"text": " is"
},
@ -82,7 +82,7 @@
},
{
"id": 1207,
"logprob": -1.3125,
"logprob": -1.2451172,
"special": false,
"text": " sub"
},

View File

@ -8,7 +8,6 @@
eetq,
einops,
exllamav2,
fbgemm-gpu,
flashinfer,
flash-attn,
flash-attn-layer-norm,
@ -77,7 +76,6 @@ buildPythonPackage {
causal-conv1d
einops
exllamav2
fbgemm-gpu
flashinfer
flash-attn
flash-attn-layer-norm

View File

@ -5,7 +5,6 @@ include Makefile-awq
include Makefile-eetq
include Makefile-selective-scan
include Makefile-lorax-punica
include Makefile-fbgemm
include Makefile-exllamav2
include Makefile-flashinfer
@ -30,7 +29,7 @@ install-server: gen-server
install: install-cuda
echo "Installed server"
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention
pip install -e ".[bnb,marlin,moe]"
pip install nvidia-nccl-cu12==2.22.3

View File

@ -1,15 +0,0 @@
fbgemm_commit := v0.8.0
build-fbgemm:
@if [ ! -d "fbgemm" ]; then \
git clone https://github.com/pytorch/FBGEMM.git fbgemm; \
fi
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \
git submodule update --init --recursive && \
cd fbgemm_gpu && \
pip install -r requirements.txt && \
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build
install-fbgemm: build-fbgemm
cd fbgemm/fbgemm_gpu && \
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install

29
server/poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
[[package]]
name = "accelerate"
@ -1215,12 +1215,12 @@ files = [
[[package]]
name = "marlin-kernels"
version = "0.3.0"
version = "0.3.1"
description = "Marlin quantization kernels"
optional = true
python-versions = ">=3.7"
files = [
{file = "marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:a2086b9e98d22071f52c5b4b4b98b1b4a988565258905173fa74c5a9eddd1a0a"},
{file = "marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:705c89ed54977099a40b37dc0c796964649024f1a8819a1832118cd7b146efe1"},
]
[package.dependencies]
@ -1228,16 +1228,16 @@ torch = "*"
[package.source]
type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
[[package]]
name = "marlin-kernels"
version = "0.3.0"
version = "0.3.1"
description = "Marlin quantization kernels"
optional = true
python-versions = ">=3.7"
files = [
{file = "marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:f39a6946d8247629446ec170832d832c7038c363f1d8803211fe67249c2d804d"},
{file = "marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:e1f3d123eca643149d0a4f6b81c4405d78abb3a694a78fccc8670a25b3404406"},
]
[package.dependencies]
@ -1245,16 +1245,16 @@ torch = "*"
[package.source]
type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
[[package]]
name = "marlin-kernels"
version = "0.3.0"
version = "0.3.1"
description = "Marlin quantization kernels"
optional = true
python-versions = ">=3.7"
files = [
{file = "marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:07fd869d5289777fa866107dae676523e18b1f6ba4afce79946ddc58a6870169"},
{file = "marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:9d68367fd5e1caf2edc90b77ad5d074b11586012265a3147ecca1f1171ae22f8"},
]
[package.dependencies]
@ -1262,16 +1262,16 @@ torch = "*"
[package.source]
type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
[[package]]
name = "marlin-kernels"
version = "0.3.0"
version = "0.3.1"
description = "Marlin quantization kernels"
optional = true
python-versions = ">=3.7"
files = [
{file = "marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:0dedaa418225d490a5f1d8f85dbc75e439a8c43a8870e4ef32945bf61672d7dc"},
{file = "marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:d962277c5f7642972e298650913dd0546b9f735b706dc88bb34955b3cac7f330"},
]
[package.dependencies]
@ -1279,7 +1279,7 @@ torch = "*"
[package.source]
type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
[[package]]
name = "mdurl"
@ -1770,6 +1770,7 @@ description = "Nvidia JIT LTO Library"
optional = true
python-versions = ">=3"
files = [
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"},
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"},
]
@ -3966,4 +3967,4 @@ torch = ["torch"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<3.13"
content-hash = "500fa44255e4a6c89a16314a931548447afe1ba71ea341a73cad6670e46ddac7"
content-hash = "b39033e573f50a0f046787aebf1702d86673aad0b2fcee818404fcea7f644b81"

View File

@ -41,10 +41,10 @@ py-cpuinfo = "^9.0.0"
numpy = "^1.26"
marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
]
moe-kernels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },

View File

@ -1,7 +1,8 @@
import torch
from dataclasses import dataclass
from typing import Optional, Tuple, Union, List
import os
from typing import Optional, Tuple, Type, Union, List
import torch
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
@ -11,20 +12,7 @@ from text_generation_server.utils.weights import (
UnquantizedWeight,
Weights,
)
from text_generation_server.utils.log import log_master, log_once
import importlib.util
FBGEMM_MM_AVAILABLE = False
FBGEMM_DYN_AVAILABLE = False
def is_fbgemm_gpu_available():
try:
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
except ModuleNotFoundError:
return False
from text_generation_server.utils.log import log_once
try:
import marlin_kernels
@ -32,23 +20,26 @@ except ImportError:
marlin_kernels = None
if is_fbgemm_gpu_available():
if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability()
FBGEMM_MM_AVAILABLE = major == 9
FBGEMM_DYN_AVAILABLE = major >= 8
if SYSTEM == "cuda" and marlin_kernels is not None:
major, minor = torch.cuda.get_device_capability()
CUTLASS_FP8_AVAILABLE = marlin_kernels.cutlass_scaled_mm_supports_fp8(
major * 10 + minor
)
else:
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
CUTLASS_FP8_AVAILABLE = False
def get_fp8_linear() -> torch.nn.Module:
def get_fp8_linear() -> Type[torch.nn.Module]:
"""
Return an FP8 linear `Module` that is compatible with the current system.
"""
if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability()
if major == 8:
if major == 8 and os.getenv("USE_CUTLASS_W8A8", "0") != "1":
# NOTE: Capability 8.9 is supported by cutlass kernels, but FP8-Marlin
# gives better decoding throughput on L4 and L40.
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
return GPTQMarlinFP8Linear
@ -94,12 +85,6 @@ def fp8_quantize(
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
be used without modification).
"""
if FBGEMM_DYN_AVAILABLE and not scalar and not scale:
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
)
return qweight, scale
if marlin_kernels is not None:
shape = weight.shape
qweight, scale = marlin_kernels.scaled_fp8_quant(
@ -107,11 +92,12 @@ def fp8_quantize(
dtype=qdtype,
scale=scale,
scale_ub=scale_upper_bound,
# TODO: don't do this when we have to use the Torch kernel.
use_per_token_if_dynamic=not scalar,
)
return qweight.reshape(shape), scale
# weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype)
if scale is None:
@ -327,8 +313,8 @@ class Fp8Linear(torch.nn.Module):
scale_upper_bound: Optional[float] = None,
) -> None:
super().__init__()
if FBGEMM_MM_AVAILABLE:
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
if CUTLASS_FP8_AVAILABLE:
log_once(logger.info, "Using cutlass w8a8 kernels")
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=qweight, weight_scale=scale
@ -339,13 +325,9 @@ class Fp8Linear(torch.nn.Module):
self.scale = scale.float()
self.input_scale = input_scale.float() if input_scale is not None else None
if FBGEMM_MM_AVAILABLE:
self.scale_upper_bound = (
torch.tensor(
[scale_upper_bound], dtype=torch.float32, device=qweight.device
)
if scale_upper_bound is not None
else None
if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
self.scale_upper_bound = torch.tensor(
scale_upper_bound, dtype=torch.float32, device=qweight.device
)
else:
self.scale_upper_bound = scale_upper_bound
@ -354,7 +336,7 @@ class Fp8Linear(torch.nn.Module):
@classmethod
def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE)
return cls(
qweight=qweight,
scale=scale,
@ -376,9 +358,6 @@ class Fp8Linear(torch.nn.Module):
input_scale = kwargs.get("input_scale", None)
scale_upper_bound = kwargs.get("scale_upper_bound", None)
if FBGEMM_DYN_AVAILABLE:
# fbgemm needs float32 scales.
scale = scale.float()
return cls(
qweight=weight,
scale=scale,
@ -397,20 +376,14 @@ class Fp8Linear(torch.nn.Module):
return cls._device_identity_cache[device]
def forward(self, input: torch.Tensor) -> torch.Tensor:
if FBGEMM_MM_AVAILABLE:
if CUTLASS_FP8_AVAILABLE:
# cutlass FP8 supports per-token scales, so get non-scalar scales.
qinput, scale = fp8_quantize(
input, scale_upper_bound=self.scale_upper_bound
input, scale_upper_bound=self.scale_upper_bound, scalar=False
)
y = torch.ops.fbgemm.f8f8bf16_rowwise(
qinput,
self.qweight,
scale,
self.scale,
use_fast_accum=True,
bias=self.bias,
return marlin_kernels.cutlass_scaled_mm(
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
)
return y.to(self.dtype)
qinput, scale = fp8_quantize(
input,

View File

@ -410,12 +410,6 @@ def get_model(
else:
# These quantizers only work with float16 params.
dtype = torch.float16
elif quantize == "fp8":
from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE
if FBGEMM_DYN_AVAILABLE:
# fbgemm kernels are fp8xfp8->bf16
dtype = torch.bfloat16
else:
# Keep it as default for now and let
# every model resolve their own default dtype.