feat(fp8): use fbgemm kernels and load fp8 weights directly (#2248)
* feat(fp8): add support for fbgemm * allow loading fp8 weights directly * update outlines * fix makefile * build fbgemm * avoid circular import and fix dockerfile * add default dtype * refactored weights loader * fix auto conversion * fix quantization config parsing * force new nccl on install * missing get_weights implementation * increase timeout
This commit is contained in:
parent
e5c1d6d611
commit
53ec0b790b
|
@ -183,4 +183,3 @@ jobs:
|
||||||
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||||
echo $DOCKER_IMAGE
|
echo $DOCKER_IMAGE
|
||||||
pytest -s -vv integration-tests ${PYTEST_FLAGS}
|
pytest -s -vv integration-tests ${PYTEST_FLAGS}
|
||||||
|
|
||||||
|
|
17
Dockerfile
17
Dockerfile
|
@ -161,6 +161,17 @@ COPY server/custom_kernels/ .
|
||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN python setup.py build
|
RUN python setup.py build
|
||||||
|
|
||||||
|
# Build FBGEMM CUDA kernels
|
||||||
|
FROM kernel-builder AS fbgemm-builder
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
COPY server/Makefile-fbgemm Makefile
|
||||||
|
COPY server/fbgemm_remove_unused.patch fbgemm_remove_unused.patch
|
||||||
|
COPY server/fix_torch90a.sh fix_torch90a.sh
|
||||||
|
|
||||||
|
RUN make build-fbgemm
|
||||||
|
|
||||||
# Build vllm CUDA kernels
|
# Build vllm CUDA kernels
|
||||||
FROM kernel-builder AS vllm-builder
|
FROM kernel-builder AS vllm-builder
|
||||||
|
|
||||||
|
@ -225,10 +236,10 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31
|
||||||
# Copy build artifacts from marlin kernels builder
|
# Copy build artifacts from marlin kernels builder
|
||||||
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
# Copy build artifacts from fbgemm builder
|
||||||
# Copy builds artifacts from vllm builder
|
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages
|
||||||
|
# Copy build artifacts from vllm builder
|
||||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Copy build artifacts from mamba builder
|
# Copy build artifacts from mamba builder
|
||||||
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
|
@ -5,6 +5,7 @@ include Makefile-awq
|
||||||
include Makefile-eetq
|
include Makefile-eetq
|
||||||
include Makefile-selective-scan
|
include Makefile-selective-scan
|
||||||
include Makefile-lorax-punica
|
include Makefile-lorax-punica
|
||||||
|
include Makefile-fbgemm
|
||||||
|
|
||||||
unit-tests:
|
unit-tests:
|
||||||
pytest -s -vv -m "not private" tests
|
pytest -s -vv -m "not private" tests
|
||||||
|
@ -27,8 +28,9 @@ install-server: gen-server
|
||||||
install: install-cuda
|
install: install-cuda
|
||||||
echo "Installed server"
|
echo "Installed server"
|
||||||
|
|
||||||
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention
|
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
|
||||||
pip install -e ".[bnb]"
|
pip install -e ".[bnb]"
|
||||||
|
pip install nvidia-nccl-cu12==2.22.3
|
||||||
|
|
||||||
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
||||||
|
|
||||||
|
@ -36,5 +38,6 @@ run-dev:
|
||||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
||||||
|
|
||||||
export-requirements:
|
export-requirements:
|
||||||
poetry export -o requirements_cuda.txt --without-hashes --with cuda
|
poetry export -o requirements_cuda.txt --without-hashes
|
||||||
poetry export -o requirements_rocm.txt --without-hashes
|
poetry export -o requirements_rocm.txt --without-hashes
|
||||||
|
poetry export -o requirements_intel.txt --without-hashes
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
fbgemm_commit := 9cf0429b726931cfab72b8264730bea682f32fca
|
||||||
|
|
||||||
|
build-fbgemm:
|
||||||
|
chmod +x fix_torch90a.sh && ./fix_torch90a.sh && \
|
||||||
|
git clone https://github.com/pytorch/FBGEMM.git fbgemm && \
|
||||||
|
cp fbgemm_remove_unused.patch fbgemm && \
|
||||||
|
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && git apply fbgemm_remove_unused.patch && \
|
||||||
|
git submodule update --init --recursive && \
|
||||||
|
cd fbgemm_gpu && \
|
||||||
|
pip install -r requirements.txt && \
|
||||||
|
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build
|
||||||
|
|
||||||
|
install-fbgemm: build-fbgemm
|
||||||
|
cd fbgemm/fbgemm_gpu && \
|
||||||
|
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install
|
|
@ -0,0 +1,306 @@
|
||||||
|
diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt
|
||||||
|
index 2244ea6f..96265a48 100644
|
||||||
|
--- a/fbgemm_gpu/CMakeLists.txt
|
||||||
|
+++ b/fbgemm_gpu/CMakeLists.txt
|
||||||
|
@@ -94,14 +94,14 @@ endif()
|
||||||
|
# Build Experimental Modules
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
-if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
|
||||||
|
- # TODO: Figure out NCCL/RCCL integration with ROCm
|
||||||
|
- add_subdirectory(experimental/example)
|
||||||
|
-endif()
|
||||||
|
-
|
||||||
|
-if(NOT FBGEMM_CPU_ONLY)
|
||||||
|
- add_subdirectory(experimental/gemm)
|
||||||
|
-endif()
|
||||||
|
+# if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
|
||||||
|
+# # TODO: Figure out NCCL/RCCL integration with ROCm
|
||||||
|
+# add_subdirectory(experimental/example)
|
||||||
|
+# endif()
|
||||||
|
+
|
||||||
|
+# if(NOT FBGEMM_CPU_ONLY)
|
||||||
|
+# add_subdirectory(experimental/gemm)
|
||||||
|
+# endif()
|
||||||
|
|
||||||
|
if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
|
||||||
|
# CUTLASS currently doesn't build on ROCm and CK hasnt yet been added:
|
||||||
|
diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake
|
||||||
|
index c56773fe..0c0d349e 100644
|
||||||
|
--- a/fbgemm_gpu/FbgemmGpu.cmake
|
||||||
|
+++ b/fbgemm_gpu/FbgemmGpu.cmake
|
||||||
|
@@ -446,53 +446,55 @@ set_source_files_properties(${fbgemm_sources}
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
set(fbgemm_gpu_sources_static_cpu
|
||||||
|
- codegen/training/forward/embedding_forward_split_cpu.cpp
|
||||||
|
- codegen/inference/embedding_forward_quantized_host_cpu.cpp
|
||||||
|
- codegen/training/backward/embedding_backward_dense_host_cpu.cpp
|
||||||
|
- codegen/utils/embedding_bounds_check_host_cpu.cpp
|
||||||
|
- src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
|
||||||
|
- src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp
|
||||||
|
- src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
|
||||||
|
- src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
|
||||||
|
- src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
|
||||||
|
- src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
|
||||||
|
- src/input_combine_ops/input_combine_cpu.cpp
|
||||||
|
- src/layout_transform_ops/layout_transform_ops_cpu.cpp
|
||||||
|
+ # codegen/training/forward/embedding_forward_split_cpu.cpp
|
||||||
|
+ # codegen/inference/embedding_forward_quantized_host_cpu.cpp
|
||||||
|
+ # codegen/training/backward/embedding_backward_dense_host_cpu.cpp
|
||||||
|
+ # codegen/utils/embedding_bounds_check_host_cpu.cpp
|
||||||
|
+ # src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
|
||||||
|
+ # src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp
|
||||||
|
+ # src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
|
||||||
|
+ # src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
|
||||||
|
+ # src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
|
||||||
|
+ # src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
|
||||||
|
+ # src/input_combine_ops/input_combine_cpu.cpp
|
||||||
|
+ # src/layout_transform_ops/layout_transform_ops_cpu.cpp
|
||||||
|
src/quantize_ops/quantize_ops_cpu.cpp
|
||||||
|
src/quantize_ops/quantize_ops_meta.cpp
|
||||||
|
- src/sparse_ops/sparse_ops_cpu.cpp
|
||||||
|
- src/sparse_ops/sparse_ops_meta.cpp
|
||||||
|
- src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp
|
||||||
|
- src/split_embeddings_cache/linearize_cache_indices.cpp
|
||||||
|
- src/split_embeddings_cache/lfu_cache_populate_byte.cpp
|
||||||
|
- src/split_embeddings_cache/lru_cache_populate_byte.cpp
|
||||||
|
- src/split_embeddings_cache/lxu_cache.cpp
|
||||||
|
- src/split_embeddings_cache/split_embeddings_cache_ops.cpp
|
||||||
|
- codegen/training/index_select/batch_index_select_dim0_ops.cpp
|
||||||
|
- codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)
|
||||||
|
+ # src/sparse_ops/sparse_ops_cpu.cpp
|
||||||
|
+ # src/sparse_ops/sparse_ops_meta.cpp
|
||||||
|
+ # src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp
|
||||||
|
+ # src/split_embeddings_cache/linearize_cache_indices.cpp
|
||||||
|
+ # src/split_embeddings_cache/lfu_cache_populate_byte.cpp
|
||||||
|
+ # src/split_embeddings_cache/lru_cache_populate_byte.cpp
|
||||||
|
+ # src/split_embeddings_cache/lxu_cache.cpp
|
||||||
|
+ # src/split_embeddings_cache/split_embeddings_cache_ops.cpp
|
||||||
|
+ # codegen/training/index_select/batch_index_select_dim0_ops.cpp
|
||||||
|
+ # codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)
|
||||||
|
+)
|
||||||
|
|
||||||
|
if(NOT FBGEMM_CPU_ONLY)
|
||||||
|
list(APPEND fbgemm_gpu_sources_static_cpu
|
||||||
|
- codegen/inference/embedding_forward_quantized_host.cpp
|
||||||
|
- codegen/utils/embedding_bounds_check_host.cpp
|
||||||
|
- src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp
|
||||||
|
- src/layout_transform_ops/layout_transform_ops_gpu.cpp
|
||||||
|
- src/memory_utils/memory_utils.cpp
|
||||||
|
- src/memory_utils/memory_utils_ops.cpp
|
||||||
|
- src/memory_utils/memory_utils_ops_cpu.cpp
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp
|
||||||
|
+ # codegen/inference/embedding_forward_quantized_host.cpp
|
||||||
|
+ # codegen/utils/embedding_bounds_check_host.cpp
|
||||||
|
+ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp
|
||||||
|
+ # src/layout_transform_ops/layout_transform_ops_gpu.cpp
|
||||||
|
+ # src/memory_utils/memory_utils.cpp
|
||||||
|
+ # src/memory_utils/memory_utils_ops.cpp
|
||||||
|
+ # src/memory_utils/memory_utils_ops_cpu.cpp
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp
|
||||||
|
src/quantize_ops/quantize_ops_gpu.cpp
|
||||||
|
- src/sparse_ops/sparse_ops_gpu.cpp
|
||||||
|
- src/split_embeddings_utils/split_embeddings_utils.cpp
|
||||||
|
- src/split_embeddings_cache/split_embeddings_cache_ops.cu
|
||||||
|
- src/metric_ops/metric_ops_host.cpp
|
||||||
|
- src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp
|
||||||
|
- src/input_combine_ops/input_combine_gpu.cpp
|
||||||
|
- codegen/training/index_select/batch_index_select_dim0_host.cpp)
|
||||||
|
+ # src/sparse_ops/sparse_ops_gpu.cpp
|
||||||
|
+ # src/split_embeddings_utils/split_embeddings_utils.cpp
|
||||||
|
+ # src/split_embeddings_cache/split_embeddings_cache_ops.cu
|
||||||
|
+ # src/metric_ops/metric_ops_host.cpp
|
||||||
|
+ # src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp
|
||||||
|
+ # src/input_combine_ops/input_combine_gpu.cpp
|
||||||
|
+ # codegen/training/index_select/batch_index_select_dim0_host.cpp)
|
||||||
|
+ )
|
||||||
|
|
||||||
|
if(NVML_LIB_PATH OR USE_ROCM)
|
||||||
|
message(STATUS "Adding merge_pooled_embeddings sources")
|
||||||
|
@@ -516,36 +518,36 @@ endif()
|
||||||
|
|
||||||
|
if(NOT FBGEMM_CPU_ONLY)
|
||||||
|
set(fbgemm_gpu_sources_static_gpu
|
||||||
|
- codegen/utils/embedding_bounds_check.cu
|
||||||
|
- codegen/inference/embedding_forward_quantized_split_lookup.cu
|
||||||
|
- src/embedding_inplace_ops/embedding_inplace_update.cu
|
||||||
|
- src/histogram_binning_calibration_ops.cu
|
||||||
|
- src/input_combine_ops/input_combine.cu
|
||||||
|
- src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu
|
||||||
|
- src/memory_utils/memory_utils.cu
|
||||||
|
- src/memory_utils/memory_utils_ops.cu
|
||||||
|
- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu
|
||||||
|
- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu
|
||||||
|
- src/jagged_tensor_ops/dense_to_jagged_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_dense_bmm_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_index_add_2d_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_index_select_2d_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_softmax_backward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_softmax_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_tensor_ops.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_unique_indices.cu
|
||||||
|
- src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
|
||||||
|
- src/layout_transform_ops/layout_transform_ops.cu
|
||||||
|
- src/metric_ops/metric_ops.cu
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu
|
||||||
|
- src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu
|
||||||
|
+ # codegen/utils/embedding_bounds_check.cu
|
||||||
|
+ # codegen/inference/embedding_forward_quantized_split_lookup.cu
|
||||||
|
+ # src/embedding_inplace_ops/embedding_inplace_update.cu
|
||||||
|
+ # src/histogram_binning_calibration_ops.cu
|
||||||
|
+ # src/input_combine_ops/input_combine.cu
|
||||||
|
+ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu
|
||||||
|
+ # src/memory_utils/memory_utils.cu
|
||||||
|
+ # src/memory_utils/memory_utils_ops.cu
|
||||||
|
+ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu
|
||||||
|
+ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/dense_to_jagged_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_dense_bmm_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_index_add_2d_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_index_select_2d_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_softmax_backward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_softmax_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_tensor_ops.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_unique_indices.cu
|
||||||
|
+ # src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
|
||||||
|
+ # src/layout_transform_ops/layout_transform_ops.cu
|
||||||
|
+ # src/metric_ops/metric_ops.cu
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu
|
||||||
|
+ # src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu
|
||||||
|
src/quantize_ops/quantize_bfloat16.cu
|
||||||
|
src/quantize_ops/quantize_fp8_rowwise.cu
|
||||||
|
src/quantize_ops/quantize_fused_8bit_rowwise.cu
|
||||||
|
@@ -554,39 +556,40 @@ if(NOT FBGEMM_CPU_ONLY)
|
||||||
|
src/quantize_ops/quantize_msfp.cu
|
||||||
|
src/quantize_ops/quantize_padded_fp8_rowwise.cu
|
||||||
|
src/quantize_ops/quantize_mx.cu
|
||||||
|
- src/sparse_ops/sparse_async_cumsum.cu
|
||||||
|
- src/sparse_ops/sparse_block_bucketize_features.cu
|
||||||
|
- src/sparse_ops/sparse_bucketize_features.cu
|
||||||
|
- src/sparse_ops/sparse_batched_unary_embeddings.cu
|
||||||
|
- src/sparse_ops/sparse_compute_frequency_sequence.cu
|
||||||
|
- src/sparse_ops/sparse_expand_into_jagged_permute.cu
|
||||||
|
- src/sparse_ops/sparse_group_index.cu
|
||||||
|
- src/sparse_ops/sparse_index_add.cu
|
||||||
|
- src/sparse_ops/sparse_index_select.cu
|
||||||
|
- src/sparse_ops/sparse_invert_permute.cu
|
||||||
|
- src/sparse_ops/sparse_pack_segments_backward.cu
|
||||||
|
- src/sparse_ops/sparse_pack_segments_forward.cu
|
||||||
|
- src/sparse_ops/sparse_permute_1d.cu
|
||||||
|
- src/sparse_ops/sparse_permute_2d.cu
|
||||||
|
- src/sparse_ops/sparse_permute102.cu
|
||||||
|
- src/sparse_ops/sparse_permute_embeddings.cu
|
||||||
|
- src/sparse_ops/sparse_range.cu
|
||||||
|
- src/sparse_ops/sparse_reorder_batched_ad.cu
|
||||||
|
- src/sparse_ops/sparse_segment_sum_csr.cu
|
||||||
|
- src/sparse_ops/sparse_zipf.cu
|
||||||
|
- src/split_embeddings_cache/lfu_cache_find.cu
|
||||||
|
- src/split_embeddings_cache/lfu_cache_populate.cu
|
||||||
|
- src/split_embeddings_cache/lfu_cache_populate_byte.cu
|
||||||
|
- src/split_embeddings_cache/lru_cache_find.cu
|
||||||
|
- src/split_embeddings_cache/lru_cache_populate.cu
|
||||||
|
- src/split_embeddings_cache/lru_cache_populate_byte.cu
|
||||||
|
- src/split_embeddings_cache/lxu_cache.cu
|
||||||
|
- src/split_embeddings_cache/linearize_cache_indices.cu
|
||||||
|
- src/split_embeddings_cache/reset_weight_momentum.cu
|
||||||
|
- src/split_embeddings_utils/generate_vbe_metadata.cu
|
||||||
|
- src/split_embeddings_utils/get_infos_metadata.cu
|
||||||
|
- src/split_embeddings_utils/radix_sort_pairs.cu
|
||||||
|
- src/split_embeddings_utils/transpose_embedding_input.cu)
|
||||||
|
+ # src/sparse_ops/sparse_async_cumsum.cu
|
||||||
|
+ # src/sparse_ops/sparse_block_bucketize_features.cu
|
||||||
|
+ # src/sparse_ops/sparse_bucketize_features.cu
|
||||||
|
+ # src/sparse_ops/sparse_batched_unary_embeddings.cu
|
||||||
|
+ # src/sparse_ops/sparse_compute_frequency_sequence.cu
|
||||||
|
+ # src/sparse_ops/sparse_expand_into_jagged_permute.cu
|
||||||
|
+ # src/sparse_ops/sparse_group_index.cu
|
||||||
|
+ # src/sparse_ops/sparse_index_add.cu
|
||||||
|
+ # src/sparse_ops/sparse_index_select.cu
|
||||||
|
+ # src/sparse_ops/sparse_invert_permute.cu
|
||||||
|
+ # src/sparse_ops/sparse_pack_segments_backward.cu
|
||||||
|
+ # src/sparse_ops/sparse_pack_segments_forward.cu
|
||||||
|
+ # src/sparse_ops/sparse_permute_1d.cu
|
||||||
|
+ # src/sparse_ops/sparse_permute_2d.cu
|
||||||
|
+ # src/sparse_ops/sparse_permute102.cu
|
||||||
|
+ # src/sparse_ops/sparse_permute_embeddings.cu
|
||||||
|
+ # src/sparse_ops/sparse_range.cu
|
||||||
|
+ # src/sparse_ops/sparse_reorder_batched_ad.cu
|
||||||
|
+ # src/sparse_ops/sparse_segment_sum_csr.cu
|
||||||
|
+ # src/sparse_ops/sparse_zipf.cu
|
||||||
|
+ # src/split_embeddings_cache/lfu_cache_find.cu
|
||||||
|
+ # src/split_embeddings_cache/lfu_cache_populate.cu
|
||||||
|
+ # src/split_embeddings_cache/lfu_cache_populate_byte.cu
|
||||||
|
+ # src/split_embeddings_cache/lru_cache_find.cu
|
||||||
|
+ # src/split_embeddings_cache/lru_cache_populate.cu
|
||||||
|
+ # src/split_embeddings_cache/lru_cache_populate_byte.cu
|
||||||
|
+ # src/split_embeddings_cache/lxu_cache.cu
|
||||||
|
+ # src/split_embeddings_cache/linearize_cache_indices.cu
|
||||||
|
+ # src/split_embeddings_cache/reset_weight_momentum.cu
|
||||||
|
+ # src/split_embeddings_utils/generate_vbe_metadata.cu
|
||||||
|
+ # src/split_embeddings_utils/get_infos_metadata.cu
|
||||||
|
+ # src/split_embeddings_utils/radix_sort_pairs.cu
|
||||||
|
+ # src/split_embeddings_utils/transpose_embedding_input.cu)
|
||||||
|
+ )
|
||||||
|
|
||||||
|
set_source_files_properties(${fbgemm_gpu_sources_static_gpu}
|
||||||
|
PROPERTIES COMPILE_OPTIONS
|
||||||
|
diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
|
||||||
|
index 01f1d6ab..a6b8d7a8 100644
|
||||||
|
--- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
|
||||||
|
+++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
|
||||||
|
@@ -25,23 +25,24 @@ set(fbgemm_sources_include_directories
|
||||||
|
${THIRDPARTY}/json/include
|
||||||
|
${NCCL_INCLUDE_DIRS})
|
||||||
|
|
||||||
|
-set(attention_ops_sources
|
||||||
|
- src/attention/attention.cpp
|
||||||
|
- src/attention/gqa_attn_splitk.cu)
|
||||||
|
+# set(attention_ops_sources
|
||||||
|
+# src/attention/attention.cpp
|
||||||
|
+# src/attention/gqa_attn_splitk.cu)
|
||||||
|
|
||||||
|
set(quantize_ops_sources
|
||||||
|
src/quantize/cutlass_extensions.cu
|
||||||
|
src/quantize/quantize.cu
|
||||||
|
src/quantize/quantize.cpp)
|
||||||
|
|
||||||
|
-set(comm_ops_sources
|
||||||
|
- src/comm/car.cu
|
||||||
|
- src/comm/car.cpp)
|
||||||
|
+# set(comm_ops_sources
|
||||||
|
+# src/comm/car.cu
|
||||||
|
+# src/comm/car.cpp)
|
||||||
|
|
||||||
|
set(experimental_gen_ai_cpp_source_files
|
||||||
|
- ${attention_ops_sources}
|
||||||
|
+ # ${attention_ops_sources}
|
||||||
|
${quantize_ops_sources}
|
||||||
|
- ${comm_ops_sources})
|
||||||
|
+ # ${comm_ops_sources}
|
||||||
|
+)
|
||||||
|
|
||||||
|
set_source_files_properties(${experimental_gen_ai_cpp_source_files}
|
||||||
|
PROPERTIES INCLUDE_DIRECTORIES
|
|
@ -0,0 +1,11 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# This script is required to patch torch < 2.4
|
||||||
|
# It adds the 90a cuda target (H100)
|
||||||
|
# This target is required to build FBGEMM kernels
|
||||||
|
|
||||||
|
torch_cuda_arch=$(python -c "import torch; print(torch.__file__)" | sed 's/\/__init__.py//; s|$|/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake|')
|
||||||
|
|
||||||
|
sed -i '189s/\[0-9]\\\\\.\[0-9](/[0-9]\\\\.[0-9]a?(/' $torch_cuda_arch
|
||||||
|
sed -i '245s/\[0-9()]+\+"/[0-9()]+a?"/' $torch_cuda_arch
|
||||||
|
sed -i '246s/\[0-9]+\+"/[0-9]+a?"/' $torch_cuda_arch
|
|
@ -8,6 +8,7 @@ from typing import Optional
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
@ -87,15 +88,17 @@ def serve(
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(lora_adapter_ids) > 0:
|
if len(lora_adapter_ids) > 0:
|
||||||
logger.warning(
|
log_master(
|
||||||
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
|
logger.warning,
|
||||||
|
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
|
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
|
||||||
# and warn the user
|
# and warn the user
|
||||||
if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
|
if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
|
||||||
logger.warning(
|
log_master(
|
||||||
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs."
|
logger.warning,
|
||||||
|
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.",
|
||||||
)
|
)
|
||||||
global CUDA_GRAPHS
|
global CUDA_GRAPHS
|
||||||
CUDA_GRAPHS = None
|
CUDA_GRAPHS = None
|
||||||
|
|
|
@ -3,6 +3,7 @@ import torch
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import FLASH_DECODING
|
from text_generation_server.models.globals import FLASH_DECODING
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
@ -136,7 +137,10 @@ if ENGINE != "triton":
|
||||||
try:
|
try:
|
||||||
import flash_attn_2_cuda
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
|
log_master(
|
||||||
|
logger.info,
|
||||||
|
"ROCm: using Flash Attention 2 Composable Kernel implementation.",
|
||||||
|
)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
if major >= 8:
|
if major >= 8:
|
||||||
architecture_suffix = f"-{SYSTEM}"
|
architecture_suffix = f"-{SYSTEM}"
|
||||||
|
|
|
@ -4,19 +4,11 @@ from functools import lru_cache
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
from bitsandbytes.nn import Int8Params, Params4bit
|
from bitsandbytes.nn import Int8Params, Params4bit
|
||||||
from loguru import logger
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
from text_generation_server.utils.weights import Weight
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(1)
|
|
||||||
def warn_deprecate_bnb():
|
|
||||||
logger.warning(
|
|
||||||
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BNBWeight(Weight):
|
class BNBWeight(UnquantizedWeight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
@ -82,7 +74,7 @@ class Linear8bitLt(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BNBFP4Weight(Weight):
|
class BNBFP4Weight(UnquantizedWeight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
@ -90,7 +82,7 @@ class BNBFP4Weight(Weight):
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BNBNF4Weight(Weight):
|
class BNBNF4Weight(UnquantizedWeight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
|
|
@ -2,11 +2,11 @@ from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from EETQ import quant_weights, w8_a16_gemm
|
from EETQ import quant_weights, w8_a16_gemm
|
||||||
from text_generation_server.utils.weights import Weight
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EETQWeight(Weight):
|
class EETQWeight(UnquantizedWeight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
|
|
@ -1,8 +1,29 @@
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Union, List
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.weights import Weight
|
from text_generation_server.utils.weights import (
|
||||||
|
Weight,
|
||||||
|
WeightsLoader,
|
||||||
|
UnquantizedWeight,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.log import log_master, log_once
|
||||||
|
|
||||||
|
FBGEMM_MM_AVAILABLE = False
|
||||||
|
FBGEMM_DYN_AVAILABLE = False
|
||||||
|
try:
|
||||||
|
import fbgemm_gpu.experimental.gen_ai
|
||||||
|
|
||||||
|
if SYSTEM == "cuda":
|
||||||
|
major, _ = torch.cuda.get_device_capability()
|
||||||
|
FBGEMM_MM_AVAILABLE = major == 9
|
||||||
|
FBGEMM_DYN_AVAILABLE = major >= 8
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
|
||||||
|
|
||||||
|
|
||||||
def get_fp8_linear() -> torch.nn.Module:
|
def get_fp8_linear() -> torch.nn.Module:
|
||||||
|
@ -21,12 +42,17 @@ def get_fp8_linear() -> torch.nn.Module:
|
||||||
return Fp8Linear
|
return Fp8Linear
|
||||||
|
|
||||||
|
|
||||||
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
|
||||||
device = weight.device
|
if FBGEMM_DYN_AVAILABLE:
|
||||||
|
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||||
|
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
||||||
|
)
|
||||||
|
return qweight, scale
|
||||||
|
|
||||||
# weight, scale = quant_weights(weight, torch.int8, False)
|
# weight, scale = quant_weights(weight, torch.int8, False)
|
||||||
finfo = torch.finfo(qdtype)
|
finfo = torch.finfo(qdtype)
|
||||||
# Calculate the scale as dtype max divided by absmax
|
# Calculate the scale as dtype max divided by absmax
|
||||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12)
|
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
|
||||||
# scale and clamp the tensor to bring it to
|
# scale and clamp the tensor to bring it to
|
||||||
# the representative range of float8 data type
|
# the representative range of float8 data type
|
||||||
# (as default cast is unsaturated)
|
# (as default cast is unsaturated)
|
||||||
|
@ -38,27 +64,166 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
||||||
return qweight, scale
|
return qweight, scale
|
||||||
|
|
||||||
|
|
||||||
|
class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
|
"""Weight loader that loads FP8 and unquantized Torch tensors."""
|
||||||
|
|
||||||
|
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
|
||||||
|
self.activation_scale_ub = activation_scale_ub
|
||||||
|
self.to_fp8 = to_fp8
|
||||||
|
|
||||||
|
def get_weights(self, weights: "Weights", prefix: str):
|
||||||
|
w = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
# FP8 branch
|
||||||
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
|
def get_weights_col_packed(
|
||||||
|
self,
|
||||||
|
weights: Weights,
|
||||||
|
prefix: str,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
):
|
||||||
|
w = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
# FP8 branch
|
||||||
|
scale = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
|
||||||
|
)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
|
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||||
|
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
|
w = torch.cat(w, dim=dim)
|
||||||
|
|
||||||
|
# FP8 branch
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
scale = [
|
||||||
|
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
|
||||||
|
for p in prefixes
|
||||||
|
]
|
||||||
|
scale = torch.cat(scale, dim=0)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
|
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||||
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
|
# FP8 branch
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Fp8Weight(Weight):
|
class Fp8Weight(Weight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
dtype: torch.dtype
|
||||||
|
weight_scale: Optional[torch.Tensor] = None
|
||||||
|
activation_scale_ub: Optional[float] = None
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
return get_fp8_linear()(self.weight, bias)
|
if self.weight_scale is None:
|
||||||
|
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
|
||||||
|
return get_fp8_linear().from_fp8(
|
||||||
|
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Fp8Linear(torch.nn.Module):
|
class Fp8Linear(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
weight,
|
qweight,
|
||||||
|
scale,
|
||||||
|
scale_upper_bound,
|
||||||
bias,
|
bias,
|
||||||
|
dtype,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtype = weight.dtype
|
self.dtype = dtype
|
||||||
self.qweight, self.scale = fp8_quantize(weight)
|
self.qweight = qweight
|
||||||
|
self.scale = scale
|
||||||
|
self.scale_upper_bound = (
|
||||||
|
torch.tensor(
|
||||||
|
[scale_upper_bound], dtype=torch.float32, device=qweight.device
|
||||||
|
)
|
||||||
|
if scale_upper_bound is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
self.bias = bias if bias is not None else None
|
self.bias = bias if bias is not None else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_unquant(cls, weight, bias, dtype):
|
||||||
|
qweight, scale = fp8_quantize(weight)
|
||||||
|
return cls(
|
||||||
|
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
|
||||||
|
return cls(
|
||||||
|
qweight=weight,
|
||||||
|
scale=scale,
|
||||||
|
scale_upper_bound=input_scale,
|
||||||
|
bias=bias,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
if FBGEMM_MM_AVAILABLE:
|
||||||
|
qinput, scale = fp8_quantize(
|
||||||
|
input, scale_upper_bound=self.scale_upper_bound
|
||||||
|
)
|
||||||
|
|
||||||
|
y = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||||
|
qinput,
|
||||||
|
self.qweight,
|
||||||
|
scale,
|
||||||
|
self.scale,
|
||||||
|
use_fast_accum=True,
|
||||||
|
bias=self.bias,
|
||||||
|
)
|
||||||
|
return y.to(self.dtype)
|
||||||
|
|
||||||
qinput, scale = fp8_quantize(input)
|
qinput, scale = fp8_quantize(input)
|
||||||
output, _ = torch._scaled_mm(
|
output, _ = torch._scaled_mm(
|
||||||
qinput,
|
qinput,
|
||||||
|
|
|
@ -9,11 +9,12 @@ from loguru import logger
|
||||||
|
|
||||||
from text_generation_server.layers.exl2 import Exl2Weight
|
from text_generation_server.layers.exl2 import Exl2Weight
|
||||||
from text_generation_server.layers.gptq import GPTQWeight
|
from text_generation_server.layers.gptq import GPTQWeight
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error("exllamav2_kernels not installed.")
|
log_master(logger.warning, "exllamav2_kernels not installed.")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||||
|
|
|
@ -503,7 +503,8 @@ class GPTQMarlinFP8Linear(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
weight: torch.Tensor,
|
qweight: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor],
|
bias: Optional[torch.Tensor],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -513,7 +514,6 @@ class GPTQMarlinFP8Linear(nn.Module):
|
||||||
|
|
||||||
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
|
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
|
||||||
|
|
||||||
qweight, scale = fp8_quantize(weight)
|
|
||||||
scale = scale.to(torch.float16)
|
scale = scale.to(torch.float16)
|
||||||
qweight, scales = repack_fp8_for_marlin(qweight, scale)
|
qweight, scales = repack_fp8_for_marlin(qweight, scale)
|
||||||
|
|
||||||
|
@ -529,6 +529,15 @@ class GPTQMarlinFP8Linear(nn.Module):
|
||||||
out_features // 64 * 16, dtype=torch.int, device=qweight.device
|
out_features // 64 * 16, dtype=torch.int, device=qweight.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_unquant(cls, weight, bias, _dtype):
|
||||||
|
qweight, scale = fp8_quantize(weight)
|
||||||
|
return cls(qweight=qweight, scale=scale, bias=bias)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_fp8(cls, weight, scale, _input_scale, bias, _dtype):
|
||||||
|
return cls(qweight=weight, scale=scale, bias=bias)
|
||||||
|
|
||||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||||
assert marlin_kernels is not None
|
assert marlin_kernels is not None
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,7 @@ from text_generation_server.models.custom_modeling.t5_modeling import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||||
# in PyTorch 1.12 and later.
|
# in PyTorch 1.12 and later.
|
||||||
|
@ -47,9 +48,7 @@ torch.set_grad_enabled(False)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Model",
|
"Model",
|
||||||
"BLOOMSharded",
|
|
||||||
"CausalLM",
|
"CausalLM",
|
||||||
"GalacticaSharded",
|
|
||||||
"Seq2SeqLM",
|
"Seq2SeqLM",
|
||||||
"get_model",
|
"get_model",
|
||||||
]
|
]
|
||||||
|
@ -125,7 +124,7 @@ try:
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
|
||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
FLASH_ATTENTION = False
|
FLASH_ATTENTION = False
|
||||||
|
|
||||||
|
@ -137,7 +136,7 @@ MAMBA_AVAILABLE = True
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.mamba import Mamba
|
from text_generation_server.models.mamba import Mamba
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Could not import Mamba: {e}")
|
log_master(logger.warning, f"Could not import Mamba: {e}")
|
||||||
MAMBA_AVAILABLE = False
|
MAMBA_AVAILABLE = False
|
||||||
|
|
||||||
if MAMBA_AVAILABLE:
|
if MAMBA_AVAILABLE:
|
||||||
|
@ -311,6 +310,12 @@ def get_model(
|
||||||
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||||
# These quantizers only work with float16 params.
|
# These quantizers only work with float16 params.
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
|
elif quantize == "fp8":
|
||||||
|
from text_generation_server.layers.fp8 import FBGEMM_MM_AVAILABLE
|
||||||
|
|
||||||
|
if FBGEMM_MM_AVAILABLE:
|
||||||
|
# fbgemm kernels are fp8xfp8->bf16
|
||||||
|
dtype = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
# Keep it as default for now and let
|
# Keep it as default for now and let
|
||||||
# every model resolve their own default dtype.
|
# every model resolve their own default dtype.
|
||||||
|
@ -433,7 +438,9 @@ def get_model(
|
||||||
|
|
||||||
speculate = get_speculate()
|
speculate = get_speculate()
|
||||||
if speculate > 0:
|
if speculate > 0:
|
||||||
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
log_master(
|
||||||
|
logger.info, f"Using speculation {method} with {speculate} input ids."
|
||||||
|
)
|
||||||
|
|
||||||
if model_type is None:
|
if model_type is None:
|
||||||
# TODO: fix how we determine model type for Mamba
|
# TODO: fix how we determine model type for Mamba
|
||||||
|
@ -448,10 +455,10 @@ def get_model(
|
||||||
if quantization_config is not None and quantize is None:
|
if quantization_config is not None and quantize is None:
|
||||||
method = quantization_config.get("quant_method", None)
|
method = quantization_config.get("quant_method", None)
|
||||||
if method in {"gptq", "awq", "exl2"}:
|
if method in {"gptq", "awq", "exl2"}:
|
||||||
logger.info(f"Auto selecting quantization method {method}")
|
log_master(logger.info, f"Auto selecting quantization method {method}")
|
||||||
quantize = method
|
quantize = method
|
||||||
else:
|
else:
|
||||||
logger.info(f"Unknown quantization method {method}")
|
log_master(logger.warning, f"Unknown quantization method {method}")
|
||||||
|
|
||||||
if quantize == "exl2" and sharded:
|
if quantize == "exl2" and sharded:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -593,7 +600,7 @@ def get_model(
|
||||||
)
|
)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
# Lots of legacy models with various weight names.
|
# Lots of legacy models with various weight names.
|
||||||
logger.warning(f"Couldn't load flash gpt2 variant: {e}")
|
log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
|
||||||
return CausalLM.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
|
|
@ -33,7 +33,6 @@ from text_generation_server.layers.attention import (
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.globals import FLASH_DECODING
|
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -42,16 +41,15 @@ from text_generation_server.layers import (
|
||||||
TensorParallelMultiAdapterLinear,
|
TensorParallelMultiAdapterLinear,
|
||||||
TensorParallelAdapterRowLinear,
|
TensorParallelAdapterRowLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.fp8 import Fp8Weight
|
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
DefaultWeightsLoader,
|
|
||||||
UnquantizedWeight,
|
UnquantizedWeight,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
try:
|
try:
|
||||||
|
@ -113,12 +111,12 @@ def load_attention(config, prefix: str, weights, layer_id):
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def no_fp8(weights: Weights):
|
def no_fp8(weights: Weights):
|
||||||
|
"""De-activate fp8 auto conversion for the duration of this context manager"""
|
||||||
weights_loader = weights.weights_loader
|
weights_loader = weights.weights_loader
|
||||||
if (
|
if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8:
|
||||||
isinstance(weights_loader, DefaultWeightsLoader)
|
weights_loader = HybridFP8UnquantLoader(
|
||||||
and weights_loader.weight_class is Fp8Weight
|
weights_loader.activation_scale_ub, to_fp8=False
|
||||||
):
|
)
|
||||||
weights_loader = DefaultWeightsLoader(UnquantizedWeight)
|
|
||||||
|
|
||||||
with weights.use_loader(weights_loader):
|
with weights.use_loader(weights_loader):
|
||||||
yield
|
yield
|
||||||
|
@ -418,7 +416,22 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
self.layers = nn.ModuleList(
|
|
||||||
|
# Skip fp8 quant for first and last layers
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
with no_fp8(weights):
|
||||||
|
self.layers.append(
|
||||||
|
FlashLlamaLayer(
|
||||||
|
index=0,
|
||||||
|
prefix=(
|
||||||
|
"model.layers.0" if not prefix else "{prefix}.model.layers.0"
|
||||||
|
),
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers.extend(
|
||||||
[
|
[
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=layer_id,
|
index=layer_id,
|
||||||
|
@ -430,9 +443,26 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
# Skip first and last layers
|
||||||
|
for layer_id in range(1, config.num_hidden_layers - 1)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with no_fp8(weights):
|
||||||
|
last_layer_id = config.num_hidden_layers - 1
|
||||||
|
self.layers.append(
|
||||||
|
FlashLlamaLayer(
|
||||||
|
index=last_layer_id,
|
||||||
|
prefix=(
|
||||||
|
f"model.layers.{last_layer_id}"
|
||||||
|
if not prefix
|
||||||
|
else f"{prefix}.model.layers.{last_layer_id}"
|
||||||
|
),
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = FastRMSNorm.load(
|
||||||
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
|
|
@ -23,14 +23,13 @@ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
from text_generation_server.utils.chunks import concat_text_chunks
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.utils.dist import RANK
|
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
hub,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
|
@ -1156,31 +1155,36 @@ class FlashCausalLM(Model):
|
||||||
f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
log_master(
|
||||||
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`."
|
logger.info,
|
||||||
|
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
|
||||||
)
|
)
|
||||||
|
|
||||||
if os.path.isfile(tunableop_filepath):
|
if os.path.isfile(tunableop_filepath):
|
||||||
logger.info(
|
log_master(
|
||||||
f"The file {tunableop_filepath} already exists and will be reused."
|
logger.info,
|
||||||
|
f"The file {tunableop_filepath} already exists and will be reused.",
|
||||||
)
|
)
|
||||||
torch.cuda.tunable.read_file(tunableop_filepath)
|
torch.cuda.tunable.read_file(tunableop_filepath)
|
||||||
|
|
||||||
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
|
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
|
||||||
|
|
||||||
for seqlen in tuning_sequences:
|
for seqlen in tuning_sequences:
|
||||||
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
|
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
|
||||||
self.tunableop_warmup(seqlen)
|
self.tunableop_warmup(seqlen)
|
||||||
torch.cuda.tunable.write_file(tunableop_filepath)
|
torch.cuda.tunable.write_file(tunableop_filepath)
|
||||||
torch.cuda.tunable.tuning_enable(False)
|
torch.cuda.tunable.tuning_enable(False)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
log_master(
|
||||||
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp."
|
logger.info,
|
||||||
|
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.",
|
||||||
)
|
)
|
||||||
|
|
||||||
if CUDA_GRAPHS:
|
if CUDA_GRAPHS:
|
||||||
try:
|
try:
|
||||||
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
log_master(
|
||||||
|
logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
|
||||||
|
)
|
||||||
# Warmup cuda graphs
|
# Warmup cuda graphs
|
||||||
for bs in CUDA_GRAPHS:
|
for bs in CUDA_GRAPHS:
|
||||||
if self.speculate is None or self.speculate + 1 <= bs:
|
if self.speculate is None or self.speculate + 1 <= bs:
|
||||||
|
@ -1188,7 +1192,9 @@ class FlashCausalLM(Model):
|
||||||
except torch.cuda.OutOfMemoryError:
|
except torch.cuda.OutOfMemoryError:
|
||||||
logger.exception(f"Decode cuda graph warmup failed")
|
logger.exception(f"Decode cuda graph warmup failed")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
log_master(
|
||||||
|
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
|
||||||
|
)
|
||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE)
|
return int(num_blocks * BLOCK_SIZE)
|
||||||
|
|
||||||
|
@ -1540,8 +1546,7 @@ class FlashCausalLM(Model):
|
||||||
left = 0
|
left = 0
|
||||||
|
|
||||||
if n_accepted_ids > 1:
|
if n_accepted_ids > 1:
|
||||||
if RANK == 0:
|
log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}")
|
||||||
logger.debug(f"Speculated ids {n_accepted_ids - 1}")
|
|
||||||
|
|
||||||
current_stopped = False
|
current_stopped = False
|
||||||
for j in range(index, index + n_accepted_ids):
|
for j in range(index, index + n_accepted_ids):
|
||||||
|
|
|
@ -1,15 +1,16 @@
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Dict
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
||||||
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
||||||
if FLASH_DECODING:
|
if FLASH_DECODING:
|
||||||
logger.info("Using FLASH_DECODING")
|
log_master(logger.info, "Using FLASH_DECODING")
|
||||||
|
|
||||||
|
|
||||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||||
if cuda_graphs is not None:
|
if cuda_graphs is not None:
|
||||||
|
@ -26,11 +27,9 @@ else:
|
||||||
if cuda_graphs is not None:
|
if cuda_graphs is not None:
|
||||||
cuda_graphs.sort(reverse=True)
|
cuda_graphs.sort(reverse=True)
|
||||||
|
|
||||||
|
|
||||||
CUDA_GRAPHS = cuda_graphs
|
CUDA_GRAPHS = cuda_graphs
|
||||||
|
|
||||||
# This is overridden at model loading.
|
# This is overridden at model loading.
|
||||||
global MODEL_ID
|
|
||||||
MODEL_ID = None
|
MODEL_ID = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,8 +40,7 @@ def set_model_id(model_id: str):
|
||||||
|
|
||||||
# NOTE: eventually we should move this into the router and pass back the
|
# NOTE: eventually we should move this into the router and pass back the
|
||||||
# index in all cases.
|
# index in all cases.
|
||||||
global ADAPTER_TO_INDEX
|
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
|
||||||
ADAPTER_TO_INDEX: Dict[str, int] = None
|
|
||||||
|
|
||||||
|
|
||||||
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
||||||
|
|
|
@ -15,6 +15,7 @@ from text_generation_server.utils.adapter import (
|
||||||
AdapterParameters,
|
AdapterParameters,
|
||||||
AdapterSource,
|
AdapterSource,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
@ -204,8 +205,9 @@ class Model(ABC):
|
||||||
f"order to use the dynamic adapter loading feature."
|
f"order to use the dynamic adapter loading feature."
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
log_master(
|
||||||
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}"
|
logger.info,
|
||||||
|
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}",
|
||||||
)
|
)
|
||||||
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
|
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
|
||||||
(
|
(
|
||||||
|
@ -240,8 +242,9 @@ class Model(ABC):
|
||||||
layer_weights.add_adapter(adapter_index, adapter_weights)
|
layer_weights.add_adapter(adapter_index, adapter_weights)
|
||||||
|
|
||||||
if len(unused_weight_names) > 0:
|
if len(unused_weight_names) > 0:
|
||||||
logger.warning(
|
log_master(
|
||||||
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
|
logger.warning,
|
||||||
|
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}",
|
||||||
)
|
)
|
||||||
|
|
||||||
if adapter_tokenizer is not None:
|
if adapter_tokenizer is not None:
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
from itertools import repeat
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
@ -13,6 +12,7 @@ from text_generation_server.models.flash_causal_lm import (
|
||||||
FlashCausalLMBatch,
|
FlashCausalLMBatch,
|
||||||
FlashCausalLM,
|
FlashCausalLM,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
@ -56,8 +56,9 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||||
num_features = get_number_of_features(height, width, config)
|
num_features = get_number_of_features(height, width, config)
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
logger.info(
|
log_master(
|
||||||
f"Found {num_features} features in image of resolution {height}x{width}"
|
logger.info,
|
||||||
|
f"Found {num_features} features in image of resolution {height}x{width}",
|
||||||
)
|
)
|
||||||
return "<image>" * num_features
|
return "<image>" * num_features
|
||||||
|
|
||||||
|
|
|
@ -56,7 +56,7 @@ def initialize_torch_distributed():
|
||||||
backend = "nccl"
|
backend = "nccl"
|
||||||
options = ProcessGroupNCCL.Options()
|
options = ProcessGroupNCCL.Options()
|
||||||
options.is_high_priority_stream = True
|
options.is_high_priority_stream = True
|
||||||
options._timeout = timedelta(seconds=60)
|
options._timeout = timedelta(seconds=120)
|
||||||
else:
|
else:
|
||||||
backend = "gloo"
|
backend = "gloo"
|
||||||
options = None
|
options = None
|
||||||
|
@ -76,7 +76,7 @@ def initialize_torch_distributed():
|
||||||
backend="ccl",
|
backend="ccl",
|
||||||
world_size=WORLD_SIZE,
|
world_size=WORLD_SIZE,
|
||||||
rank=RANK,
|
rank=RANK,
|
||||||
timeout=timedelta(seconds=60),
|
timeout=timedelta(seconds=120),
|
||||||
pg_options=options,
|
pg_options=options,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -84,7 +84,7 @@ def initialize_torch_distributed():
|
||||||
backend=backend,
|
backend=backend,
|
||||||
world_size=WORLD_SIZE,
|
world_size=WORLD_SIZE,
|
||||||
rank=RANK,
|
rank=RANK,
|
||||||
timeout=timedelta(seconds=60),
|
timeout=timedelta(seconds=120),
|
||||||
pg_options=options,
|
pg_options=options,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,6 +1,15 @@
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from text_generation_server.utils.dist import RANK
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(10)
|
@lru_cache(10)
|
||||||
def log_once(log, msg: str):
|
def log_once(log, msg: str, master=True):
|
||||||
|
if master:
|
||||||
|
log_master(log, msg)
|
||||||
|
else:
|
||||||
|
log(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def log_master(log, msg: str):
|
||||||
|
if RANK == 0:
|
||||||
log(msg)
|
log(msg)
|
||||||
|
|
|
@ -11,6 +11,7 @@ from text_generation_server.utils.weights import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Split this config to have a single config type per quant method
|
||||||
@dataclass
|
@dataclass
|
||||||
class _QuantizerConfig:
|
class _QuantizerConfig:
|
||||||
bits: int
|
bits: int
|
||||||
|
@ -21,6 +22,11 @@ class _QuantizerConfig:
|
||||||
sym: bool
|
sym: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _FP8QuantizerConfig:
|
||||||
|
activation_scale_ub: float
|
||||||
|
|
||||||
|
|
||||||
# We should probably do this with Pytantic JSON deserialization,
|
# We should probably do this with Pytantic JSON deserialization,
|
||||||
# but for now we'll stay close to the old _set_gptq_params.
|
# but for now we'll stay close to the old _set_gptq_params.
|
||||||
def _get_quantizer_config(model_id, revision):
|
def _get_quantizer_config(model_id, revision):
|
||||||
|
@ -39,6 +45,13 @@ def _get_quantizer_config(model_id, revision):
|
||||||
filename = hf_hub_download(model_id, filename=filename, revision=revision)
|
filename = hf_hub_download(model_id, filename=filename, revision=revision)
|
||||||
with open(filename, "r") as f:
|
with open(filename, "r") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
|
# FP8 config
|
||||||
|
if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
|
||||||
|
return _FP8QuantizerConfig(
|
||||||
|
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
|
||||||
|
)
|
||||||
|
|
||||||
bits = data["quantization_config"]["bits"]
|
bits = data["quantization_config"]["bits"]
|
||||||
groupsize = data["quantization_config"]["group_size"]
|
groupsize = data["quantization_config"]["group_size"]
|
||||||
# Order is important here, desc_act is missing on some real models
|
# Order is important here, desc_act is missing on some real models
|
||||||
|
@ -99,6 +112,12 @@ def get_loader(
|
||||||
if quantize in {"awq", "gptq"}:
|
if quantize in {"awq", "gptq"}:
|
||||||
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||||
|
|
||||||
|
# TODO: improve check once we have one config type per quantize value
|
||||||
|
if not isinstance(quantizer_config, _QuantizerConfig):
|
||||||
|
raise ValueError(
|
||||||
|
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
|
||||||
|
)
|
||||||
|
|
||||||
return GPTQWeightsLoader(
|
return GPTQWeightsLoader(
|
||||||
bits=quantizer_config.bits,
|
bits=quantizer_config.bits,
|
||||||
desc_act=quantizer_config.desc_act,
|
desc_act=quantizer_config.desc_act,
|
||||||
|
@ -127,18 +146,28 @@ def get_loader(
|
||||||
from text_generation_server.layers.exl2 import Exl2WeightsLoader
|
from text_generation_server.layers.exl2 import Exl2WeightsLoader
|
||||||
|
|
||||||
return Exl2WeightsLoader()
|
return Exl2WeightsLoader()
|
||||||
elif quantize == "fp8":
|
|
||||||
from text_generation_server.layers.fp8 import Fp8Weight
|
|
||||||
|
|
||||||
return DefaultWeightsLoader(Fp8Weight)
|
|
||||||
elif quantize == "marlin":
|
elif quantize == "marlin":
|
||||||
from text_generation_server.layers.marlin import MarlinWeightsLoader
|
from text_generation_server.layers.marlin import MarlinWeightsLoader
|
||||||
|
|
||||||
|
# TODO: improve check once we have one config type per quantize value
|
||||||
|
if not isinstance(quantizer_config, _QuantizerConfig):
|
||||||
|
raise ValueError(
|
||||||
|
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
|
||||||
|
)
|
||||||
|
|
||||||
return MarlinWeightsLoader(
|
return MarlinWeightsLoader(
|
||||||
bits=quantizer_config.bits,
|
bits=quantizer_config.bits,
|
||||||
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
|
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
|
||||||
)
|
)
|
||||||
elif quantize is None:
|
elif quantize == "fp8" or quantize is None:
|
||||||
return DefaultWeightsLoader(UnquantizedWeight)
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
|
|
||||||
|
# Since the default for the quantize config is _QuantizerConfig,
|
||||||
|
# we need to add this check to not get an attribute error
|
||||||
|
activation_scale_ub = None
|
||||||
|
if isinstance(quantizer_config, _FP8QuantizerConfig):
|
||||||
|
activation_scale_ub = quantizer_config.activation_scale_ub
|
||||||
|
|
||||||
|
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown quantization method: {quantize}")
|
raise ValueError(f"Unknown quantization method: {quantize}")
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum, auto
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union, Type
|
||||||
|
|
||||||
import torch
|
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ class Weight(ABC):
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UnquantizedWeight:
|
class UnquantizedWeight(Weight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
@ -99,7 +99,7 @@ class UnquantizedWeight:
|
||||||
class DefaultWeightsLoader(WeightsLoader):
|
class DefaultWeightsLoader(WeightsLoader):
|
||||||
"""Weight loader that loads (unquantized) Torch tensors."""
|
"""Weight loader that loads (unquantized) Torch tensors."""
|
||||||
|
|
||||||
def __init__(self, weight_class):
|
def __init__(self, weight_class: Type[UnquantizedWeight]):
|
||||||
"""Create a loader. Weights will be wrapped using the given `weights_class`,
|
"""Create a loader. Weights will be wrapped using the given `weights_class`,
|
||||||
normally this will be `UnquantizedWeight`, but a quantizer-specific class
|
normally this will be `UnquantizedWeight`, but a quantizer-specific class
|
||||||
such as `Fp8Weight` can be used to quantize the weights during loading.
|
such as `Fp8Weight` can be used to quantize the weights during loading.
|
||||||
|
@ -208,20 +208,29 @@ class Weights:
|
||||||
def get_shape(self, tensor_name: str):
|
def get_shape(self, tensor_name: str):
|
||||||
return self._get_slice(tensor_name).get_shape()
|
return self._get_slice(tensor_name).get_shape()
|
||||||
|
|
||||||
def get_tensor(self, tensor_name: str, to_device=True):
|
def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
tensor = f.get_tensor(tensor_name)
|
tensor = f.get_tensor(tensor_name)
|
||||||
# Special case for gptq which shouldn't convert
|
# Special case for gptq which shouldn't convert
|
||||||
# u4 which are disguised as int32. Exl2 uses int16
|
# u4 which are disguised as int32. Exl2 uses int16
|
||||||
# as well.
|
# as well. FP8 uses torch.float8_e4m3fn
|
||||||
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
|
if (
|
||||||
|
tensor.dtype
|
||||||
|
not in [
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
torch.int16,
|
||||||
|
torch.int32,
|
||||||
|
torch.int64,
|
||||||
|
]
|
||||||
|
and to_dtype
|
||||||
|
):
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
if to_device:
|
if to_device:
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_partial_sharded(self, tensor_name: str, dim: int):
|
def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype=True):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
slice_ = f.get_slice(tensor_name)
|
slice_ = f.get_slice(tensor_name)
|
||||||
|
@ -241,12 +250,16 @@ class Weights:
|
||||||
raise NotImplementedError("Let's make that generic when needed")
|
raise NotImplementedError("Let's make that generic when needed")
|
||||||
# Special case for gptq which shouldn't convert
|
# Special case for gptq which shouldn't convert
|
||||||
# u4 which are disguised as int32. exl2 uses int16.
|
# u4 which are disguised as int32. exl2 uses int16.
|
||||||
if tensor.dtype not in (torch.int16, torch.int32):
|
# FP8 uses torch.float8_e4m3fn.
|
||||||
|
if (
|
||||||
|
tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32)
|
||||||
|
and to_dtype
|
||||||
|
):
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_sharded(self, tensor_name: str, dim: int):
|
def get_sharded(self, tensor_name: str, dim: int, to_dtype=True):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
slice_ = f.get_slice(tensor_name)
|
slice_ = f.get_slice(tensor_name)
|
||||||
|
@ -255,10 +268,14 @@ class Weights:
|
||||||
assert (
|
assert (
|
||||||
size % world_size == 0
|
size % world_size == 0
|
||||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||||
return self.get_partial_sharded(tensor_name, dim)
|
return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype)
|
||||||
|
|
||||||
def get_packed_sharded(
|
def get_packed_sharded(
|
||||||
self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]]
|
self,
|
||||||
|
tensor_name: str,
|
||||||
|
dim: int,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
to_dtype=True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Get a shard from a tensor that packs multiple tensors.
|
Get a shard from a tensor that packs multiple tensors.
|
||||||
|
@ -304,7 +321,16 @@ class Weights:
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
|
|
||||||
# Avoid casting quantizer dtypes.
|
# Avoid casting quantizer dtypes.
|
||||||
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
|
if (
|
||||||
|
tensor.dtype
|
||||||
|
not in [
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
torch.int16,
|
||||||
|
torch.int32,
|
||||||
|
torch.int64,
|
||||||
|
]
|
||||||
|
and to_dtype
|
||||||
|
):
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
|
|
||||||
return tensor
|
return tensor
|
||||||
|
|
Loading…
Reference in New Issue