feat(fp8): use fbgemm kernels and load fp8 weights directly (#2248)

* feat(fp8): add support for fbgemm

* allow loading fp8 weights directly

* update outlines

* fix makefile

* build fbgemm

* avoid circular import and fix dockerfile

* add default dtype

* refactored weights loader

* fix auto conversion

* fix quantization config parsing

* force new nccl on install

* missing get_weights implementation

* increase timeout
This commit is contained in:
OlivierDehaene 2024-07-20 17:02:04 +00:00 committed by GitHub
parent e5c1d6d611
commit 53ec0b790b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 737 additions and 110 deletions

View File

@ -183,4 +183,3 @@ jobs:
export HF_TOKEN=${{ secrets.HF_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
echo $DOCKER_IMAGE echo $DOCKER_IMAGE
pytest -s -vv integration-tests ${PYTEST_FLAGS} pytest -s -vv integration-tests ${PYTEST_FLAGS}

View File

@ -161,6 +161,17 @@ COPY server/custom_kernels/ .
# Build specific version of transformers # Build specific version of transformers
RUN python setup.py build RUN python setup.py build
# Build FBGEMM CUDA kernels
FROM kernel-builder AS fbgemm-builder
WORKDIR /usr/src
COPY server/Makefile-fbgemm Makefile
COPY server/fbgemm_remove_unused.patch fbgemm_remove_unused.patch
COPY server/fix_torch90a.sh fix_torch90a.sh
RUN make build-fbgemm
# Build vllm CUDA kernels # Build vllm CUDA kernels
FROM kernel-builder AS vllm-builder FROM kernel-builder AS vllm-builder
@ -225,10 +236,10 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31
# Copy build artifacts from marlin kernels builder # Copy build artifacts from marlin kernels builder
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from fbgemm builder
# Copy builds artifacts from vllm builder COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from mamba builder # Copy build artifacts from mamba builder
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages

View File

@ -5,6 +5,7 @@ include Makefile-awq
include Makefile-eetq include Makefile-eetq
include Makefile-selective-scan include Makefile-selective-scan
include Makefile-lorax-punica include Makefile-lorax-punica
include Makefile-fbgemm
unit-tests: unit-tests:
pytest -s -vv -m "not private" tests pytest -s -vv -m "not private" tests
@ -27,8 +28,9 @@ install-server: gen-server
install: install-cuda install: install-cuda
echo "Installed server" echo "Installed server"
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
pip install -e ".[bnb]" pip install -e ".[bnb]"
pip install nvidia-nccl-cu12==2.22.3
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
@ -36,5 +38,6 @@ run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
export-requirements: export-requirements:
poetry export -o requirements_cuda.txt --without-hashes --with cuda poetry export -o requirements_cuda.txt --without-hashes
poetry export -o requirements_rocm.txt --without-hashes poetry export -o requirements_rocm.txt --without-hashes
poetry export -o requirements_intel.txt --without-hashes

15
server/Makefile-fbgemm Normal file
View File

@ -0,0 +1,15 @@
fbgemm_commit := 9cf0429b726931cfab72b8264730bea682f32fca
build-fbgemm:
chmod +x fix_torch90a.sh && ./fix_torch90a.sh && \
git clone https://github.com/pytorch/FBGEMM.git fbgemm && \
cp fbgemm_remove_unused.patch fbgemm && \
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && git apply fbgemm_remove_unused.patch && \
git submodule update --init --recursive && \
cd fbgemm_gpu && \
pip install -r requirements.txt && \
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build
install-fbgemm: build-fbgemm
cd fbgemm/fbgemm_gpu && \
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install

View File

@ -0,0 +1,306 @@
diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt
index 2244ea6f..96265a48 100644
--- a/fbgemm_gpu/CMakeLists.txt
+++ b/fbgemm_gpu/CMakeLists.txt
@@ -94,14 +94,14 @@ endif()
# Build Experimental Modules
################################################################################
-if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
- # TODO: Figure out NCCL/RCCL integration with ROCm
- add_subdirectory(experimental/example)
-endif()
-
-if(NOT FBGEMM_CPU_ONLY)
- add_subdirectory(experimental/gemm)
-endif()
+# if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
+# # TODO: Figure out NCCL/RCCL integration with ROCm
+# add_subdirectory(experimental/example)
+# endif()
+
+# if(NOT FBGEMM_CPU_ONLY)
+# add_subdirectory(experimental/gemm)
+# endif()
if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
# CUTLASS currently doesn't build on ROCm and CK hasnt yet been added:
diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake
index c56773fe..0c0d349e 100644
--- a/fbgemm_gpu/FbgemmGpu.cmake
+++ b/fbgemm_gpu/FbgemmGpu.cmake
@@ -446,53 +446,55 @@ set_source_files_properties(${fbgemm_sources}
################################################################################
set(fbgemm_gpu_sources_static_cpu
- codegen/training/forward/embedding_forward_split_cpu.cpp
- codegen/inference/embedding_forward_quantized_host_cpu.cpp
- codegen/training/backward/embedding_backward_dense_host_cpu.cpp
- codegen/utils/embedding_bounds_check_host_cpu.cpp
- src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
- src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp
- src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp
- src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
- src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
- src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
- src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
- src/input_combine_ops/input_combine_cpu.cpp
- src/layout_transform_ops/layout_transform_ops_cpu.cpp
+ # codegen/training/forward/embedding_forward_split_cpu.cpp
+ # codegen/inference/embedding_forward_quantized_host_cpu.cpp
+ # codegen/training/backward/embedding_backward_dense_host_cpu.cpp
+ # codegen/utils/embedding_bounds_check_host_cpu.cpp
+ # src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
+ # src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp
+ # src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
+ # src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
+ # src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
+ # src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
+ # src/input_combine_ops/input_combine_cpu.cpp
+ # src/layout_transform_ops/layout_transform_ops_cpu.cpp
src/quantize_ops/quantize_ops_cpu.cpp
src/quantize_ops/quantize_ops_meta.cpp
- src/sparse_ops/sparse_ops_cpu.cpp
- src/sparse_ops/sparse_ops_meta.cpp
- src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp
- src/split_embeddings_cache/linearize_cache_indices.cpp
- src/split_embeddings_cache/lfu_cache_populate_byte.cpp
- src/split_embeddings_cache/lru_cache_populate_byte.cpp
- src/split_embeddings_cache/lxu_cache.cpp
- src/split_embeddings_cache/split_embeddings_cache_ops.cpp
- codegen/training/index_select/batch_index_select_dim0_ops.cpp
- codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)
+ # src/sparse_ops/sparse_ops_cpu.cpp
+ # src/sparse_ops/sparse_ops_meta.cpp
+ # src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp
+ # src/split_embeddings_cache/linearize_cache_indices.cpp
+ # src/split_embeddings_cache/lfu_cache_populate_byte.cpp
+ # src/split_embeddings_cache/lru_cache_populate_byte.cpp
+ # src/split_embeddings_cache/lxu_cache.cpp
+ # src/split_embeddings_cache/split_embeddings_cache_ops.cpp
+ # codegen/training/index_select/batch_index_select_dim0_ops.cpp
+ # codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)
+)
if(NOT FBGEMM_CPU_ONLY)
list(APPEND fbgemm_gpu_sources_static_cpu
- codegen/inference/embedding_forward_quantized_host.cpp
- codegen/utils/embedding_bounds_check_host.cpp
- src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp
- src/layout_transform_ops/layout_transform_ops_gpu.cpp
- src/memory_utils/memory_utils.cpp
- src/memory_utils/memory_utils_ops.cpp
- src/memory_utils/memory_utils_ops_cpu.cpp
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp
+ # codegen/inference/embedding_forward_quantized_host.cpp
+ # codegen/utils/embedding_bounds_check_host.cpp
+ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp
+ # src/layout_transform_ops/layout_transform_ops_gpu.cpp
+ # src/memory_utils/memory_utils.cpp
+ # src/memory_utils/memory_utils_ops.cpp
+ # src/memory_utils/memory_utils_ops_cpu.cpp
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp
src/quantize_ops/quantize_ops_gpu.cpp
- src/sparse_ops/sparse_ops_gpu.cpp
- src/split_embeddings_utils/split_embeddings_utils.cpp
- src/split_embeddings_cache/split_embeddings_cache_ops.cu
- src/metric_ops/metric_ops_host.cpp
- src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp
- src/input_combine_ops/input_combine_gpu.cpp
- codegen/training/index_select/batch_index_select_dim0_host.cpp)
+ # src/sparse_ops/sparse_ops_gpu.cpp
+ # src/split_embeddings_utils/split_embeddings_utils.cpp
+ # src/split_embeddings_cache/split_embeddings_cache_ops.cu
+ # src/metric_ops/metric_ops_host.cpp
+ # src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp
+ # src/input_combine_ops/input_combine_gpu.cpp
+ # codegen/training/index_select/batch_index_select_dim0_host.cpp)
+ )
if(NVML_LIB_PATH OR USE_ROCM)
message(STATUS "Adding merge_pooled_embeddings sources")
@@ -516,36 +518,36 @@ endif()
if(NOT FBGEMM_CPU_ONLY)
set(fbgemm_gpu_sources_static_gpu
- codegen/utils/embedding_bounds_check.cu
- codegen/inference/embedding_forward_quantized_split_lookup.cu
- src/embedding_inplace_ops/embedding_inplace_update.cu
- src/histogram_binning_calibration_ops.cu
- src/input_combine_ops/input_combine.cu
- src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu
- src/memory_utils/memory_utils.cu
- src/memory_utils/memory_utils_ops.cu
- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu
- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu
- src/jagged_tensor_ops/dense_to_jagged_forward.cu
- src/jagged_tensor_ops/jagged_dense_bmm_forward.cu
- src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu
- src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu
- src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu
- src/jagged_tensor_ops/jagged_index_add_2d_forward.cu
- src/jagged_tensor_ops/jagged_index_select_2d_forward.cu
- src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu
- src/jagged_tensor_ops/jagged_softmax_backward.cu
- src/jagged_tensor_ops/jagged_softmax_forward.cu
- src/jagged_tensor_ops/jagged_tensor_ops.cu
- src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu
- src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
- src/jagged_tensor_ops/jagged_unique_indices.cu
- src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
- src/layout_transform_ops/layout_transform_ops.cu
- src/metric_ops/metric_ops.cu
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu
- src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu
+ # codegen/utils/embedding_bounds_check.cu
+ # codegen/inference/embedding_forward_quantized_split_lookup.cu
+ # src/embedding_inplace_ops/embedding_inplace_update.cu
+ # src/histogram_binning_calibration_ops.cu
+ # src/input_combine_ops/input_combine.cu
+ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu
+ # src/memory_utils/memory_utils.cu
+ # src/memory_utils/memory_utils_ops.cu
+ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu
+ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu
+ # src/jagged_tensor_ops/dense_to_jagged_forward.cu
+ # src/jagged_tensor_ops/jagged_dense_bmm_forward.cu
+ # src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu
+ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu
+ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu
+ # src/jagged_tensor_ops/jagged_index_add_2d_forward.cu
+ # src/jagged_tensor_ops/jagged_index_select_2d_forward.cu
+ # src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu
+ # src/jagged_tensor_ops/jagged_softmax_backward.cu
+ # src/jagged_tensor_ops/jagged_softmax_forward.cu
+ # src/jagged_tensor_ops/jagged_tensor_ops.cu
+ # src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu
+ # src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
+ # src/jagged_tensor_ops/jagged_unique_indices.cu
+ # src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
+ # src/layout_transform_ops/layout_transform_ops.cu
+ # src/metric_ops/metric_ops.cu
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu
+ # src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu
src/quantize_ops/quantize_bfloat16.cu
src/quantize_ops/quantize_fp8_rowwise.cu
src/quantize_ops/quantize_fused_8bit_rowwise.cu
@@ -554,39 +556,40 @@ if(NOT FBGEMM_CPU_ONLY)
src/quantize_ops/quantize_msfp.cu
src/quantize_ops/quantize_padded_fp8_rowwise.cu
src/quantize_ops/quantize_mx.cu
- src/sparse_ops/sparse_async_cumsum.cu
- src/sparse_ops/sparse_block_bucketize_features.cu
- src/sparse_ops/sparse_bucketize_features.cu
- src/sparse_ops/sparse_batched_unary_embeddings.cu
- src/sparse_ops/sparse_compute_frequency_sequence.cu
- src/sparse_ops/sparse_expand_into_jagged_permute.cu
- src/sparse_ops/sparse_group_index.cu
- src/sparse_ops/sparse_index_add.cu
- src/sparse_ops/sparse_index_select.cu
- src/sparse_ops/sparse_invert_permute.cu
- src/sparse_ops/sparse_pack_segments_backward.cu
- src/sparse_ops/sparse_pack_segments_forward.cu
- src/sparse_ops/sparse_permute_1d.cu
- src/sparse_ops/sparse_permute_2d.cu
- src/sparse_ops/sparse_permute102.cu
- src/sparse_ops/sparse_permute_embeddings.cu
- src/sparse_ops/sparse_range.cu
- src/sparse_ops/sparse_reorder_batched_ad.cu
- src/sparse_ops/sparse_segment_sum_csr.cu
- src/sparse_ops/sparse_zipf.cu
- src/split_embeddings_cache/lfu_cache_find.cu
- src/split_embeddings_cache/lfu_cache_populate.cu
- src/split_embeddings_cache/lfu_cache_populate_byte.cu
- src/split_embeddings_cache/lru_cache_find.cu
- src/split_embeddings_cache/lru_cache_populate.cu
- src/split_embeddings_cache/lru_cache_populate_byte.cu
- src/split_embeddings_cache/lxu_cache.cu
- src/split_embeddings_cache/linearize_cache_indices.cu
- src/split_embeddings_cache/reset_weight_momentum.cu
- src/split_embeddings_utils/generate_vbe_metadata.cu
- src/split_embeddings_utils/get_infos_metadata.cu
- src/split_embeddings_utils/radix_sort_pairs.cu
- src/split_embeddings_utils/transpose_embedding_input.cu)
+ # src/sparse_ops/sparse_async_cumsum.cu
+ # src/sparse_ops/sparse_block_bucketize_features.cu
+ # src/sparse_ops/sparse_bucketize_features.cu
+ # src/sparse_ops/sparse_batched_unary_embeddings.cu
+ # src/sparse_ops/sparse_compute_frequency_sequence.cu
+ # src/sparse_ops/sparse_expand_into_jagged_permute.cu
+ # src/sparse_ops/sparse_group_index.cu
+ # src/sparse_ops/sparse_index_add.cu
+ # src/sparse_ops/sparse_index_select.cu
+ # src/sparse_ops/sparse_invert_permute.cu
+ # src/sparse_ops/sparse_pack_segments_backward.cu
+ # src/sparse_ops/sparse_pack_segments_forward.cu
+ # src/sparse_ops/sparse_permute_1d.cu
+ # src/sparse_ops/sparse_permute_2d.cu
+ # src/sparse_ops/sparse_permute102.cu
+ # src/sparse_ops/sparse_permute_embeddings.cu
+ # src/sparse_ops/sparse_range.cu
+ # src/sparse_ops/sparse_reorder_batched_ad.cu
+ # src/sparse_ops/sparse_segment_sum_csr.cu
+ # src/sparse_ops/sparse_zipf.cu
+ # src/split_embeddings_cache/lfu_cache_find.cu
+ # src/split_embeddings_cache/lfu_cache_populate.cu
+ # src/split_embeddings_cache/lfu_cache_populate_byte.cu
+ # src/split_embeddings_cache/lru_cache_find.cu
+ # src/split_embeddings_cache/lru_cache_populate.cu
+ # src/split_embeddings_cache/lru_cache_populate_byte.cu
+ # src/split_embeddings_cache/lxu_cache.cu
+ # src/split_embeddings_cache/linearize_cache_indices.cu
+ # src/split_embeddings_cache/reset_weight_momentum.cu
+ # src/split_embeddings_utils/generate_vbe_metadata.cu
+ # src/split_embeddings_utils/get_infos_metadata.cu
+ # src/split_embeddings_utils/radix_sort_pairs.cu
+ # src/split_embeddings_utils/transpose_embedding_input.cu)
+ )
set_source_files_properties(${fbgemm_gpu_sources_static_gpu}
PROPERTIES COMPILE_OPTIONS
diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
index 01f1d6ab..a6b8d7a8 100644
--- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
+++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
@@ -25,23 +25,24 @@ set(fbgemm_sources_include_directories
${THIRDPARTY}/json/include
${NCCL_INCLUDE_DIRS})
-set(attention_ops_sources
- src/attention/attention.cpp
- src/attention/gqa_attn_splitk.cu)
+# set(attention_ops_sources
+# src/attention/attention.cpp
+# src/attention/gqa_attn_splitk.cu)
set(quantize_ops_sources
src/quantize/cutlass_extensions.cu
src/quantize/quantize.cu
src/quantize/quantize.cpp)
-set(comm_ops_sources
- src/comm/car.cu
- src/comm/car.cpp)
+# set(comm_ops_sources
+# src/comm/car.cu
+# src/comm/car.cpp)
set(experimental_gen_ai_cpp_source_files
- ${attention_ops_sources}
+ # ${attention_ops_sources}
${quantize_ops_sources}
- ${comm_ops_sources})
+ # ${comm_ops_sources}
+)
set_source_files_properties(${experimental_gen_ai_cpp_source_files}
PROPERTIES INCLUDE_DIRECTORIES

11
server/fix_torch90a.sh Executable file
View File

@ -0,0 +1,11 @@
#!/bin/bash
# This script is required to patch torch < 2.4
# It adds the 90a cuda target (H100)
# This target is required to build FBGEMM kernels
torch_cuda_arch=$(python -c "import torch; print(torch.__file__)" | sed 's/\/__init__.py//; s|$|/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake|')
sed -i '189s/\[0-9]\\\\\.\[0-9](/[0-9]\\\\.[0-9]a?(/' $torch_cuda_arch
sed -i '245s/\[0-9()]+\+"/[0-9()]+a?"/' $torch_cuda_arch
sed -i '246s/\[0-9]+\+"/[0-9]+a?"/' $torch_cuda_arch

View File

@ -8,6 +8,7 @@ from typing import Optional
from enum import Enum from enum import Enum
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from text_generation_server.utils.log import log_master
app = typer.Typer() app = typer.Typer()
@ -87,15 +88,17 @@ def serve(
) )
if len(lora_adapter_ids) > 0: if len(lora_adapter_ids) > 0:
logger.warning( log_master(
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected." logger.warning,
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.",
) )
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
# and warn the user # and warn the user
if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None: if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
logger.warning( log_master(
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs." logger.warning,
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.",
) )
global CUDA_GRAPHS global CUDA_GRAPHS
CUDA_GRAPHS = None CUDA_GRAPHS = None

View File

@ -3,6 +3,7 @@ import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master
from loguru import logger from loguru import logger
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
@ -136,7 +137,10 @@ if ENGINE != "triton":
try: try:
import flash_attn_2_cuda import flash_attn_2_cuda
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") log_master(
logger.info,
"ROCm: using Flash Attention 2 Composable Kernel implementation.",
)
except ImportError as e: except ImportError as e:
if major >= 8: if major >= 8:
architecture_suffix = f"-{SYSTEM}" architecture_suffix = f"-{SYSTEM}"

