diff --git a/Dockerfile b/Dockerfile index daeb9309..d4189c9f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -161,15 +161,6 @@ COPY server/custom_kernels/ . # Build specific version of transformers RUN python setup.py build -# Build FBGEMM CUDA kernels -FROM kernel-builder AS fbgemm-builder - -WORKDIR /usr/src - -COPY server/Makefile-fbgemm Makefile - -RUN make build-fbgemm - # Build vllm CUDA kernels FROM kernel-builder AS vllm-builder @@ -239,8 +230,6 @@ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86 COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from lorax punica kernels builder COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages -# Copy build artifacts from fbgemm builder -COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.11/cmake-install /opt/conda/lib/python3.11/site-packages # Copy build artifacts from vllm builder COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from mamba builder diff --git a/flake.lock b/flake.lock index 76b4ca2f..1706385a 100644 --- a/flake.lock +++ b/flake.lock @@ -978,16 +978,16 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1729531056, - "narHash": "sha256-dW9IOA31+j3VS19WAWAmkJW2YCzeVZGqd6HpIJfODtI=", + "lastModified": 1729761651, + "narHash": "sha256-GYykQ9Fxji2EuXCGcPn0dx8Qx8VQBJTkRdcCytp4A/k=", "owner": "huggingface", "repo": "text-generation-inference-nix", - "rev": "a84a90281a17b15762873845c947e5c78f5a8dd1", + "rev": "f7e3c4fa67d70590ed9ee47feeab645bd9ba81b1", "type": "github" }, "original": { "owner": "huggingface", - "ref": "marlin-kernels-0.3.0", + "ref": "marlin-kernels-0.3.1", "repo": "text-generation-inference-nix", "type": "github" } diff --git a/flake.nix b/flake.nix index 5c05bfae..45441cae 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.0"; + tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.1"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json index e39829ec..13c46f54 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json @@ -1,8 +1,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "stop_sequence", - "generated_tokens": 5, + "finish_reason": "length", + "generated_tokens": 10, "prefill": [ { "id": 128000, @@ -11,12 +11,12 @@ }, { "id": 2323, - "logprob": -9.5625, + "logprob": -9.5234375, "text": "Test" }, { "id": 1715, - "logprob": -10.4375, + "logprob": -10.421875, "text": " request" } ], @@ -24,36 +24,66 @@ "tokens": [ { "id": 25, - "logprob": -0.8984375, + "logprob": -0.88183594, "special": false, "text": ":" }, { - "id": 923, - "logprob": -2.84375, + "id": 2209, + "logprob": -2.6699219, "special": false, - "text": " add" + "text": " Is" }, { - "id": 264, - "logprob": 0.0, + "id": 279, + "logprob": -0.61083984, "special": false, - "text": " a" + "text": " the" + }, + { + "id": 734, + "logprob": -2.6660156, + "special": false, + "text": " function" }, { "id": 330, - "logprob": -0.31640625, + "logprob": -0.35498047, "special": false, "text": " \"" }, { - "id": 1985, - "logprob": 0.0, + "id": 4110, + "logprob": -2.4101562, "special": false, - "text": "test" + "text": "Create" + }, + { + "id": 7575, + "logprob": -2.2304688, + "special": false, + "text": "Process" + }, + { + "id": 1, + "logprob": -0.080078125, + "special": false, + "text": "\"" + }, + { + "id": 304, + "logprob": -0.75439453, + "special": false, + "text": " in" + }, + { + "id": 12468, + "logprob": -1.8769531, + "special": false, + "text": " Win" } ], "top_tokens": null }, - "generated_text": "Test request: add a \"test" + "generated_text": "Test request: Is the function \"CreateProcess\" in Win" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_all_params.json index 8bce3e10..f195f8f7 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_all_params.json @@ -16,17 +16,17 @@ }, { "id": 5655, - "logprob": -11.75, + "logprob": -11.8359375, "text": " deep" }, { "id": 6975, - "logprob": -2.0625, + "logprob": -2.0703125, "text": " learning" }, { "id": 30, - "logprob": -6.0, + "logprob": -5.9765625, "text": "?" } ], @@ -40,25 +40,25 @@ }, { "id": 34564, - "logprob": -0.11279297, + "logprob": -0.12512207, "special": false, "text": "Deep" }, { "id": 6975, - "logprob": -0.16015625, + "logprob": 0.0, "special": false, "text": " learning" }, { "id": 320, - "logprob": -0.25195312, + "logprob": -0.23840332, "special": false, "text": " (" }, { "id": 16931, - "logprob": -1.703125, + "logprob": -2.0175781, "special": false, "text": "DL" }, @@ -70,7 +70,7 @@ }, { "id": 374, - "logprob": -1.140625, + "logprob": -0.8613281, "special": false, "text": " is" }, @@ -82,7 +82,7 @@ }, { "id": 1207, - "logprob": -1.3125, + "logprob": -1.2451172, "special": false, "text": " sub" }, diff --git a/nix/server.nix b/nix/server.nix index 7406d563..40915546 100644 --- a/nix/server.nix +++ b/nix/server.nix @@ -8,7 +8,6 @@ eetq, einops, exllamav2, - fbgemm-gpu, flashinfer, flash-attn, flash-attn-layer-norm, @@ -77,7 +76,6 @@ buildPythonPackage { causal-conv1d einops exllamav2 - fbgemm-gpu flashinfer flash-attn flash-attn-layer-norm diff --git a/server/Makefile b/server/Makefile index 18424dd6..018d3d8c 100644 --- a/server/Makefile +++ b/server/Makefile @@ -5,7 +5,6 @@ include Makefile-awq include Makefile-eetq include Makefile-selective-scan include Makefile-lorax-punica -include Makefile-fbgemm include Makefile-exllamav2 include Makefile-flashinfer @@ -30,7 +29,7 @@ install-server: gen-server install: install-cuda echo "Installed server" -install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm +install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention pip install -e ".[bnb,marlin,moe]" pip install nvidia-nccl-cu12==2.22.3 diff --git a/server/Makefile-fbgemm b/server/Makefile-fbgemm deleted file mode 100644 index 3b8061a1..00000000 --- a/server/Makefile-fbgemm +++ /dev/null @@ -1,15 +0,0 @@ -fbgemm_commit := v0.8.0 - -build-fbgemm: - @if [ ! -d "fbgemm" ]; then \ - git clone https://github.com/pytorch/FBGEMM.git fbgemm; \ - fi - cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \ - git submodule update --init --recursive && \ - cd fbgemm_gpu && \ - pip install -r requirements.txt && \ - CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build - -install-fbgemm: build-fbgemm - cd fbgemm/fbgemm_gpu && \ - CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install diff --git a/server/poetry.lock b/server/poetry.lock index 1293e883..e75786c3 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "accelerate" @@ -1215,12 +1215,12 @@ files = [ [[package]] name = "marlin-kernels" -version = "0.3.0" +version = "0.3.1" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:a2086b9e98d22071f52c5b4b4b98b1b4a988565258905173fa74c5a9eddd1a0a"}, + {file = "marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:705c89ed54977099a40b37dc0c796964649024f1a8819a1832118cd7b146efe1"}, ] [package.dependencies] @@ -1228,16 +1228,16 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl" [[package]] name = "marlin-kernels" -version = "0.3.0" +version = "0.3.1" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:f39a6946d8247629446ec170832d832c7038c363f1d8803211fe67249c2d804d"}, + {file = "marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:e1f3d123eca643149d0a4f6b81c4405d78abb3a694a78fccc8670a25b3404406"}, ] [package.dependencies] @@ -1245,16 +1245,16 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl" [[package]] name = "marlin-kernels" -version = "0.3.0" +version = "0.3.1" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:07fd869d5289777fa866107dae676523e18b1f6ba4afce79946ddc58a6870169"}, + {file = "marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:9d68367fd5e1caf2edc90b77ad5d074b11586012265a3147ecca1f1171ae22f8"}, ] [package.dependencies] @@ -1262,16 +1262,16 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl" [[package]] name = "marlin-kernels" -version = "0.3.0" +version = "0.3.1" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:0dedaa418225d490a5f1d8f85dbc75e439a8c43a8870e4ef32945bf61672d7dc"}, + {file = "marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:d962277c5f7642972e298650913dd0546b9f735b706dc88bb34955b3cac7f330"}, ] [package.dependencies] @@ -1279,7 +1279,7 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl" [[package]] name = "mdurl" @@ -1770,6 +1770,7 @@ description = "Nvidia JIT LTO Library" optional = true python-versions = ">=3" files = [ + {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"}, ] @@ -3966,4 +3967,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "500fa44255e4a6c89a16314a931548447afe1ba71ea341a73cad6670e46ddac7" +content-hash = "b39033e573f50a0f046787aebf1702d86673aad0b2fcee818404fcea7f644b81" diff --git a/server/pyproject.toml b/server/pyproject.toml index d08d0b8f..5c414d6e 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -41,10 +41,10 @@ py-cpuinfo = "^9.0.0" numpy = "^1.26" marlin-kernels = [ - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] moe-kernels = [ { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index a58c7f7b..21688173 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -1,7 +1,8 @@ -import torch - from dataclasses import dataclass -from typing import Optional, Tuple, Union, List +import os +from typing import Optional, Tuple, Type, Union, List + +import torch from loguru import logger from text_generation_server.utils.import_utils import SYSTEM @@ -11,20 +12,7 @@ from text_generation_server.utils.weights import ( UnquantizedWeight, Weights, ) -from text_generation_server.utils.log import log_master, log_once -import importlib.util - - -FBGEMM_MM_AVAILABLE = False -FBGEMM_DYN_AVAILABLE = False - - -def is_fbgemm_gpu_available(): - try: - return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None - except ModuleNotFoundError: - return False - +from text_generation_server.utils.log import log_once try: import marlin_kernels @@ -32,23 +20,26 @@ except ImportError: marlin_kernels = None -if is_fbgemm_gpu_available(): - if SYSTEM == "cuda": - major, _ = torch.cuda.get_device_capability() - FBGEMM_MM_AVAILABLE = major == 9 - FBGEMM_DYN_AVAILABLE = major >= 8 +if SYSTEM == "cuda" and marlin_kernels is not None: + major, minor = torch.cuda.get_device_capability() + CUTLASS_FP8_AVAILABLE = marlin_kernels.cutlass_scaled_mm_supports_fp8( + major * 10 + minor + ) else: - log_master(logger.warning, "FBGEMM fp8 kernels are not installed.") + CUTLASS_FP8_AVAILABLE = False -def get_fp8_linear() -> torch.nn.Module: +def get_fp8_linear() -> Type[torch.nn.Module]: """ Return an FP8 linear `Module` that is compatible with the current system. """ if SYSTEM == "cuda": + major, _ = torch.cuda.get_device_capability() - if major == 8: + if major == 8 and os.getenv("USE_CUTLASS_W8A8", "0") != "1": + # NOTE: Capability 8.9 is supported by cutlass kernels, but FP8-Marlin + # gives better decoding throughput on L4 and L40. from text_generation_server.layers.marlin import GPTQMarlinFP8Linear return GPTQMarlinFP8Linear @@ -94,12 +85,6 @@ def fp8_quantize( argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can be used without modification). """ - if FBGEMM_DYN_AVAILABLE and not scalar and not scale: - qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( - weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype - ) - return qweight, scale - if marlin_kernels is not None: shape = weight.shape qweight, scale = marlin_kernels.scaled_fp8_quant( @@ -107,11 +92,12 @@ def fp8_quantize( dtype=qdtype, scale=scale, scale_ub=scale_upper_bound, + # TODO: don't do this when we have to use the Torch kernel. + use_per_token_if_dynamic=not scalar, ) return qweight.reshape(shape), scale - # weight, scale = quant_weights(weight, torch.int8, False) finfo = torch.finfo(qdtype) if scale is None: @@ -327,8 +313,8 @@ class Fp8Linear(torch.nn.Module): scale_upper_bound: Optional[float] = None, ) -> None: super().__init__() - if FBGEMM_MM_AVAILABLE: - log_once(logger.info, "Using FBGEMM fp8 optimized kernels") + if CUTLASS_FP8_AVAILABLE: + log_once(logger.info, "Using cutlass w8a8 kernels") if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn: qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=qweight, weight_scale=scale @@ -339,13 +325,9 @@ class Fp8Linear(torch.nn.Module): self.scale = scale.float() self.input_scale = input_scale.float() if input_scale is not None else None - if FBGEMM_MM_AVAILABLE: - self.scale_upper_bound = ( - torch.tensor( - [scale_upper_bound], dtype=torch.float32, device=qweight.device - ) - if scale_upper_bound is not None - else None + if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None: + self.scale_upper_bound = torch.tensor( + scale_upper_bound, dtype=torch.float32, device=qweight.device ) else: self.scale_upper_bound = scale_upper_bound @@ -354,7 +336,7 @@ class Fp8Linear(torch.nn.Module): @classmethod def from_unquant(cls, weight, bias, dtype): - qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE) + qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE) return cls( qweight=qweight, scale=scale, @@ -376,9 +358,6 @@ class Fp8Linear(torch.nn.Module): input_scale = kwargs.get("input_scale", None) scale_upper_bound = kwargs.get("scale_upper_bound", None) - if FBGEMM_DYN_AVAILABLE: - # fbgemm needs float32 scales. - scale = scale.float() return cls( qweight=weight, scale=scale, @@ -397,20 +376,14 @@ class Fp8Linear(torch.nn.Module): return cls._device_identity_cache[device] def forward(self, input: torch.Tensor) -> torch.Tensor: - if FBGEMM_MM_AVAILABLE: + if CUTLASS_FP8_AVAILABLE: + # cutlass FP8 supports per-token scales, so get non-scalar scales. qinput, scale = fp8_quantize( - input, scale_upper_bound=self.scale_upper_bound + input, scale_upper_bound=self.scale_upper_bound, scalar=False ) - - y = torch.ops.fbgemm.f8f8bf16_rowwise( - qinput, - self.qweight, - scale, - self.scale, - use_fast_accum=True, - bias=self.bias, + return marlin_kernels.cutlass_scaled_mm( + qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias ) - return y.to(self.dtype) qinput, scale = fp8_quantize( input, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index d3015408..f4fa431c 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -410,12 +410,6 @@ def get_model( else: # These quantizers only work with float16 params. dtype = torch.float16 - elif quantize == "fp8": - from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE - - if FBGEMM_DYN_AVAILABLE: - # fbgemm kernels are fp8xfp8->bf16 - dtype = torch.bfloat16 else: # Keep it as default for now and let # every model resolve their own default dtype.