Switch from fbgemm-gpu w8a8 scaled matmul to vLLM/marlin-kernels (#2688)
* Switch from fbgemm-gpu w8a8 scaled matmul to vLLM/marlin-kernels Performance and accuracy of these kernels are on par (tested with Llama 70B and 405B). Removes a dependency and resolves some stability issues we have been seeing. * Update test snapshots
This commit is contained in:
parent
ba5fc7d922
commit
0f346a3296
11
Dockerfile
11
Dockerfile
|
@ -161,15 +161,6 @@ COPY server/custom_kernels/ .
|
|||
# Build specific version of transformers
|
||||
RUN python setup.py build
|
||||
|
||||
# Build FBGEMM CUDA kernels
|
||||
FROM kernel-builder AS fbgemm-builder
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
COPY server/Makefile-fbgemm Makefile
|
||||
|
||||
RUN make build-fbgemm
|
||||
|
||||
# Build vllm CUDA kernels
|
||||
FROM kernel-builder AS vllm-builder
|
||||
|
||||
|
@ -239,8 +230,6 @@ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86
|
|||
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
# Copy build artifacts from lorax punica kernels builder
|
||||
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
# Copy build artifacts from fbgemm builder
|
||||
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.11/cmake-install /opt/conda/lib/python3.11/site-packages
|
||||
# Copy build artifacts from vllm builder
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
# Copy build artifacts from mamba builder
|
||||
|
|
|
@ -978,16 +978,16 @@
|
|||
"nixpkgs": "nixpkgs_6"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1729531056,
|
||||
"narHash": "sha256-dW9IOA31+j3VS19WAWAmkJW2YCzeVZGqd6HpIJfODtI=",
|
||||
"lastModified": 1729761651,
|
||||
"narHash": "sha256-GYykQ9Fxji2EuXCGcPn0dx8Qx8VQBJTkRdcCytp4A/k=",
|
||||
"owner": "huggingface",
|
||||
"repo": "text-generation-inference-nix",
|
||||
"rev": "a84a90281a17b15762873845c947e5c78f5a8dd1",
|
||||
"rev": "f7e3c4fa67d70590ed9ee47feeab645bd9ba81b1",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "huggingface",
|
||||
"ref": "marlin-kernels-0.3.0",
|
||||
"ref": "marlin-kernels-0.3.1",
|
||||
"repo": "text-generation-inference-nix",
|
||||
"type": "github"
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
};
|
||||
nix-filter.url = "github:numtide/nix-filter";
|
||||
tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.0";
|
||||
tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.1";
|
||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
rust-overlay = {
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "stop_sequence",
|
||||
"generated_tokens": 5,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
|
@ -11,12 +11,12 @@
|
|||
},
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": -9.5625,
|
||||
"logprob": -9.5234375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -10.4375,
|
||||
"logprob": -10.421875,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
|
@ -24,36 +24,66 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 25,
|
||||
"logprob": -0.8984375,
|
||||
"logprob": -0.88183594,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 923,
|
||||
"logprob": -2.84375,
|
||||
"id": 2209,
|
||||
"logprob": -2.6699219,
|
||||
"special": false,
|
||||
"text": " add"
|
||||
"text": " Is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": 0.0,
|
||||
"id": 279,
|
||||
"logprob": -0.61083984,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 734,
|
||||
"logprob": -2.6660156,
|
||||
"special": false,
|
||||
"text": " function"
|
||||
},
|
||||
{
|
||||
"id": 330,
|
||||
"logprob": -0.31640625,
|
||||
"logprob": -0.35498047,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 1985,
|
||||
"logprob": 0.0,
|
||||
"id": 4110,
|
||||
"logprob": -2.4101562,
|
||||
"special": false,
|
||||
"text": "test"
|
||||
"text": "Create"
|
||||
},
|
||||
{
|
||||
"id": 7575,
|
||||
"logprob": -2.2304688,
|
||||
"special": false,
|
||||
"text": "Process"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": -0.080078125,
|
||||
"special": false,
|
||||
"text": "\""
|
||||
},
|
||||
{
|
||||
"id": 304,
|
||||
"logprob": -0.75439453,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 12468,
|
||||
"logprob": -1.8769531,
|
||||
"special": false,
|
||||
"text": " Win"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request: add a \"test"
|
||||
"generated_text": "Test request: Is the function \"CreateProcess\" in Win"
|
||||
}
|
||||
|
|
|
@ -16,17 +16,17 @@
|
|||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -11.75,
|
||||
"logprob": -11.8359375,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -2.0625,
|
||||
"logprob": -2.0703125,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -6.0,
|
||||
"logprob": -5.9765625,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
|
@ -40,25 +40,25 @@
|
|||
},
|
||||
{
|
||||
"id": 34564,
|
||||
"logprob": -0.11279297,
|
||||
"logprob": -0.12512207,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.16015625,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 320,
|
||||
"logprob": -0.25195312,
|
||||
"logprob": -0.23840332,
|
||||
"special": false,
|
||||
"text": " ("
|
||||
},
|
||||
{
|
||||
"id": 16931,
|
||||
"logprob": -1.703125,
|
||||
"logprob": -2.0175781,
|
||||
"special": false,
|
||||
"text": "DL"
|
||||
},
|
||||
|
@ -70,7 +70,7 @@
|
|||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.140625,
|
||||
"logprob": -0.8613281,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
|
@ -82,7 +82,7 @@
|
|||
},
|
||||
{
|
||||
"id": 1207,
|
||||
"logprob": -1.3125,
|
||||
"logprob": -1.2451172,
|
||||
"special": false,
|
||||
"text": " sub"
|
||||
},
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
eetq,
|
||||
einops,
|
||||
exllamav2,
|
||||
fbgemm-gpu,
|
||||
flashinfer,
|
||||
flash-attn,
|
||||
flash-attn-layer-norm,
|
||||
|
@ -77,7 +76,6 @@ buildPythonPackage {
|
|||
causal-conv1d
|
||||
einops
|
||||
exllamav2
|
||||
fbgemm-gpu
|
||||
flashinfer
|
||||
flash-attn
|
||||
flash-attn-layer-norm
|
||||
|
|
|
@ -5,7 +5,6 @@ include Makefile-awq
|
|||
include Makefile-eetq
|
||||
include Makefile-selective-scan
|
||||
include Makefile-lorax-punica
|
||||
include Makefile-fbgemm
|
||||
include Makefile-exllamav2
|
||||
include Makefile-flashinfer
|
||||
|
||||
|
@ -30,7 +29,7 @@ install-server: gen-server
|
|||
install: install-cuda
|
||||
echo "Installed server"
|
||||
|
||||
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
|
||||
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention
|
||||
pip install -e ".[bnb,marlin,moe]"
|
||||
pip install nvidia-nccl-cu12==2.22.3
|
||||
|
||||
|
|
|
@ -1,15 +0,0 @@
|
|||
fbgemm_commit := v0.8.0
|
||||
|
||||
build-fbgemm:
|
||||
@if [ ! -d "fbgemm" ]; then \
|
||||
git clone https://github.com/pytorch/FBGEMM.git fbgemm; \
|
||||
fi
|
||||
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \
|
||||
git submodule update --init --recursive && \
|
||||
cd fbgemm_gpu && \
|
||||
pip install -r requirements.txt && \
|
||||
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build
|
||||
|
||||
install-fbgemm: build-fbgemm
|
||||
cd fbgemm/fbgemm_gpu && \
|
||||
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "accelerate"
|
||||
|
@ -1215,12 +1215,12 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "marlin-kernels"
|
||||
version = "0.3.0"
|
||||
version = "0.3.1"
|
||||
description = "Marlin quantization kernels"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:a2086b9e98d22071f52c5b4b4b98b1b4a988565258905173fa74c5a9eddd1a0a"},
|
||||
{file = "marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:705c89ed54977099a40b37dc0c796964649024f1a8819a1832118cd7b146efe1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -1228,16 +1228,16 @@ torch = "*"
|
|||
|
||||
[package.source]
|
||||
type = "url"
|
||||
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
|
||||
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
|
||||
|
||||
[[package]]
|
||||
name = "marlin-kernels"
|
||||
version = "0.3.0"
|
||||
version = "0.3.1"
|
||||
description = "Marlin quantization kernels"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:f39a6946d8247629446ec170832d832c7038c363f1d8803211fe67249c2d804d"},
|
||||
{file = "marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:e1f3d123eca643149d0a4f6b81c4405d78abb3a694a78fccc8670a25b3404406"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -1245,16 +1245,16 @@ torch = "*"
|
|||
|
||||
[package.source]
|
||||
type = "url"
|
||||
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
|
||||
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
|
||||
|
||||
[[package]]
|
||||
name = "marlin-kernels"
|
||||
version = "0.3.0"
|
||||
version = "0.3.1"
|
||||
description = "Marlin quantization kernels"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:07fd869d5289777fa866107dae676523e18b1f6ba4afce79946ddc58a6870169"},
|
||||
{file = "marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:9d68367fd5e1caf2edc90b77ad5d074b11586012265a3147ecca1f1171ae22f8"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -1262,16 +1262,16 @@ torch = "*"
|
|||
|
||||
[package.source]
|
||||
type = "url"
|
||||
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
|
||||
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
|
||||
|
||||
[[package]]
|
||||
name = "marlin-kernels"
|
||||
version = "0.3.0"
|
||||
version = "0.3.1"
|
||||
description = "Marlin quantization kernels"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:0dedaa418225d490a5f1d8f85dbc75e439a8c43a8870e4ef32945bf61672d7dc"},
|
||||
{file = "marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:d962277c5f7642972e298650913dd0546b9f735b706dc88bb34955b3cac7f330"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -1279,7 +1279,7 @@ torch = "*"
|
|||
|
||||
[package.source]
|
||||
type = "url"
|
||||
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
|
||||
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
|
||||
|
||||
[[package]]
|
||||
name = "mdurl"
|
||||
|
@ -1770,6 +1770,7 @@ description = "Nvidia JIT LTO Library"
|
|||
optional = true
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"},
|
||||
]
|
||||
|
@ -3966,4 +3967,4 @@ torch = ["torch"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<3.13"
|
||||
content-hash = "500fa44255e4a6c89a16314a931548447afe1ba71ea341a73cad6670e46ddac7"
|
||||
content-hash = "b39033e573f50a0f046787aebf1702d86673aad0b2fcee818404fcea7f644b81"
|
||||
|
|
|
@ -41,10 +41,10 @@ py-cpuinfo = "^9.0.0"
|
|||
numpy = "^1.26"
|
||||
|
||||
marlin-kernels = [
|
||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||
]
|
||||
moe-kernels = [
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union, List
|
||||
import os
|
||||
from typing import Optional, Tuple, Type, Union, List
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
@ -11,20 +12,7 @@ from text_generation_server.utils.weights import (
|
|||
UnquantizedWeight,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.log import log_master, log_once
|
||||
import importlib.util
|
||||
|
||||
|
||||
FBGEMM_MM_AVAILABLE = False
|
||||
FBGEMM_DYN_AVAILABLE = False
|
||||
|
||||
|
||||
def is_fbgemm_gpu_available():
|
||||
try:
|
||||
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
try:
|
||||
import marlin_kernels
|
||||
|
@ -32,23 +20,26 @@ except ImportError:
|
|||
marlin_kernels = None
|
||||
|
||||
|
||||
if is_fbgemm_gpu_available():
|
||||
if SYSTEM == "cuda":
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
FBGEMM_MM_AVAILABLE = major == 9
|
||||
FBGEMM_DYN_AVAILABLE = major >= 8
|
||||
if SYSTEM == "cuda" and marlin_kernels is not None:
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
CUTLASS_FP8_AVAILABLE = marlin_kernels.cutlass_scaled_mm_supports_fp8(
|
||||
major * 10 + minor
|
||||
)
|
||||
else:
|
||||
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
|
||||
CUTLASS_FP8_AVAILABLE = False
|
||||
|
||||
|
||||
def get_fp8_linear() -> torch.nn.Module:
|
||||
def get_fp8_linear() -> Type[torch.nn.Module]:
|
||||
"""
|
||||
Return an FP8 linear `Module` that is compatible with the current system.
|
||||
"""
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
if major == 8:
|
||||
if major == 8 and os.getenv("USE_CUTLASS_W8A8", "0") != "1":
|
||||
# NOTE: Capability 8.9 is supported by cutlass kernels, but FP8-Marlin
|
||||
# gives better decoding throughput on L4 and L40.
|
||||
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
||||
|
||||
return GPTQMarlinFP8Linear
|
||||
|
@ -94,12 +85,6 @@ def fp8_quantize(
|
|||
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
|
||||
be used without modification).
|
||||
"""
|
||||
if FBGEMM_DYN_AVAILABLE and not scalar and not scale:
|
||||
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
||||
)
|
||||
return qweight, scale
|
||||
|
||||
if marlin_kernels is not None:
|
||||
shape = weight.shape
|
||||
qweight, scale = marlin_kernels.scaled_fp8_quant(
|
||||
|
@ -107,11 +92,12 @@ def fp8_quantize(
|
|||
dtype=qdtype,
|
||||
scale=scale,
|
||||
scale_ub=scale_upper_bound,
|
||||
# TODO: don't do this when we have to use the Torch kernel.
|
||||
use_per_token_if_dynamic=not scalar,
|
||||
)
|
||||
|
||||
return qweight.reshape(shape), scale
|
||||
|
||||
# weight, scale = quant_weights(weight, torch.int8, False)
|
||||
finfo = torch.finfo(qdtype)
|
||||
|
||||
if scale is None:
|
||||
|
@ -327,8 +313,8 @@ class Fp8Linear(torch.nn.Module):
|
|||
scale_upper_bound: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if FBGEMM_MM_AVAILABLE:
|
||||
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
|
||||
if CUTLASS_FP8_AVAILABLE:
|
||||
log_once(logger.info, "Using cutlass w8a8 kernels")
|
||||
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
|
||||
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=qweight, weight_scale=scale
|
||||
|
@ -339,13 +325,9 @@ class Fp8Linear(torch.nn.Module):
|
|||
self.scale = scale.float()
|
||||
self.input_scale = input_scale.float() if input_scale is not None else None
|
||||
|
||||
if FBGEMM_MM_AVAILABLE:
|
||||
self.scale_upper_bound = (
|
||||
torch.tensor(
|
||||
[scale_upper_bound], dtype=torch.float32, device=qweight.device
|
||||
)
|
||||
if scale_upper_bound is not None
|
||||
else None
|
||||
if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
|
||||
self.scale_upper_bound = torch.tensor(
|
||||
scale_upper_bound, dtype=torch.float32, device=qweight.device
|
||||
)
|
||||
else:
|
||||
self.scale_upper_bound = scale_upper_bound
|
||||
|
@ -354,7 +336,7 @@ class Fp8Linear(torch.nn.Module):
|
|||
|
||||
@classmethod
|
||||
def from_unquant(cls, weight, bias, dtype):
|
||||
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
|
||||
qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE)
|
||||
return cls(
|
||||
qweight=qweight,
|
||||
scale=scale,
|
||||
|
@ -376,9 +358,6 @@ class Fp8Linear(torch.nn.Module):
|
|||
input_scale = kwargs.get("input_scale", None)
|
||||
scale_upper_bound = kwargs.get("scale_upper_bound", None)
|
||||
|
||||
if FBGEMM_DYN_AVAILABLE:
|
||||
# fbgemm needs float32 scales.
|
||||
scale = scale.float()
|
||||
return cls(
|
||||
qweight=weight,
|
||||
scale=scale,
|
||||
|
@ -397,20 +376,14 @@ class Fp8Linear(torch.nn.Module):
|
|||
return cls._device_identity_cache[device]
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if FBGEMM_MM_AVAILABLE:
|
||||
if CUTLASS_FP8_AVAILABLE:
|
||||
# cutlass FP8 supports per-token scales, so get non-scalar scales.
|
||||
qinput, scale = fp8_quantize(
|
||||
input, scale_upper_bound=self.scale_upper_bound
|
||||
input, scale_upper_bound=self.scale_upper_bound, scalar=False
|
||||
)
|
||||
|
||||
y = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||
qinput,
|
||||
self.qweight,
|
||||
scale,
|
||||
self.scale,
|
||||
use_fast_accum=True,
|
||||
bias=self.bias,
|
||||
return marlin_kernels.cutlass_scaled_mm(
|
||||
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
|
||||
)
|
||||
return y.to(self.dtype)
|
||||
|
||||
qinput, scale = fp8_quantize(
|
||||
input,
|
||||
|
|
|
@ -410,12 +410,6 @@ def get_model(
|
|||
else:
|
||||
# These quantizers only work with float16 params.
|
||||
dtype = torch.float16
|
||||
elif quantize == "fp8":
|
||||
from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE
|
||||
|
||||
if FBGEMM_DYN_AVAILABLE:
|
||||
# fbgemm kernels are fp8xfp8->bf16
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
# Keep it as default for now and let
|
||||
# every model resolve their own default dtype.
|
||||
|
|
Loading…
Reference in New Issue