View File

@ -4,19 +4,11 @@ from functools import lru_cache
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
from bitsandbytes.nn import Int8Params, Params4bit from bitsandbytes.nn import Int8Params, Params4bit
from loguru import logger from text_generation_server.utils.weights import UnquantizedWeight
from text_generation_server.utils.weights import Weight
@lru_cache(1)
def warn_deprecate_bnb():
logger.warning(
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
)
@dataclass @dataclass
class BNBWeight(Weight): class BNBWeight(UnquantizedWeight):
weight: torch.Tensor weight: torch.Tensor
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
@ -82,7 +74,7 @@ class Linear8bitLt(torch.nn.Module):
@dataclass @dataclass
class BNBFP4Weight(Weight): class BNBFP4Weight(UnquantizedWeight):
weight: torch.Tensor weight: torch.Tensor
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
@ -90,7 +82,7 @@ class BNBFP4Weight(Weight):
@dataclass @dataclass
class BNBNF4Weight(Weight): class BNBNF4Weight(UnquantizedWeight):
weight: torch.Tensor weight: torch.Tensor
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):

View File

@ -2,11 +2,11 @@ from dataclasses import dataclass
import torch import torch
from EETQ import quant_weights, w8_a16_gemm from EETQ import quant_weights, w8_a16_gemm
from text_generation_server.utils.weights import Weight from text_generation_server.utils.weights import UnquantizedWeight
@dataclass @dataclass
class EETQWeight(Weight): class EETQWeight(UnquantizedWeight):
weight: torch.Tensor weight: torch.Tensor
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):

