feat(fp8): use fbgemm kernels and load fp8 weights directly (#2248)
* feat(fp8): add support for fbgemm * allow loading fp8 weights directly * update outlines * fix makefile * build fbgemm * avoid circular import and fix dockerfile * add default dtype * refactored weights loader * fix auto conversion * fix quantization config parsing * force new nccl on install * missing get_weights implementation * increase timeout
This commit is contained in:
parent
e5c1d6d611
commit
53ec0b790b
|
@ -183,4 +183,3 @@ jobs:
|
|||
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||
echo $DOCKER_IMAGE
|
||||
pytest -s -vv integration-tests ${PYTEST_FLAGS}
|
||||
|
||||
|
|
17
Dockerfile
17
Dockerfile
|
@ -161,6 +161,17 @@ COPY server/custom_kernels/ .
|
|||
# Build specific version of transformers
|
||||
RUN python setup.py build
|
||||
|
||||
# Build FBGEMM CUDA kernels
|
||||
FROM kernel-builder AS fbgemm-builder
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
COPY server/Makefile-fbgemm Makefile
|
||||
COPY server/fbgemm_remove_unused.patch fbgemm_remove_unused.patch
|
||||
COPY server/fix_torch90a.sh fix_torch90a.sh
|
||||
|
||||
RUN make build-fbgemm
|
||||
|
||||
# Build vllm CUDA kernels
|
||||
FROM kernel-builder AS vllm-builder
|
||||
|
||||
|
@ -225,10 +236,10 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31
|
|||
# Copy build artifacts from marlin kernels builder
|
||||
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# Copy builds artifacts from vllm builder
|
||||
# Copy build artifacts from fbgemm builder
|
||||
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages
|
||||
# Copy build artifacts from vllm builder
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# Copy build artifacts from mamba builder
|
||||
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||
|
|
|
@ -5,6 +5,7 @@ include Makefile-awq
|
|||
include Makefile-eetq
|
||||
include Makefile-selective-scan
|
||||
include Makefile-lorax-punica
|
||||
include Makefile-fbgemm
|
||||
|
||||
unit-tests:
|
||||
pytest -s -vv -m "not private" tests
|
||||
|
@ -27,8 +28,9 @@ install-server: gen-server
|
|||
install: install-cuda
|
||||
echo "Installed server"
|
||||
|
||||
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention
|
||||
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
|
||||
pip install -e ".[bnb]"
|
||||
pip install nvidia-nccl-cu12==2.22.3
|
||||
|
||||
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
||||
|
||||
|
@ -36,5 +38,6 @@ run-dev:
|
|||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
||||
|
||||
export-requirements:
|
||||
poetry export -o requirements_cuda.txt --without-hashes --with cuda
|
||||
poetry export -o requirements_cuda.txt --without-hashes
|
||||
poetry export -o requirements_rocm.txt --without-hashes
|
||||
poetry export -o requirements_intel.txt --without-hashes
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
fbgemm_commit := 9cf0429b726931cfab72b8264730bea682f32fca
|
||||
|
||||
build-fbgemm:
|
||||
chmod +x fix_torch90a.sh && ./fix_torch90a.sh && \
|
||||
git clone https://github.com/pytorch/FBGEMM.git fbgemm && \
|
||||
cp fbgemm_remove_unused.patch fbgemm && \
|
||||
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && git apply fbgemm_remove_unused.patch && \
|
||||
git submodule update --init --recursive && \
|
||||
cd fbgemm_gpu && \
|
||||
pip install -r requirements.txt && \
|
||||
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build
|
||||
|
||||
install-fbgemm: build-fbgemm
|
||||
cd fbgemm/fbgemm_gpu && \
|
||||
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install
|
|
@ -0,0 +1,306 @@
|
|||
diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt
|
||||
index 2244ea6f..96265a48 100644
|
||||
--- a/fbgemm_gpu/CMakeLists.txt
|
||||
+++ b/fbgemm_gpu/CMakeLists.txt
|
||||
@@ -94,14 +94,14 @@ endif()
|
||||
# Build Experimental Modules
|
||||
################################################################################
|
||||
|
||||
-if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
|
||||
- # TODO: Figure out NCCL/RCCL integration with ROCm
|
||||
- add_subdirectory(experimental/example)
|
||||
-endif()
|
||||
-
|
||||
-if(NOT FBGEMM_CPU_ONLY)
|
||||
- add_subdirectory(experimental/gemm)
|
||||
-endif()
|
||||
+# if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
|
||||
+# # TODO: Figure out NCCL/RCCL integration with ROCm
|
||||
+# add_subdirectory(experimental/example)
|
||||
+# endif()
|
||||
+
|
||||
+# if(NOT FBGEMM_CPU_ONLY)
|
||||
+# add_subdirectory(experimental/gemm)
|
||||
+# endif()
|
||||
|
||||
if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
|
||||
# CUTLASS currently doesn't build on ROCm and CK hasnt yet been added:
|
||||
diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake
|
||||
index c56773fe..0c0d349e 100644
|
||||
--- a/fbgemm_gpu/FbgemmGpu.cmake
|
||||
+++ b/fbgemm_gpu/FbgemmGpu.cmake
|
||||
@@ -446,53 +446,55 @@ set_source_files_properties(${fbgemm_sources}
|
||||
################################################################################
|
||||
|
||||
set(fbgemm_gpu_sources_static_cpu
|
||||
- codegen/training/forward/embedding_forward_split_cpu.cpp
|
||||
- codegen/inference/embedding_forward_quantized_host_cpu.cpp
|
||||
- codegen/training/backward/embedding_backward_dense_host_cpu.cpp
|
||||
- codegen/utils/embedding_bounds_check_host_cpu.cpp
|
||||
- src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
|
||||
- src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp
|
||||
- src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp
|
||||
- src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
|
||||
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
|
||||
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
|
||||
- src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
|
||||
- src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
|
||||
- src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
|
||||
- src/input_combine_ops/input_combine_cpu.cpp
|
||||
- src/layout_transform_ops/layout_transform_ops_cpu.cpp
|
||||
+ # codegen/training/forward/embedding_forward_split_cpu.cpp
|
||||
+ # codegen/inference/embedding_forward_quantized_host_cpu.cpp
|
||||
+ # codegen/training/backward/embedding_backward_dense_host_cpu.cpp
|
||||
+ # codegen/utils/embedding_bounds_check_host_cpu.cpp
|
||||
+ # src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
|
||||
+ # src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp
|
||||
+ # src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp
|
||||
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
|
||||
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
|
||||
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
|
||||
+ # src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
|
||||
+ # src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
|
||||
+ # src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
|
||||
+ # src/input_combine_ops/input_combine_cpu.cpp
|
||||
+ # src/layout_transform_ops/layout_transform_ops_cpu.cpp
|
||||
src/quantize_ops/quantize_ops_cpu.cpp
|
||||
src/quantize_ops/quantize_ops_meta.cpp
|
||||
- src/sparse_ops/sparse_ops_cpu.cpp
|
||||
- src/sparse_ops/sparse_ops_meta.cpp
|
||||
- src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp
|
||||
- src/split_embeddings_cache/linearize_cache_indices.cpp
|
||||
- src/split_embeddings_cache/lfu_cache_populate_byte.cpp
|
||||
- src/split_embeddings_cache/lru_cache_populate_byte.cpp
|
||||
- src/split_embeddings_cache/lxu_cache.cpp
|
||||
- src/split_embeddings_cache/split_embeddings_cache_ops.cpp
|
||||
- codegen/training/index_select/batch_index_select_dim0_ops.cpp
|
||||
- codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)
|
||||
+ # src/sparse_ops/sparse_ops_cpu.cpp
|
||||
+ # src/sparse_ops/sparse_ops_meta.cpp
|
||||
+ # src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp
|
||||
+ # src/split_embeddings_cache/linearize_cache_indices.cpp
|
||||
+ # src/split_embeddings_cache/lfu_cache_populate_byte.cpp
|
||||
+ # src/split_embeddings_cache/lru_cache_populate_byte.cpp
|
||||
+ # src/split_embeddings_cache/lxu_cache.cpp
|
||||
+ # src/split_embeddings_cache/split_embeddings_cache_ops.cpp
|
||||
+ # codegen/training/index_select/batch_index_select_dim0_ops.cpp
|
||||
+ # codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)
|
||||
+)
|
||||
|
||||
if(NOT FBGEMM_CPU_ONLY)
|
||||
list(APPEND fbgemm_gpu_sources_static_cpu
|
||||
- codegen/inference/embedding_forward_quantized_host.cpp
|
||||
- codegen/utils/embedding_bounds_check_host.cpp
|
||||
- src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp
|
||||
- src/layout_transform_ops/layout_transform_ops_gpu.cpp
|
||||
- src/memory_utils/memory_utils.cpp
|
||||
- src/memory_utils/memory_utils_ops.cpp
|
||||
- src/memory_utils/memory_utils_ops_cpu.cpp
|
||||
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp
|
||||
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp
|
||||
+ # codegen/inference/embedding_forward_quantized_host.cpp
|
||||
+ # codegen/utils/embedding_bounds_check_host.cpp
|
||||
+ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp
|
||||
+ # src/layout_transform_ops/layout_transform_ops_gpu.cpp
|
||||
+ # src/memory_utils/memory_utils.cpp
|
||||
+ # src/memory_utils/memory_utils_ops.cpp
|
||||
+ # src/memory_utils/memory_utils_ops_cpu.cpp
|
||||
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp
|
||||
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp
|
||||
src/quantize_ops/quantize_ops_gpu.cpp
|
||||
- src/sparse_ops/sparse_ops_gpu.cpp
|
||||
- src/split_embeddings_utils/split_embeddings_utils.cpp
|
||||
- src/split_embeddings_cache/split_embeddings_cache_ops.cu
|
||||
- src/metric_ops/metric_ops_host.cpp
|
||||
- src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp
|
||||
- src/input_combine_ops/input_combine_gpu.cpp
|
||||
- codegen/training/index_select/batch_index_select_dim0_host.cpp)
|
||||
+ # src/sparse_ops/sparse_ops_gpu.cpp
|
||||
+ # src/split_embeddings_utils/split_embeddings_utils.cpp
|
||||
+ # src/split_embeddings_cache/split_embeddings_cache_ops.cu
|
||||
+ # src/metric_ops/metric_ops_host.cpp
|
||||
+ # src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp
|
||||
+ # src/input_combine_ops/input_combine_gpu.cpp
|
||||
+ # codegen/training/index_select/batch_index_select_dim0_host.cpp)
|
||||
+ )
|
||||
|
||||
if(NVML_LIB_PATH OR USE_ROCM)
|
||||
message(STATUS "Adding merge_pooled_embeddings sources")
|
||||
@@ -516,36 +518,36 @@ endif()
|
||||
|
||||
if(NOT FBGEMM_CPU_ONLY)
|
||||
set(fbgemm_gpu_sources_static_gpu
|
||||
- codegen/utils/embedding_bounds_check.cu
|
||||
- codegen/inference/embedding_forward_quantized_split_lookup.cu
|
||||
- src/embedding_inplace_ops/embedding_inplace_update.cu
|
||||
- src/histogram_binning_calibration_ops.cu
|
||||
- src/input_combine_ops/input_combine.cu
|
||||
- src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu
|
||||
- src/memory_utils/memory_utils.cu
|
||||
- src/memory_utils/memory_utils_ops.cu
|
||||
- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu
|
||||
- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu
|
||||
- src/jagged_tensor_ops/dense_to_jagged_forward.cu
|
||||
- src/jagged_tensor_ops/jagged_dense_bmm_forward.cu
|
||||
- src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu
|
||||
- src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu
|
||||
- src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu
|
||||
- src/jagged_tensor_ops/jagged_index_add_2d_forward.cu
|
||||
- src/jagged_tensor_ops/jagged_index_select_2d_forward.cu
|
||||
- src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu
|
||||
- src/jagged_tensor_ops/jagged_softmax_backward.cu
|
||||
- src/jagged_tensor_ops/jagged_softmax_forward.cu
|
||||
- src/jagged_tensor_ops/jagged_tensor_ops.cu
|
||||
- src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu
|
||||
- src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
|
||||
- src/jagged_tensor_ops/jagged_unique_indices.cu
|
||||
- src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
|
||||
- src/layout_transform_ops/layout_transform_ops.cu
|
||||
- src/metric_ops/metric_ops.cu
|
||||
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu
|
||||
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu
|
||||
- src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu
|
||||
+ # codegen/utils/embedding_bounds_check.cu
|
||||
+ # codegen/inference/embedding_forward_quantized_split_lookup.cu
|
||||
+ # src/embedding_inplace_ops/embedding_inplace_update.cu
|
||||
+ # src/histogram_binning_calibration_ops.cu
|
||||
+ # src/input_combine_ops/input_combine.cu
|
||||
+ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu
|
||||
+ # src/memory_utils/memory_utils.cu
|
||||
+ # src/memory_utils/memory_utils_ops.cu
|
||||
+ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu
|
||||
+ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu
|
||||
+ # src/jagged_tensor_ops/dense_to_jagged_forward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_dense_bmm_forward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_index_add_2d_forward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_index_select_2d_forward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_softmax_backward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_softmax_forward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_tensor_ops.cu
|
||||
+ # src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_unique_indices.cu
|
||||
+ # src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
|
||||
+ # src/layout_transform_ops/layout_transform_ops.cu
|
||||
+ # src/metric_ops/metric_ops.cu
|
||||
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu
|
||||
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu
|
||||
+ # src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu
|
||||
src/quantize_ops/quantize_bfloat16.cu
|
||||
src/quantize_ops/quantize_fp8_rowwise.cu
|
||||
src/quantize_ops/quantize_fused_8bit_rowwise.cu
|
||||
@@ -554,39 +556,40 @@ if(NOT FBGEMM_CPU_ONLY)
|
||||
src/quantize_ops/quantize_msfp.cu
|
||||
src/quantize_ops/quantize_padded_fp8_rowwise.cu
|
||||
src/quantize_ops/quantize_mx.cu
|
||||
- src/sparse_ops/sparse_async_cumsum.cu
|
||||
- src/sparse_ops/sparse_block_bucketize_features.cu
|
||||
- src/sparse_ops/sparse_bucketize_features.cu
|
||||
- src/sparse_ops/sparse_batched_unary_embeddings.cu
|
||||
- src/sparse_ops/sparse_compute_frequency_sequence.cu
|
||||
- src/sparse_ops/sparse_expand_into_jagged_permute.cu
|
||||
- src/sparse_ops/sparse_group_index.cu
|
||||
- src/sparse_ops/sparse_index_add.cu
|
||||
- src/sparse_ops/sparse_index_select.cu
|
||||
- src/sparse_ops/sparse_invert_permute.cu
|
||||
- src/sparse_ops/sparse_pack_segments_backward.cu
|
||||
- src/sparse_ops/sparse_pack_segments_forward.cu
|
||||
- src/sparse_ops/sparse_permute_1d.cu
|
||||
- src/sparse_ops/sparse_permute_2d.cu
|
||||
- src/sparse_ops/sparse_permute102.cu
|
||||
- src/sparse_ops/sparse_permute_embeddings.cu
|
||||
- src/sparse_ops/sparse_range.cu
|
||||
- src/sparse_ops/sparse_reorder_batched_ad.cu
|
||||
- src/sparse_ops/sparse_segment_sum_csr.cu
|
||||
- src/sparse_ops/sparse_zipf.cu
|
||||
- src/split_embeddings_cache/lfu_cache_find.cu
|
||||
- src/split_embeddings_cache/lfu_cache_populate.cu
|
||||
- src/split_embeddings_cache/lfu_cache_populate_byte.cu
|
||||
- src/split_embeddings_cache/lru_cache_find.cu
|
||||
- src/split_embeddings_cache/lru_cache_populate.cu
|
||||
- src/split_embeddings_cache/lru_cache_populate_byte.cu
|
||||
- src/split_embeddings_cache/lxu_cache.cu
|
||||
- src/split_embeddings_cache/linearize_cache_indices.cu
|
||||
- src/split_embeddings_cache/reset_weight_momentum.cu
|
||||
- src/split_embeddings_utils/generate_vbe_metadata.cu
|
||||
- src/split_embeddings_utils/get_infos_metadata.cu
|
||||
- src/split_embeddings_utils/radix_sort_pairs.cu
|
||||
- src/split_embeddings_utils/transpose_embedding_input.cu)
|
||||
+ # src/sparse_ops/sparse_async_cumsum.cu
|
||||
+ # src/sparse_ops/sparse_block_bucketize_features.cu
|
||||
+ # src/sparse_ops/sparse_bucketize_features.cu
|
||||
+ # src/sparse_ops/sparse_batched_unary_embeddings.cu
|
||||
+ # src/sparse_ops/sparse_compute_frequency_sequence.cu
|
||||
+ # src/sparse_ops/sparse_expand_into_jagged_permute.cu
|
||||
+ # src/sparse_ops/sparse_group_index.cu
|
||||
+ # src/sparse_ops/sparse_index_add.cu
|
||||
+ # src/sparse_ops/sparse_index_select.cu
|
||||
+ # src/sparse_ops/sparse_invert_permute.cu
|
||||
+ # src/sparse_ops/sparse_pack_segments_backward.cu
|
||||
+ # src/sparse_ops/sparse_pack_segments_forward.cu
|
||||
+ # src/sparse_ops/sparse_permute_1d.cu
|
||||
+ # src/sparse_ops/sparse_permute_2d.cu
|
||||
+ # src/sparse_ops/sparse_permute102.cu
|
||||
+ # src/sparse_ops/sparse_permute_embeddings.cu
|
||||
+ # src/sparse_ops/sparse_range.cu
|
||||
+ # src/sparse_ops/sparse_reorder_batched_ad.cu
|
||||
+ # src/sparse_ops/sparse_segment_sum_csr.cu
|
||||
+ # src/sparse_ops/sparse_zipf.cu
|
||||
+ # src/split_embeddings_cache/lfu_cache_find.cu
|
||||
+ # src/split_embeddings_cache/lfu_cache_populate.cu
|
||||
+ # src/split_embeddings_cache/lfu_cache_populate_byte.cu
|
||||
+ # src/split_embeddings_cache/lru_cache_find.cu
|
||||
+ # src/split_embeddings_cache/lru_cache_populate.cu
|
||||
+ # src/split_embeddings_cache/lru_cache_populate_byte.cu
|
||||
+ # src/split_embeddings_cache/lxu_cache.cu
|
||||
+ # src/split_embeddings_cache/linearize_cache_indices.cu
|
||||
+ # src/split_embeddings_cache/reset_weight_momentum.cu
|
||||
+ # src/split_embeddings_utils/generate_vbe_metadata.cu
|
||||
+ # src/split_embeddings_utils/get_infos_metadata.cu
|
||||
+ # src/split_embeddings_utils/radix_sort_pairs.cu
|
||||
+ # src/split_embeddings_utils/transpose_embedding_input.cu)
|
||||
+ )
|
||||
|
||||
set_source_files_properties(${fbgemm_gpu_sources_static_gpu}
|
||||
PROPERTIES COMPILE_OPTIONS
|
||||
diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
|
||||
index 01f1d6ab..a6b8d7a8 100644
|
||||
--- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
|
||||
+++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
|
||||
@@ -25,23 +25,24 @@ set(fbgemm_sources_include_directories
|
||||
${THIRDPARTY}/json/include
|
||||
${NCCL_INCLUDE_DIRS})
|
||||
|
||||
-set(attention_ops_sources
|
||||
- src/attention/attention.cpp
|
||||
- src/attention/gqa_attn_splitk.cu)
|
||||
+# set(attention_ops_sources
|
||||
+# src/attention/attention.cpp
|
||||
+# src/attention/gqa_attn_splitk.cu)
|
||||
|
||||
set(quantize_ops_sources
|
||||
src/quantize/cutlass_extensions.cu
|
||||
src/quantize/quantize.cu
|
||||
src/quantize/quantize.cpp)
|
||||
|
||||
-set(comm_ops_sources
|
||||
- src/comm/car.cu
|
||||
- src/comm/car.cpp)
|
||||
+# set(comm_ops_sources
|
||||
+# src/comm/car.cu
|
||||
+# src/comm/car.cpp)
|
||||
|
||||
set(experimental_gen_ai_cpp_source_files
|
||||
- ${attention_ops_sources}
|
||||
+ # ${attention_ops_sources}
|
||||
${quantize_ops_sources}
|
||||
- ${comm_ops_sources})
|
||||
+ # ${comm_ops_sources}
|
||||
+)
|
||||
|
||||
set_source_files_properties(${experimental_gen_ai_cpp_source_files}
|
||||
PROPERTIES INCLUDE_DIRECTORIES
|
|
@ -0,0 +1,11 @@
|
|||
#!/bin/bash
|
||||
|
||||
# This script is required to patch torch < 2.4
|
||||
# It adds the 90a cuda target (H100)
|
||||
# This target is required to build FBGEMM kernels
|
||||
|
||||
torch_cuda_arch=$(python -c "import torch; print(torch.__file__)" | sed 's/\/__init__.py//; s|$|/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake|')
|
||||
|
||||
sed -i '189s/\[0-9]\\\\\.\[0-9](/[0-9]\\\\.[0-9]a?(/' $torch_cuda_arch
|
||||
sed -i '245s/\[0-9()]+\+"/[0-9()]+a?"/' $torch_cuda_arch
|
||||
sed -i '246s/\[0-9]+\+"/[0-9]+a?"/' $torch_cuda_arch
|
|
@ -8,6 +8,7 @@ from typing import Optional
|
|||
from enum import Enum
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
@ -87,15 +88,17 @@ def serve(
|
|||
)
|
||||
|
||||
if len(lora_adapter_ids) > 0:
|
||||
logger.warning(
|
||||
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
|
||||
log_master(
|
||||
logger.warning,
|
||||
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.",
|
||||
)
|
||||
|
||||
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
|
||||
# and warn the user
|
||||
if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
|
||||
logger.warning(
|
||||
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs."
|
||||
log_master(
|
||||
logger.warning,
|
||||
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.",
|
||||
)
|
||||
global CUDA_GRAPHS
|
||||
CUDA_GRAPHS = None
|
||||
|
|
|
@ -3,6 +3,7 @@ import torch
|
|||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.utils.log import log_master
|
||||
from loguru import logger
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
|
@ -136,7 +137,10 @@ if ENGINE != "triton":
|
|||
try:
|
||||
import flash_attn_2_cuda
|
||||
|
||||
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
|
||||
log_master(
|
||||
logger.info,
|
||||
"ROCm: using Flash Attention 2 Composable Kernel implementation.",
|
||||
)
|
||||
except ImportError as e:
|
||||
if major >= 8:
|
||||
architecture_suffix = f"-{SYSTEM}"
|
||||
|
|
|
@ -4,19 +4,11 @@ from functools import lru_cache
|
|||
import bitsandbytes as bnb
|
||||
import torch
|
||||
from bitsandbytes.nn import Int8Params, Params4bit
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.weights import Weight
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def warn_deprecate_bnb():
|
||||
logger.warning(
|
||||
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
|
||||
)
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
|
||||
|
||||
@dataclass
|
||||
class BNBWeight(Weight):
|
||||
class BNBWeight(UnquantizedWeight):
|
||||
weight: torch.Tensor
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
|
@ -82,7 +74,7 @@ class Linear8bitLt(torch.nn.Module):
|
|||
|
||||
|
||||
@dataclass
|
||||
class BNBFP4Weight(Weight):
|
||||
class BNBFP4Weight(UnquantizedWeight):
|
||||
weight: torch.Tensor
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
|
@ -90,7 +82,7 @@ class BNBFP4Weight(Weight):
|
|||
|
||||
|
||||
@dataclass
|
||||
class BNBNF4Weight(Weight):
|
||||
class BNBNF4Weight(UnquantizedWeight):
|
||||
weight: torch.Tensor
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
|
|
|
@ -2,11 +2,11 @@ from dataclasses import dataclass
|
|||
|
||||
import torch
|
||||
from EETQ import quant_weights, w8_a16_gemm
|
||||
from text_generation_server.utils.weights import Weight
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
|
||||
|
||||
@dataclass
|
||||
class EETQWeight(Weight):
|
||||
class EETQWeight(UnquantizedWeight):
|
||||
weight: torch.Tensor
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
|
|
|
@ -1,8 +1,29 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union, List
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import Weight
|
||||
from text_generation_server.utils.weights import (
|
||||
Weight,
|
||||
WeightsLoader,
|
||||
UnquantizedWeight,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.log import log_master, log_once
|
||||
|
||||
FBGEMM_MM_AVAILABLE = False
|
||||
FBGEMM_DYN_AVAILABLE = False
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
FBGEMM_MM_AVAILABLE = major == 9
|
||||
FBGEMM_DYN_AVAILABLE = major >= 8
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
|
||||
|
||||
|
||||
def get_fp8_linear() -> torch.nn.Module:
|
||||
|
@ -21,12 +42,17 @@ def get_fp8_linear() -> torch.nn.Module:
|
|||
return Fp8Linear
|
||||
|
||||
|
||||
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
||||
device = weight.device
|
||||
def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
|
||||
if FBGEMM_DYN_AVAILABLE:
|
||||
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
||||
)
|
||||
return qweight, scale
|
||||
|
||||
# weight, scale = quant_weights(weight, torch.int8, False)
|
||||
finfo = torch.finfo(qdtype)
|
||||
# Calculate the scale as dtype max divided by absmax
|
||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12)
|
||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
|
||||
# scale and clamp the tensor to bring it to
|
||||
# the representative range of float8 data type
|
||||
# (as default cast is unsaturated)
|
||||
|
@ -38,27 +64,166 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
|||
return qweight, scale
|
||||
|
||||
|
||||
class HybridFP8UnquantLoader(WeightsLoader):
|
||||
"""Weight loader that loads FP8 and unquantized Torch tensors."""
|
||||
|
||||
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
|
||||
self.activation_scale_ub = activation_scale_ub
|
||||
self.to_fp8 = to_fp8
|
||||
|
||||
def get_weights(self, weights: "Weights", prefix: str):
|
||||
w = weights.get_tensor(f"{prefix}.weight")
|
||||
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
# FP8 branch
|
||||
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
if self.to_fp8:
|
||||
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||
|
||||
return UnquantizedWeight(w)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
w = weights.get_packed_sharded(
|
||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||
)
|
||||
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
# FP8 branch
|
||||
scale = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
|
||||
)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
if self.to_fp8:
|
||||
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||
|
||||
return UnquantizedWeight(w)
|
||||
|
||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
w = torch.cat(w, dim=dim)
|
||||
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
scale = [
|
||||
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
|
||||
for p in prefixes
|
||||
]
|
||||
scale = torch.cat(scale, dim=0)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
if self.to_fp8:
|
||||
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||
|
||||
return UnquantizedWeight(w)
|
||||
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
if self.to_fp8:
|
||||
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||
|
||||
return UnquantizedWeight(w)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fp8Weight(Weight):
|
||||
weight: torch.Tensor
|
||||
dtype: torch.dtype
|
||||
weight_scale: Optional[torch.Tensor] = None
|
||||
activation_scale_ub: Optional[float] = None
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
return get_fp8_linear()(self.weight, bias)
|
||||
if self.weight_scale is None:
|
||||
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
|
||||
return get_fp8_linear().from_fp8(
|
||||
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
|
||||
)
|
||||
|
||||
|
||||
class Fp8Linear(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
qweight,
|
||||
scale,
|
||||
scale_upper_bound,
|
||||
bias,
|
||||
dtype,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dtype = weight.dtype
|
||||
self.qweight, self.scale = fp8_quantize(weight)
|
||||
self.dtype = dtype
|
||||
self.qweight = qweight
|
||||
self.scale = scale
|
||||
self.scale_upper_bound = (
|
||||
torch.tensor(
|
||||
[scale_upper_bound], dtype=torch.float32, device=qweight.device
|
||||
)
|
||||
if scale_upper_bound is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.bias = bias if bias is not None else None
|
||||
|
||||
@classmethod
|
||||
def from_unquant(cls, weight, bias, dtype):
|
||||
qweight, scale = fp8_quantize(weight)
|
||||
return cls(
|
||||
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
|
||||
return cls(
|
||||
qweight=weight,
|
||||
scale=scale,
|
||||
scale_upper_bound=input_scale,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if FBGEMM_MM_AVAILABLE:
|
||||
qinput, scale = fp8_quantize(
|
||||
input, scale_upper_bound=self.scale_upper_bound
|
||||
)
|
||||
|
||||
y = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||
qinput,
|
||||
self.qweight,
|
||||
scale,
|
||||
self.scale,
|
||||
use_fast_accum=True,
|
||||
bias=self.bias,
|
||||
)
|
||||
return y.to(self.dtype)
|
||||
|
||||
qinput, scale = fp8_quantize(input)
|
||||
output, _ = torch._scaled_mm(
|
||||
qinput,
|
||||
|
|
|
@ -9,11 +9,12 @@ from loguru import logger
|
|||
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
try:
|
||||
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
||||
except ImportError:
|
||||
logger.error("exllamav2_kernels not installed.")
|
||||
log_master(logger.warning, "exllamav2_kernels not installed.")
|
||||
raise
|
||||
|
||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||
|
|
|
@ -503,7 +503,8 @@ class GPTQMarlinFP8Linear(nn.Module):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
weight: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
@ -513,7 +514,6 @@ class GPTQMarlinFP8Linear(nn.Module):
|
|||
|
||||
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
|
||||
|
||||
qweight, scale = fp8_quantize(weight)
|
||||
scale = scale.to(torch.float16)
|
||||
qweight, scales = repack_fp8_for_marlin(qweight, scale)
|
||||
|
||||
|
@ -529,6 +529,15 @@ class GPTQMarlinFP8Linear(nn.Module):
|
|||
out_features // 64 * 16, dtype=torch.int, device=qweight.device
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_unquant(cls, weight, bias, _dtype):
|
||||
qweight, scale = fp8_quantize(weight)
|
||||
return cls(qweight=qweight, scale=scale, bias=bias)
|
||||
|
||||
@classmethod
|
||||
def from_fp8(cls, weight, scale, _input_scale, bias, _dtype):
|
||||
return cls(qweight=weight, scale=scale, bias=bias)
|
||||
|
||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||
assert marlin_kernels is not None
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ from text_generation_server.models.custom_modeling.t5_modeling import (
|
|||
)
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||
# in PyTorch 1.12 and later.
|
||||
|
@ -47,9 +48,7 @@ torch.set_grad_enabled(False)
|
|||
|
||||
__all__ = [
|
||||
"Model",
|
||||
"BLOOMSharded",
|
||||
"CausalLM",
|
||||
"GalacticaSharded",
|
||||
"Seq2SeqLM",
|
||||
"get_model",
|
||||
]
|
||||
|
@ -125,7 +124,7 @@ try:
|
|||
)
|
||||
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
||||
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
|
||||
SUPPORTS_WINDOWING = False
|
||||
FLASH_ATTENTION = False
|
||||
|
||||
|
@ -137,7 +136,7 @@ MAMBA_AVAILABLE = True
|
|||
try:
|
||||
from text_generation_server.models.mamba import Mamba
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import Mamba: {e}")
|
||||
log_master(logger.warning, f"Could not import Mamba: {e}")
|
||||
MAMBA_AVAILABLE = False
|
||||
|
||||
if MAMBA_AVAILABLE:
|
||||
|
@ -311,6 +310,12 @@ def get_model(
|
|||
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
# These quantizers only work with float16 params.
|
||||
dtype = torch.float16
|
||||
elif quantize == "fp8":
|
||||
from text_generation_server.layers.fp8 import FBGEMM_MM_AVAILABLE
|
||||
|
||||
if FBGEMM_MM_AVAILABLE:
|
||||
# fbgemm kernels are fp8xfp8->bf16
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
# Keep it as default for now and let
|
||||
# every model resolve their own default dtype.
|
||||
|
@ -433,7 +438,9 @@ def get_model(
|
|||
|
||||
speculate = get_speculate()
|
||||
if speculate > 0:
|
||||
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
||||
log_master(
|
||||
logger.info, f"Using speculation {method} with {speculate} input ids."
|
||||
)
|
||||
|
||||
if model_type is None:
|
||||
# TODO: fix how we determine model type for Mamba
|
||||
|
@ -448,10 +455,10 @@ def get_model(
|
|||
if quantization_config is not None and quantize is None:
|
||||
method = quantization_config.get("quant_method", None)
|
||||
if method in {"gptq", "awq", "exl2"}:
|
||||
logger.info(f"Auto selecting quantization method {method}")
|
||||
log_master(logger.info, f"Auto selecting quantization method {method}")
|
||||
quantize = method
|
||||
else:
|
||||
logger.info(f"Unknown quantization method {method}")
|
||||
log_master(logger.warning, f"Unknown quantization method {method}")
|
||||
|
||||
if quantize == "exl2" and sharded:
|
||||
raise RuntimeError(
|
||||
|
@ -593,7 +600,7 @@ def get_model(
|
|||
)
|
||||
except RuntimeError as e:
|
||||
# Lots of legacy models with various weight names.
|
||||
logger.warning(f"Couldn't load flash gpt2 variant: {e}")
|
||||
log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
|
|
|
@ -33,7 +33,6 @@ from text_generation_server.layers.attention import (
|
|||
attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
|
@ -42,16 +41,15 @@ from text_generation_server.layers import (
|
|||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.fp8 import Fp8Weight
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
UnquantizedWeight,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
try:
|
||||
|
@ -113,12 +111,12 @@ def load_attention(config, prefix: str, weights, layer_id):
|
|||
|
||||
@contextmanager
|
||||
def no_fp8(weights: Weights):
|
||||
"""De-activate fp8 auto conversion for the duration of this context manager"""
|
||||
weights_loader = weights.weights_loader
|
||||
if (
|
||||
isinstance(weights_loader, DefaultWeightsLoader)
|
||||
and weights_loader.weight_class is Fp8Weight
|
||||
):
|
||||
weights_loader = DefaultWeightsLoader(UnquantizedWeight)
|
||||
if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8:
|
||||
weights_loader = HybridFP8UnquantLoader(
|
||||
weights_loader.activation_scale_ub, to_fp8=False
|
||||
)
|
||||
|
||||
with weights.use_loader(weights_loader):
|
||||
yield
|
||||
|
@ -418,7 +416,22 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.layers = nn.ModuleList(
|
||||
|
||||
# Skip fp8 quant for first and last layers
|
||||
self.layers = nn.ModuleList()
|
||||
with no_fp8(weights):
|
||||
self.layers.append(
|
||||
FlashLlamaLayer(
|
||||
index=0,
|
||||
prefix=(
|
||||
"model.layers.0" if not prefix else "{prefix}.model.layers.0"
|
||||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
)
|
||||
|
||||
self.layers.extend(
|
||||
[
|
||||
FlashLlamaLayer(
|
||||
index=layer_id,
|
||||
|
@ -430,9 +443,26 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
# Skip first and last layers
|
||||
for layer_id in range(1, config.num_hidden_layers - 1)
|
||||
]
|
||||
)
|
||||
|
||||
with no_fp8(weights):
|
||||
last_layer_id = config.num_hidden_layers - 1
|
||||
self.layers.append(
|
||||
FlashLlamaLayer(
|
||||
index=last_layer_id,
|
||||
prefix=(
|
||||
f"model.layers.{last_layer_id}"
|
||||
if not prefix
|
||||
else f"{prefix}.model.layers.{last_layer_id}"
|
||||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
)
|
||||
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
||||
weights=weights,
|
||||
|
|
|
@ -23,14 +23,13 @@ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
|||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.utils.log import log_master
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.utils.dist import RANK
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
hub,
|
||||
)
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
|
@ -1156,31 +1155,36 @@ class FlashCausalLM(Model):
|
|||
f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`."
|
||||
log_master(
|
||||
logger.info,
|
||||
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
|
||||
)
|
||||
|
||||
if os.path.isfile(tunableop_filepath):
|
||||
logger.info(
|
||||
f"The file {tunableop_filepath} already exists and will be reused."
|
||||
log_master(
|
||||
logger.info,
|
||||
f"The file {tunableop_filepath} already exists and will be reused.",
|
||||
)
|
||||
torch.cuda.tunable.read_file(tunableop_filepath)
|
||||
|
||||
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
|
||||
|
||||
for seqlen in tuning_sequences:
|
||||
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
|
||||
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
|
||||
self.tunableop_warmup(seqlen)
|
||||
torch.cuda.tunable.write_file(tunableop_filepath)
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
else:
|
||||
logger.info(
|
||||
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp."
|
||||
log_master(
|
||||
logger.info,
|
||||
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.",
|
||||
)
|
||||
|
||||
if CUDA_GRAPHS:
|
||||
try:
|
||||
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
||||
log_master(
|
||||
logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
|
||||
)
|
||||
# Warmup cuda graphs
|
||||
for bs in CUDA_GRAPHS:
|
||||
if self.speculate is None or self.speculate + 1 <= bs:
|
||||
|
@ -1188,7 +1192,9 @@ class FlashCausalLM(Model):
|
|||
except torch.cuda.OutOfMemoryError:
|
||||
logger.exception(f"Decode cuda graph warmup failed")
|
||||
else:
|
||||
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
||||
log_master(
|
||||
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
|
||||
)
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
||||
|
@ -1540,8 +1546,7 @@ class FlashCausalLM(Model):
|
|||
left = 0
|
||||
|
||||
if n_accepted_ids > 1:
|
||||
if RANK == 0:
|
||||
logger.debug(f"Speculated ids {n_accepted_ids - 1}")
|
||||
log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}")
|
||||
|
||||
current_stopped = False
|
||||
for j in range(index, index + n_accepted_ids):
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
import torch
|
||||
import os
|
||||
from loguru import logger
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
# This is overridden by the cli
|
||||
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
||||
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
||||
if FLASH_DECODING:
|
||||
logger.info("Using FLASH_DECODING")
|
||||
|
||||
log_master(logger.info, "Using FLASH_DECODING")
|
||||
|
||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||
if cuda_graphs is not None:
|
||||
|
@ -26,11 +27,9 @@ else:
|
|||
if cuda_graphs is not None:
|
||||
cuda_graphs.sort(reverse=True)
|
||||
|
||||
|
||||
CUDA_GRAPHS = cuda_graphs
|
||||
|
||||
# This is overridden at model loading.
|
||||
global MODEL_ID
|
||||
MODEL_ID = None
|
||||
|
||||
|
||||
|
@ -41,8 +40,7 @@ def set_model_id(model_id: str):
|
|||
|
||||
# NOTE: eventually we should move this into the router and pass back the
|
||||
# index in all cases.
|
||||
global ADAPTER_TO_INDEX
|
||||
ADAPTER_TO_INDEX: Dict[str, int] = None
|
||||
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
|
||||
|
||||
|
||||
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
||||
|
|
|
@ -15,6 +15,7 @@ from text_generation_server.utils.adapter import (
|
|||
AdapterParameters,
|
||||
AdapterSource,
|
||||
)
|
||||
from text_generation_server.utils.log import log_master
|
||||
from loguru import logger
|
||||
|
||||
|
||||
|
@ -204,8 +205,9 @@ class Model(ABC):
|
|||
f"order to use the dynamic adapter loading feature."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}"
|
||||
log_master(
|
||||
logger.info,
|
||||
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}",
|
||||
)
|
||||
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
|
||||
(
|
||||
|
@ -240,8 +242,9 @@ class Model(ABC):
|
|||
layer_weights.add_adapter(adapter_index, adapter_weights)
|
||||
|
||||
if len(unused_weight_names) > 0:
|
||||
logger.warning(
|
||||
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
|
||||
log_master(
|
||||
logger.warning,
|
||||
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}",
|
||||
)
|
||||
|
||||
if adapter_tokenizer is not None:
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from itertools import repeat
|
||||
import torch
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
@ -13,6 +12,7 @@ from text_generation_server.models.flash_causal_lm import (
|
|||
FlashCausalLMBatch,
|
||||
FlashCausalLM,
|
||||
)
|
||||
from text_generation_server.utils.log import log_master
|
||||
from transformers import AutoProcessor
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
@ -56,8 +56,9 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||
num_features = get_number_of_features(height, width, config)
|
||||
from loguru import logger
|
||||
|
||||
logger.info(
|
||||
f"Found {num_features} features in image of resolution {height}x{width}"
|
||||
log_master(
|
||||
logger.info,
|
||||
f"Found {num_features} features in image of resolution {height}x{width}",
|
||||
)
|
||||
return "<image>" * num_features
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ def initialize_torch_distributed():
|
|||
backend = "nccl"
|
||||
options = ProcessGroupNCCL.Options()
|
||||
options.is_high_priority_stream = True
|
||||
options._timeout = timedelta(seconds=60)
|
||||
options._timeout = timedelta(seconds=120)
|
||||
else:
|
||||
backend = "gloo"
|
||||
options = None
|
||||
|
@ -76,7 +76,7 @@ def initialize_torch_distributed():
|
|||
backend="ccl",
|
||||
world_size=WORLD_SIZE,
|
||||
rank=RANK,
|
||||
timeout=timedelta(seconds=60),
|
||||
timeout=timedelta(seconds=120),
|
||||
pg_options=options,
|
||||
)
|
||||
else:
|
||||
|
@ -84,7 +84,7 @@ def initialize_torch_distributed():
|
|||
backend=backend,
|
||||
world_size=WORLD_SIZE,
|
||||
rank=RANK,
|
||||
timeout=timedelta(seconds=60),
|
||||
timeout=timedelta(seconds=120),
|
||||
pg_options=options,
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -1,6 +1,15 @@
|
|||
from functools import lru_cache
|
||||
from text_generation_server.utils.dist import RANK
|
||||
|
||||
|
||||
@lru_cache(10)
|
||||
def log_once(log, msg: str):
|
||||
log(msg)
|
||||
def log_once(log, msg: str, master=True):
|
||||
if master:
|
||||
log_master(log, msg)
|
||||
else:
|
||||
log(msg)
|
||||
|
||||
|
||||
def log_master(log, msg: str):
|
||||
if RANK == 0:
|
||||
log(msg)
|
||||
|
|
|
@ -11,6 +11,7 @@ from text_generation_server.utils.weights import (
|
|||
)
|
||||
|
||||
|
||||
# TODO: Split this config to have a single config type per quant method
|
||||
@dataclass
|
||||
class _QuantizerConfig:
|
||||
bits: int
|
||||
|
@ -21,6 +22,11 @@ class _QuantizerConfig:
|
|||
sym: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FP8QuantizerConfig:
|
||||
activation_scale_ub: float
|
||||
|
||||
|
||||
# We should probably do this with Pytantic JSON deserialization,
|
||||
# but for now we'll stay close to the old _set_gptq_params.
|
||||
def _get_quantizer_config(model_id, revision):
|
||||
|
@ -39,6 +45,13 @@ def _get_quantizer_config(model_id, revision):
|
|||
filename = hf_hub_download(model_id, filename=filename, revision=revision)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# FP8 config
|
||||
if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
|
||||
return _FP8QuantizerConfig(
|
||||
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
|
||||
)
|
||||
|
||||
bits = data["quantization_config"]["bits"]
|
||||
groupsize = data["quantization_config"]["group_size"]
|
||||
# Order is important here, desc_act is missing on some real models
|
||||
|
@ -99,6 +112,12 @@ def get_loader(
|
|||
if quantize in {"awq", "gptq"}:
|
||||
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||
|
||||
# TODO: improve check once we have one config type per quantize value
|
||||
if not isinstance(quantizer_config, _QuantizerConfig):
|
||||
raise ValueError(
|
||||
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
|
||||
)
|
||||
|
||||
return GPTQWeightsLoader(
|
||||
bits=quantizer_config.bits,
|
||||
desc_act=quantizer_config.desc_act,
|
||||
|
@ -127,18 +146,28 @@ def get_loader(
|
|||
from text_generation_server.layers.exl2 import Exl2WeightsLoader
|
||||
|
||||
return Exl2WeightsLoader()
|
||||
elif quantize == "fp8":
|
||||
from text_generation_server.layers.fp8 import Fp8Weight
|
||||
|
||||
return DefaultWeightsLoader(Fp8Weight)
|
||||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.marlin import MarlinWeightsLoader
|
||||
|
||||
# TODO: improve check once we have one config type per quantize value
|
||||
if not isinstance(quantizer_config, _QuantizerConfig):
|
||||
raise ValueError(
|
||||
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
|
||||
)
|
||||
|
||||
return MarlinWeightsLoader(
|
||||
bits=quantizer_config.bits,
|
||||
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
|
||||
)
|
||||
elif quantize is None:
|
||||
return DefaultWeightsLoader(UnquantizedWeight)
|
||||
elif quantize == "fp8" or quantize is None:
|
||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||
|
||||
# Since the default for the quantize config is _QuantizerConfig,
|
||||
# we need to add this check to not get an attribute error
|
||||
activation_scale_ub = None
|
||||
if isinstance(quantizer_config, _FP8QuantizerConfig):
|
||||
activation_scale_ub = quantizer_config.activation_scale_ub
|
||||
|
||||
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8")
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization method: {quantize}")
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import torch
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from typing import Dict, List, Optional, Union, Type
|
||||
from safetensors import safe_open
|
||||
from dataclasses import dataclass
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
|
@ -84,7 +84,7 @@ class Weight(ABC):
|
|||
|
||||
|
||||
@dataclass
|
||||
class UnquantizedWeight:
|
||||
class UnquantizedWeight(Weight):
|
||||
weight: torch.Tensor
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
|
@ -99,7 +99,7 @@ class UnquantizedWeight:
|
|||
class DefaultWeightsLoader(WeightsLoader):
|
||||
"""Weight loader that loads (unquantized) Torch tensors."""
|
||||
|
||||
def __init__(self, weight_class):
|
||||
def __init__(self, weight_class: Type[UnquantizedWeight]):
|
||||
"""Create a loader. Weights will be wrapped using the given `weights_class`,
|
||||
normally this will be `UnquantizedWeight`, but a quantizer-specific class
|
||||
such as `Fp8Weight` can be used to quantize the weights during loading.
|
||||
|
@ -208,20 +208,29 @@ class Weights:
|
|||
def get_shape(self, tensor_name: str):
|
||||
return self._get_slice(tensor_name).get_shape()
|
||||
|
||||
def get_tensor(self, tensor_name: str, to_device=True):
|
||||
def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
# Special case for gptq which shouldn't convert
|
||||
# u4 which are disguised as int32. Exl2 uses int16
|
||||
# as well.
|
||||
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
|
||||
# as well. FP8 uses torch.float8_e4m3fn
|
||||
if (
|
||||
tensor.dtype
|
||||
not in [
|
||||
torch.float8_e4m3fn,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
]
|
||||
and to_dtype
|
||||
):
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
if to_device:
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_partial_sharded(self, tensor_name: str, dim: int):
|
||||
def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype=True):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
|
@ -241,12 +250,16 @@ class Weights:
|
|||
raise NotImplementedError("Let's make that generic when needed")
|
||||
# Special case for gptq which shouldn't convert
|
||||
# u4 which are disguised as int32. exl2 uses int16.
|
||||
if tensor.dtype not in (torch.int16, torch.int32):
|
||||
# FP8 uses torch.float8_e4m3fn.
|
||||
if (
|
||||
tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32)
|
||||
and to_dtype
|
||||
):
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_sharded(self, tensor_name: str, dim: int):
|
||||
def get_sharded(self, tensor_name: str, dim: int, to_dtype=True):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
|
@ -255,10 +268,14 @@ class Weights:
|
|||
assert (
|
||||
size % world_size == 0
|
||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||
return self.get_partial_sharded(tensor_name, dim)
|
||||
return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype)
|
||||
|
||||
def get_packed_sharded(
|
||||
self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]]
|
||||
self,
|
||||
tensor_name: str,
|
||||
dim: int,
|
||||
block_sizes: Union[int, List[int]],
|
||||
to_dtype=True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Get a shard from a tensor that packs multiple tensors.
|
||||
|
@ -304,7 +321,16 @@ class Weights:
|
|||
tensor = tensor.to(device=self.device)
|
||||
|
||||
# Avoid casting quantizer dtypes.
|
||||
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
|
||||
if (
|
||||
tensor.dtype
|
||||
not in [
|
||||
torch.float8_e4m3fn,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
]
|
||||
and to_dtype
|
||||
):
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
|
||||
return tensor
|
||||
|
|
Loading…
Reference in New Issue