View File

@ -1,8 +1,29 @@
from dataclasses import dataclass
import torch import torch
from dataclasses import dataclass
from typing import Optional, Union, List
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weight from text_generation_server.utils.weights import (
Weight,
WeightsLoader,
UnquantizedWeight,
Weights,
)
from text_generation_server.utils.log import log_master, log_once
FBGEMM_MM_AVAILABLE = False
FBGEMM_DYN_AVAILABLE = False
try:
import fbgemm_gpu.experimental.gen_ai
if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability()
FBGEMM_MM_AVAILABLE = major == 9
FBGEMM_DYN_AVAILABLE = major >= 8
except (ImportError, ModuleNotFoundError):
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
def get_fp8_linear() -> torch.nn.Module: def get_fp8_linear() -> torch.nn.Module:
@ -21,12 +42,17 @@ def get_fp8_linear() -> torch.nn.Module:
return Fp8Linear return Fp8Linear
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
device = weight.device if FBGEMM_DYN_AVAILABLE:
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
)
return qweight, scale
# weight, scale = quant_weights(weight, torch.int8, False) # weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype) finfo = torch.finfo(qdtype)
# Calculate the scale as dtype max divided by absmax # Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12) scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
# scale and clamp the tensor to bring it to # scale and clamp the tensor to bring it to
# the representative range of float8 data type # the representative range of float8 data type
# (as default cast is unsaturated) # (as default cast is unsaturated)
@ -38,27 +64,166 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
return qweight, scale return qweight, scale
class HybridFP8UnquantLoader(WeightsLoader):
"""Weight loader that loads FP8 and unquantized Torch tensors."""
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
self.activation_scale_ub = activation_scale_ub
self.to_fp8 = to_fp8
def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight")
if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
w = torch.cat(w, dim=dim)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = [
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
for p in prefixes
]
scale = torch.cat(scale, dim=0)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
@dataclass @dataclass
class Fp8Weight(Weight): class Fp8Weight(Weight):
weight: torch.Tensor weight: torch.Tensor
dtype: torch.dtype
weight_scale: Optional[torch.Tensor] = None
activation_scale_ub: Optional[float] = None
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
return get_fp8_linear()(self.weight, bias) if self.weight_scale is None:
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
)
class Fp8Linear(torch.nn.Module): class Fp8Linear(torch.nn.Module):
def __init__( def __init__(
self, self,
weight, qweight,
scale,
scale_upper_bound,
bias, bias,
dtype,
) -> None: ) -> None:
super().__init__() super().__init__()
self.dtype = weight.dtype self.dtype = dtype
self.qweight, self.scale = fp8_quantize(weight) self.qweight = qweight
self.scale = scale
self.scale_upper_bound = (
torch.tensor(
[scale_upper_bound], dtype=torch.float32, device=qweight.device
)
if scale_upper_bound is not None
else None
)
self.bias = bias if bias is not None else None self.bias = bias if bias is not None else None
@classmethod
def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight)
return cls(
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
)
@classmethod
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
return cls(
qweight=weight,
scale=scale,
scale_upper_bound=input_scale,
bias=bias,
dtype=dtype,
)
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if FBGEMM_MM_AVAILABLE:
qinput, scale = fp8_quantize(
input, scale_upper_bound=self.scale_upper_bound
)
y = torch.ops.fbgemm.f8f8bf16_rowwise(
qinput,
self.qweight,
scale,
self.scale,
use_fast_accum=True,
bias=self.bias,
)
return y.to(self.dtype)
qinput, scale = fp8_quantize(input) qinput, scale = fp8_quantize(input)
output, _ = torch._scaled_mm( output, _ = torch._scaled_mm(
qinput, qinput,

View File

@ -9,11 +9,12 @@ from loguru import logger
from text_generation_server.layers.exl2 import Exl2Weight from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils.log import log_master
try: try:
from exllamav2_kernels import make_q_matrix, gemm_half_q_half from exllamav2_kernels import make_q_matrix, gemm_half_q_half
except ImportError: except ImportError:
logger.error("exllamav2_kernels not installed.") log_master(logger.warning, "exllamav2_kernels not installed.")
raise raise
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension

View File

@ -503,7 +503,8 @@ class GPTQMarlinFP8Linear(nn.Module):
def __init__( def __init__(
self, self,
weight: torch.Tensor, qweight: torch.Tensor,
scale: torch.Tensor,
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
) -> None: ) -> None:
super().__init__() super().__init__()
@ -513,7 +514,6 @@ class GPTQMarlinFP8Linear(nn.Module):
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
qweight, scale = fp8_quantize(weight)
scale = scale.to(torch.float16) scale = scale.to(torch.float16)
qweight, scales = repack_fp8_for_marlin(qweight, scale) qweight, scales = repack_fp8_for_marlin(qweight, scale)
@ -529,6 +529,15 @@ class GPTQMarlinFP8Linear(nn.Module):
out_features // 64 * 16, dtype=torch.int, device=qweight.device out_features // 64 * 16, dtype=torch.int, device=qweight.device
) )
@classmethod
def from_unquant(cls, weight, bias, _dtype):
qweight, scale = fp8_quantize(weight)
return cls(qweight=qweight, scale=scale, bias=bias)
@classmethod
def from_fp8(cls, weight, scale, _input_scale, bias, _dtype):
return cls(qweight=weight, scale=scale, bias=bias)
def forward(self, A: torch.Tensor) -> torch.Tensor: def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None assert marlin_kernels is not None

View File

@ -34,6 +34,7 @@ from text_generation_server.models.custom_modeling.t5_modeling import (
) )
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_master
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later. # in PyTorch 1.12 and later.
@ -47,9 +48,7 @@ torch.set_grad_enabled(False)
__all__ = [ __all__ = [
"Model", "Model",
"BLOOMSharded",
"CausalLM", "CausalLM",
"GalacticaSharded",
"Seq2SeqLM", "Seq2SeqLM",
"get_model", "get_model",
] ]
@ -125,7 +124,7 @@ try:
) )
from text_generation_server.layers.attention import SUPPORTS_WINDOWING from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e: except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}") log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
SUPPORTS_WINDOWING = False SUPPORTS_WINDOWING = False
FLASH_ATTENTION = False FLASH_ATTENTION = False
@ -137,7 +136,7 @@ MAMBA_AVAILABLE = True
try: try:
from text_generation_server.models.mamba import Mamba from text_generation_server.models.mamba import Mamba
except ImportError as e: except ImportError as e:
logger.warning(f"Could not import Mamba: {e}") log_master(logger.warning, f"Could not import Mamba: {e}")
MAMBA_AVAILABLE = False MAMBA_AVAILABLE = False
if MAMBA_AVAILABLE: if MAMBA_AVAILABLE:
@ -311,6 +310,12 @@ def get_model(
if quantize in ["awq", "exl2", "gptq", "marlin"]: if quantize in ["awq", "exl2", "gptq", "marlin"]:
# These quantizers only work with float16 params. # These quantizers only work with float16 params.
dtype = torch.float16 dtype = torch.float16
elif quantize == "fp8":
from text_generation_server.layers.fp8 import FBGEMM_MM_AVAILABLE
if FBGEMM_MM_AVAILABLE:
# fbgemm kernels are fp8xfp8->bf16
dtype = torch.bfloat16
else: else:
# Keep it as default for now and let # Keep it as default for now and let
# every model resolve their own default dtype. # every model resolve their own default dtype.
@ -433,7 +438,9 @@ def get_model(
speculate = get_speculate() speculate = get_speculate()
if speculate > 0: if speculate > 0:
logger.info(f"Using speculation {method} with {speculate} input ids.") log_master(
logger.info, f"Using speculation {method} with {speculate} input ids."
)
if model_type is None: if model_type is None:
# TODO: fix how we determine model type for Mamba # TODO: fix how we determine model type for Mamba
@ -448,10 +455,10 @@ def get_model(
if quantization_config is not None and quantize is None: if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None) method = quantization_config.get("quant_method", None)
if method in {"gptq", "awq", "exl2"}: if method in {"gptq", "awq", "exl2"}:
logger.info(f"Auto selecting quantization method {method}") log_master(logger.info, f"Auto selecting quantization method {method}")
quantize = method quantize = method
else: else:
logger.info(f"Unknown quantization method {method}") log_master(logger.warning, f"Unknown quantization method {method}")
if quantize == "exl2" and sharded: if quantize == "exl2" and sharded:
raise RuntimeError( raise RuntimeError(
@ -593,7 +600,7 @@ def get_model(
) )
except RuntimeError as e: except RuntimeError as e:
# Lots of legacy models with various weight names. # Lots of legacy models with various weight names.
logger.warning(f"Couldn't load flash gpt2 variant: {e}") log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
return CausalLM.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,

View File

@ -33,7 +33,6 @@ from text_generation_server.layers.attention import (
attention, attention,
reshape_and_cache, reshape_and_cache,
) )
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -42,16 +41,15 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear, TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear, TensorParallelAdapterRowLinear,
) )
from text_generation_server.layers.fp8 import Fp8Weight
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight, UnquantizedWeight,
Weights, Weights,
) )
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
if SYSTEM == "rocm": if SYSTEM == "rocm":
try: try:
@ -113,12 +111,12 @@ def load_attention(config, prefix: str, weights, layer_id):
@contextmanager @contextmanager
def no_fp8(weights: Weights): def no_fp8(weights: Weights):
"""De-activate fp8 auto conversion for the duration of this context manager"""
weights_loader = weights.weights_loader weights_loader = weights.weights_loader
if ( if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8:
isinstance(weights_loader, DefaultWeightsLoader) weights_loader = HybridFP8UnquantLoader(
and weights_loader.weight_class is Fp8Weight weights_loader.activation_scale_ub, to_fp8=False
): )
weights_loader = DefaultWeightsLoader(UnquantizedWeight)
with weights.use_loader(weights_loader): with weights.use_loader(weights_loader):
yield yield
@ -418,7 +416,22 @@ class FlashLlamaModel(torch.nn.Module):
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.layers = nn.ModuleList(
# Skip fp8 quant for first and last layers
self.layers = nn.ModuleList()
with no_fp8(weights):
self.layers.append(
FlashLlamaLayer(
index=0,
prefix=(
"model.layers.0" if not prefix else "{prefix}.model.layers.0"
),
config=config,
weights=weights,
)
)
self.layers.extend(
[ [
FlashLlamaLayer( FlashLlamaLayer(
index=layer_id, index=layer_id,
@ -430,9 +443,26 @@ class FlashLlamaModel(torch.nn.Module):
config=config, config=config,
weights=weights, weights=weights,
) )
for layer_id in range(config.num_hidden_layers) # Skip first and last layers
for layer_id in range(1, config.num_hidden_layers - 1)
] ]
) )
with no_fp8(weights):
last_layer_id = config.num_hidden_layers - 1
self.layers.append(
FlashLlamaLayer(
index=last_layer_id,
prefix=(
f"model.layers.{last_layer_id}"
if not prefix
else f"{prefix}.model.layers.{last_layer_id}"
),
config=config,
weights=weights,
)
)
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="model.norm" if not prefix else f"{prefix}.model.norm", prefix="model.norm" if not prefix else f"{prefix}.model.norm",
weights=weights, weights=weights,

View File

@ -23,14 +23,13 @@ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.log import log_master
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.dist import RANK
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights, Weights,
hub,
) )
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
@ -1156,31 +1155,36 @@ class FlashCausalLM(Model):
f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
) )
logger.info( log_master(
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`." logger.info,
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
) )
if os.path.isfile(tunableop_filepath): if os.path.isfile(tunableop_filepath):
logger.info( log_master(
f"The file {tunableop_filepath} already exists and will be reused." logger.info,
f"The file {tunableop_filepath} already exists and will be reused.",
) )
torch.cuda.tunable.read_file(tunableop_filepath) torch.cuda.tunable.read_file(tunableop_filepath)
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True) os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
for seqlen in tuning_sequences: for seqlen in tuning_sequences:
logger.info(f"Warming up TunableOp for seqlen={seqlen}") log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
self.tunableop_warmup(seqlen) self.tunableop_warmup(seqlen)
torch.cuda.tunable.write_file(tunableop_filepath) torch.cuda.tunable.write_file(tunableop_filepath)
torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.tuning_enable(False)
else: else:
logger.info( log_master(
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp." logger.info,
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.",
) )
if CUDA_GRAPHS: if CUDA_GRAPHS:
try: try:
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") log_master(
logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
)
# Warmup cuda graphs # Warmup cuda graphs
for bs in CUDA_GRAPHS: for bs in CUDA_GRAPHS:
if self.speculate is None or self.speculate + 1 <= bs: if self.speculate is None or self.speculate + 1 <= bs:
@ -1188,7 +1192,9 @@ class FlashCausalLM(Model):
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:
logger.exception(f"Decode cuda graph warmup failed") logger.exception(f"Decode cuda graph warmup failed")
else: else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") log_master(
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
)
return int(num_blocks * BLOCK_SIZE) return int(num_blocks * BLOCK_SIZE)
@ -1540,8 +1546,7 @@ class FlashCausalLM(Model):
left = 0 left = 0
if n_accepted_ids > 1: if n_accepted_ids > 1:
if RANK == 0: log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}")
logger.debug(f"Speculated ids {n_accepted_ids - 1}")
current_stopped = False current_stopped = False
for j in range(index, index + n_accepted_ids): for j in range(index, index + n_accepted_ids):

View File

@ -1,15 +1,16 @@
import torch import torch
import os import os
from loguru import logger from loguru import logger
from typing import Dict from typing import Dict, Optional
from text_generation_server.utils.log import log_master
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli # This is overridden by the cli
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
if FLASH_DECODING: if FLASH_DECODING:
logger.info("Using FLASH_DECODING") log_master(logger.info, "Using FLASH_DECODING")
cuda_graphs = os.getenv("CUDA_GRAPHS") cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None: if cuda_graphs is not None:
@ -26,11 +27,9 @@ else:
if cuda_graphs is not None: if cuda_graphs is not None:
cuda_graphs.sort(reverse=True) cuda_graphs.sort(reverse=True)
CUDA_GRAPHS = cuda_graphs CUDA_GRAPHS = cuda_graphs
# This is overridden at model loading. # This is overridden at model loading.
global MODEL_ID
MODEL_ID = None MODEL_ID = None
@ -41,8 +40,7 @@ def set_model_id(model_id: str):
# NOTE: eventually we should move this into the router and pass back the # NOTE: eventually we should move this into the router and pass back the
# index in all cases. # index in all cases.
global ADAPTER_TO_INDEX ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
ADAPTER_TO_INDEX: Dict[str, int] = None
def set_adapter_to_index(adapter_to_index: Dict[str, int]): def set_adapter_to_index(adapter_to_index: Dict[str, int]):

View File

@ -15,6 +15,7 @@ from text_generation_server.utils.adapter import (
AdapterParameters, AdapterParameters,
AdapterSource, AdapterSource,
) )
from text_generation_server.utils.log import log_master
from loguru import logger from loguru import logger
@ -204,8 +205,9 @@ class Model(ABC):
f"order to use the dynamic adapter loading feature." f"order to use the dynamic adapter loading feature."
) )
logger.info( log_master(
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}" logger.info,
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}",
) )
weight_names = tuple([v[0] for v in self.target_to_layer.values()]) weight_names = tuple([v[0] for v in self.target_to_layer.values()])
( (
@ -240,8 +242,9 @@ class Model(ABC):
layer_weights.add_adapter(adapter_index, adapter_weights) layer_weights.add_adapter(adapter_index, adapter_weights)
if len(unused_weight_names) > 0: if len(unused_weight_names) > 0:
logger.warning( log_master(
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" logger.warning,
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}",
) )
if adapter_tokenizer is not None: if adapter_tokenizer is not None:

View File

@ -1,4 +1,3 @@
from itertools import repeat
import torch import torch
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
@ -13,6 +12,7 @@ from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch, FlashCausalLMBatch,
FlashCausalLM, FlashCausalLM,
) )
from text_generation_server.utils.log import log_master
from transformers import AutoProcessor from transformers import AutoProcessor
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -56,8 +56,9 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
num_features = get_number_of_features(height, width, config) num_features = get_number_of_features(height, width, config)
from loguru import logger from loguru import logger
logger.info( log_master(
f"Found {num_features} features in image of resolution {height}x{width}" logger.info,
f"Found {num_features} features in image of resolution {height}x{width}",
) )
return "<image>" * num_features return "<image>" * num_features

View File

@ -56,7 +56,7 @@ def initialize_torch_distributed():
backend = "nccl" backend = "nccl"
options = ProcessGroupNCCL.Options() options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60) options._timeout = timedelta(seconds=120)
else: else:
backend = "gloo" backend = "gloo"
options = None options = None
@ -76,7 +76,7 @@ def initialize_torch_distributed():
backend="ccl", backend="ccl",
world_size=WORLD_SIZE, world_size=WORLD_SIZE,
rank=RANK, rank=RANK,
timeout=timedelta(seconds=60), timeout=timedelta(seconds=120),
pg_options=options, pg_options=options,
) )
else: else:
@ -84,7 +84,7 @@ def initialize_torch_distributed():
backend=backend, backend=backend,
world_size=WORLD_SIZE, world_size=WORLD_SIZE,
rank=RANK, rank=RANK,
timeout=timedelta(seconds=60), timeout=timedelta(seconds=120),
pg_options=options, pg_options=options,
) )
else: else:

View File

@ -1,6 +1,15 @@
from functools import lru_cache from functools import lru_cache
from text_generation_server.utils.dist import RANK
@lru_cache(10) @lru_cache(10)
def log_once(log, msg: str): def log_once(log, msg: str, master=True):
log(msg) if master:
log_master(log, msg)
else:
log(msg)
def log_master(log, msg: str):
if RANK == 0:
log(msg)

View File

@ -11,6 +11,7 @@ from text_generation_server.utils.weights import (
) )
# TODO: Split this config to have a single config type per quant method
@dataclass @dataclass
class _QuantizerConfig: class _QuantizerConfig:
bits: int bits: int
@ -21,6 +22,11 @@ class _QuantizerConfig:
sym: bool sym: bool
@dataclass
class _FP8QuantizerConfig:
activation_scale_ub: float
# We should probably do this with Pytantic JSON deserialization, # We should probably do this with Pytantic JSON deserialization,
# but for now we'll stay close to the old _set_gptq_params. # but for now we'll stay close to the old _set_gptq_params.
def _get_quantizer_config(model_id, revision): def _get_quantizer_config(model_id, revision):
@ -39,6 +45,13 @@ def _get_quantizer_config(model_id, revision):
filename = hf_hub_download(model_id, filename=filename, revision=revision) filename = hf_hub_download(model_id, filename=filename, revision=revision)
with open(filename, "r") as f: with open(filename, "r") as f:
data = json.load(f) data = json.load(f)
# FP8 config
if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
return _FP8QuantizerConfig(
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
)
bits = data["quantization_config"]["bits"] bits = data["quantization_config"]["bits"]
groupsize = data["quantization_config"]["group_size"] groupsize = data["quantization_config"]["group_size"]
# Order is important here, desc_act is missing on some real models # Order is important here, desc_act is missing on some real models
@ -99,6 +112,12 @@ def get_loader(
if quantize in {"awq", "gptq"}: if quantize in {"awq", "gptq"}:
from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.gptq import GPTQWeightsLoader
# TODO: improve check once we have one config type per quantize value
if not isinstance(quantizer_config, _QuantizerConfig):
raise ValueError(
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
)
return GPTQWeightsLoader( return GPTQWeightsLoader(
bits=quantizer_config.bits, bits=quantizer_config.bits,
desc_act=quantizer_config.desc_act, desc_act=quantizer_config.desc_act,
@ -127,18 +146,28 @@ def get_loader(
from text_generation_server.layers.exl2 import Exl2WeightsLoader from text_generation_server.layers.exl2 import Exl2WeightsLoader
return Exl2WeightsLoader() return Exl2WeightsLoader()
elif quantize == "fp8":
from text_generation_server.layers.fp8 import Fp8Weight
return DefaultWeightsLoader(Fp8Weight)
elif quantize == "marlin": elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeightsLoader from text_generation_server.layers.marlin import MarlinWeightsLoader
# TODO: improve check once we have one config type per quantize value
if not isinstance(quantizer_config, _QuantizerConfig):
raise ValueError(
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
)
return MarlinWeightsLoader( return MarlinWeightsLoader(
bits=quantizer_config.bits, bits=quantizer_config.bits,
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
) )
elif quantize is None: elif quantize == "fp8" or quantize is None:
return DefaultWeightsLoader(UnquantizedWeight) from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
# Since the default for the quantize config is _QuantizerConfig,
# we need to add this check to not get an attribute error
activation_scale_ub = None
if isinstance(quantizer_config, _FP8QuantizerConfig):
activation_scale_ub = quantizer_config.activation_scale_ub
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8")
else: else:
raise ValueError(f"Unknown quantization method: {quantize}") raise ValueError(f"Unknown quantization method: {quantize}")

View File

@ -1,12 +1,12 @@
import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union, Type
import torch
from safetensors import safe_open from safetensors import safe_open
from dataclasses import dataclass
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
@ -84,7 +84,7 @@ class Weight(ABC):
@dataclass @dataclass
class UnquantizedWeight: class UnquantizedWeight(Weight):
weight: torch.Tensor weight: torch.Tensor
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
@ -99,7 +99,7 @@ class UnquantizedWeight:
class DefaultWeightsLoader(WeightsLoader): class DefaultWeightsLoader(WeightsLoader):
"""Weight loader that loads (unquantized) Torch tensors.""" """Weight loader that loads (unquantized) Torch tensors."""
def __init__(self, weight_class): def __init__(self, weight_class: Type[UnquantizedWeight]):
"""Create a loader. Weights will be wrapped using the given `weights_class`, """Create a loader. Weights will be wrapped using the given `weights_class`,
normally this will be `UnquantizedWeight`, but a quantizer-specific class normally this will be `UnquantizedWeight`, but a quantizer-specific class
such as `Fp8Weight` can be used to quantize the weights during loading. such as `Fp8Weight` can be used to quantize the weights during loading.
@ -208,20 +208,29 @@ class Weights:
def get_shape(self, tensor_name: str): def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape() return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str, to_device=True): def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name) tensor = f.get_tensor(tensor_name)
# Special case for gptq which shouldn't convert # Special case for gptq which shouldn't convert
# u4 which are disguised as int32. Exl2 uses int16 # u4 which are disguised as int32. Exl2 uses int16
# as well. # as well. FP8 uses torch.float8_e4m3fn
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: if (
tensor.dtype
not in [
torch.float8_e4m3fn,
torch.int16,
torch.int32,
torch.int64,
]
and to_dtype
):
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
if to_device: if to_device:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_partial_sharded(self, tensor_name: str, dim: int): def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype=True):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name) slice_ = f.get_slice(tensor_name)
@ -241,12 +250,16 @@ class Weights:
raise NotImplementedError("Let's make that generic when needed") raise NotImplementedError("Let's make that generic when needed")
# Special case for gptq which shouldn't convert # Special case for gptq which shouldn't convert
# u4 which are disguised as int32. exl2 uses int16. # u4 which are disguised as int32. exl2 uses int16.
if tensor.dtype not in (torch.int16, torch.int32): # FP8 uses torch.float8_e4m3fn.
if (
tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32)
and to_dtype
):
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_sharded(self, tensor_name: str, dim: int): def get_sharded(self, tensor_name: str, dim: int, to_dtype=True):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name) slice_ = f.get_slice(tensor_name)
@ -255,10 +268,14 @@ class Weights:
assert ( assert (
size % world_size == 0 size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards" ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim) return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype)
def get_packed_sharded( def get_packed_sharded(
self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]] self,
tensor_name: str,
dim: int,
block_sizes: Union[int, List[int]],
to_dtype=True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Get a shard from a tensor that packs multiple tensors. Get a shard from a tensor that packs multiple tensors.
@ -304,7 +321,16 @@ class Weights:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
# Avoid casting quantizer dtypes. # Avoid casting quantizer dtypes.
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: if (
tensor.dtype
not in [
torch.float8_e4m3fn,
torch.int16,
torch.int32,
torch.int64,
]
and to_dtype
):
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
return tensor return tensor